diff --git a/osrt/sam/mask_decoder.py b/osrt/sam/mask_decoder.py
index 4043e79cfda41621128c0d7ce761031be7fae36f..c9b6cce2d603661e47d8f8946a854a728fd27bd2 100644
--- a/osrt/sam/mask_decoder.py
+++ b/osrt/sam/mask_decoder.py
@@ -125,14 +125,14 @@ class MaskDecoder(nn.Module):
 
         # Expand per-image data in batch direction to be per-mask
         if image_embeddings.shape[0] != tokens.shape[0]:
-            src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0).to(dense_prompt_embeddings.device)
+            src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0).to(device)
         else:
-            src = image_embeddings.to(dense_prompt_embeddings.device)
+            src = image_embeddings.to(device)
         src = src + dense_prompt_embeddings
 
-        pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
+        pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0).to(device)
         b, c, h, w = src.shape
-        
+
         # Run the transformer
         hs, src = self.transformer(src, pos_src, tokens)
         iou_token_out = hs[:, 0, :]
diff --git a/osrt/sam/transformer.py b/osrt/sam/transformer.py
index 7634493f6b4791370d4926f65506d560f9632cf5..5f51dd0700928dab510e76e40f88e15eac508901 100644
--- a/osrt/sam/transformer.py
+++ b/osrt/sam/transformer.py
@@ -158,32 +158,30 @@ class TwoWayAttentionBlock(nn.Module):
             q = queries + query_pe
             attn_out = self.self_attn(q=q, k=q, v=queries)
             queries = queries + attn_out
-        self.norm1 = self.norm1.to(queries.device)
+        self.norm1 = self.norm1
         queries = self.norm1(queries)
 
         # Cross attention block, tokens attending to image embedding
         q = queries + query_pe
         k = keys + key_pe
-        q, k, keys = q.to(queries.device), k.to(queries.device), keys.to(queries.device)
         attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
         queries = queries + attn_out
-        self.norm2 = self.norm2.to(queries.device)
+        self.norm2 = self.norm2
         queries = self.norm2(queries)
 
         # MLP block
-        self.mlp = self.mlp.to(queries.device)
+        self.mlp = self.mlp
         mlp_out = self.mlp(queries)
         queries = queries + mlp_out
-        self.norm3 = self.norm3.to(queries.device)
+        self.norm3 = self.norm3
         queries = self.norm3(queries)
 
         # Cross attention block, image embedding attending to tokens
         q = queries + query_pe
-        key_pe, keys = key_pe.to(queries.device), keys.to(queries.device)
         k = keys + key_pe
         attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
         keys = keys + attn_out
-        self.norm4 = self.norm4.to(queries.device)
+        self.norm4 = self.norm4
         keys = self.norm4(keys)
 
         return queries, keys
@@ -222,9 +220,6 @@ class Attention(nn.Module):
         return x.reshape(b, n_tokens, n_heads * c_per_head)  # B x N_tokens x C
 
     def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
-        self.q_proj = self.q_proj.to(q.device)
-        self.k_proj = self.k_proj.to(q.device)
-        self.v_proj = self.v_proj.to(q.device)
 
         # Input projections
         q = self.q_proj(q)
@@ -245,7 +240,6 @@ class Attention(nn.Module):
         # Get output
         out = attn @ v
         out = self._recombine_heads(out)
-        self.out_proj = self.out_proj.to(out.device)
         out = self.out_proj(out)
 
         return out