Skip to content
Snippets Groups Projects
Commit d82172c8 authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Remove devices

parent 8626edef
No related branches found
No related tags found
No related merge requests found
......@@ -125,14 +125,14 @@ class MaskDecoder(nn.Module):
# Expand per-image data in batch direction to be per-mask
if image_embeddings.shape[0] != tokens.shape[0]:
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0).to(dense_prompt_embeddings.device)
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0).to(device)
else:
src = image_embeddings.to(dense_prompt_embeddings.device)
src = image_embeddings.to(device)
src = src + dense_prompt_embeddings
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0).to(device)
b, c, h, w = src.shape
# Run the transformer
hs, src = self.transformer(src, pos_src, tokens)
iou_token_out = hs[:, 0, :]
......
......@@ -158,32 +158,30 @@ class TwoWayAttentionBlock(nn.Module):
q = queries + query_pe
attn_out = self.self_attn(q=q, k=q, v=queries)
queries = queries + attn_out
self.norm1 = self.norm1.to(queries.device)
self.norm1 = self.norm1
queries = self.norm1(queries)
# Cross attention block, tokens attending to image embedding
q = queries + query_pe
k = keys + key_pe
q, k, keys = q.to(queries.device), k.to(queries.device), keys.to(queries.device)
attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
queries = queries + attn_out
self.norm2 = self.norm2.to(queries.device)
self.norm2 = self.norm2
queries = self.norm2(queries)
# MLP block
self.mlp = self.mlp.to(queries.device)
self.mlp = self.mlp
mlp_out = self.mlp(queries)
queries = queries + mlp_out
self.norm3 = self.norm3.to(queries.device)
self.norm3 = self.norm3
queries = self.norm3(queries)
# Cross attention block, image embedding attending to tokens
q = queries + query_pe
key_pe, keys = key_pe.to(queries.device), keys.to(queries.device)
k = keys + key_pe
attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
keys = keys + attn_out
self.norm4 = self.norm4.to(queries.device)
self.norm4 = self.norm4
keys = self.norm4(keys)
return queries, keys
......@@ -222,9 +220,6 @@ class Attention(nn.Module):
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
self.q_proj = self.q_proj.to(q.device)
self.k_proj = self.k_proj.to(q.device)
self.v_proj = self.v_proj.to(q.device)
# Input projections
q = self.q_proj(q)
......@@ -245,7 +240,6 @@ class Attention(nn.Module):
# Get output
out = attn @ v
out = self._recombine_heads(out)
self.out_proj = self.out_proj.to(out.device)
out = self.out_proj(out)
return out
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment