diff --git a/osrt/sam/transformer.py b/osrt/sam/transformer.py
index a39d56427505658b0a1fab2307e6afebeb36a33a..1049c18b7ab68aa3c577fb020d7136d91d57d81b 100644
--- a/osrt/sam/transformer.py
+++ b/osrt/sam/transformer.py
@@ -158,6 +158,7 @@ class TwoWayAttentionBlock(nn.Module):
             q = queries + query_pe
             attn_out = self.self_attn(q=q, k=q, v=queries)
             queries = queries + attn_out
+        self.norm1 = self.norm1.to(queries.device)
         queries = self.norm1(queries)
 
         # Cross attention block, tokens attending to image embedding
@@ -165,18 +166,22 @@ class TwoWayAttentionBlock(nn.Module):
         k = keys + key_pe
         attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
         queries = queries + attn_out
+        self.norm2 = self.norm2.to(queries.device)
         queries = self.norm2(queries)
 
         # MLP block
+        self.mlp = self.mlp.to(queries.device)
         mlp_out = self.mlp(queries)
         queries = queries + mlp_out
         queries = self.norm3(queries)
+        self.norm3 = self.norm3.to(queries.device)
 
         # Cross attention block, image embedding attending to tokens
         q = queries + query_pe
         k = keys + key_pe
         attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
         keys = keys + attn_out
+        self.norm4 = self.norm4.to(keys.device)
         keys = self.norm4(keys)
 
         return queries, keys