diff --git a/osrt/sam/mask_decoder.py b/osrt/sam/mask_decoder.py index 4043e79cfda41621128c0d7ce761031be7fae36f..c9b6cce2d603661e47d8f8946a854a728fd27bd2 100644 --- a/osrt/sam/mask_decoder.py +++ b/osrt/sam/mask_decoder.py @@ -125,14 +125,14 @@ class MaskDecoder(nn.Module): # Expand per-image data in batch direction to be per-mask if image_embeddings.shape[0] != tokens.shape[0]: - src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0).to(dense_prompt_embeddings.device) + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0).to(device) else: - src = image_embeddings.to(dense_prompt_embeddings.device) + src = image_embeddings.to(device) src = src + dense_prompt_embeddings - pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0).to(device) b, c, h, w = src.shape - + # Run the transformer hs, src = self.transformer(src, pos_src, tokens) iou_token_out = hs[:, 0, :] diff --git a/osrt/sam/transformer.py b/osrt/sam/transformer.py index 7634493f6b4791370d4926f65506d560f9632cf5..5f51dd0700928dab510e76e40f88e15eac508901 100644 --- a/osrt/sam/transformer.py +++ b/osrt/sam/transformer.py @@ -158,32 +158,30 @@ class TwoWayAttentionBlock(nn.Module): q = queries + query_pe attn_out = self.self_attn(q=q, k=q, v=queries) queries = queries + attn_out - self.norm1 = self.norm1.to(queries.device) + self.norm1 = self.norm1 queries = self.norm1(queries) # Cross attention block, tokens attending to image embedding q = queries + query_pe k = keys + key_pe - q, k, keys = q.to(queries.device), k.to(queries.device), keys.to(queries.device) attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) queries = queries + attn_out - self.norm2 = self.norm2.to(queries.device) + self.norm2 = self.norm2 queries = self.norm2(queries) # MLP block - self.mlp = self.mlp.to(queries.device) + self.mlp = self.mlp mlp_out = self.mlp(queries) queries = queries + mlp_out - self.norm3 = self.norm3.to(queries.device) + self.norm3 = self.norm3 queries = self.norm3(queries) # Cross attention block, image embedding attending to tokens q = queries + query_pe - key_pe, keys = key_pe.to(queries.device), keys.to(queries.device) k = keys + key_pe attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) keys = keys + attn_out - self.norm4 = self.norm4.to(queries.device) + self.norm4 = self.norm4 keys = self.norm4(keys) return queries, keys @@ -222,9 +220,6 @@ class Attention(nn.Module): return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: - self.q_proj = self.q_proj.to(q.device) - self.k_proj = self.k_proj.to(q.device) - self.v_proj = self.v_proj.to(q.device) # Input projections q = self.q_proj(q) @@ -245,7 +240,6 @@ class Attention(nn.Module): # Get output out = attn @ v out = self._recombine_heads(out) - self.out_proj = self.out_proj.to(out.device) out = self.out_proj(out) return out