diff --git a/osrt/sam/transformer.py b/osrt/sam/transformer.py
index fe947052f89e684f35837343afeca11ca042bb4e..b3fc70a6407bb83c1aa61e9bb4cc24b72a81776a 100644
--- a/osrt/sam/transformer.py
+++ b/osrt/sam/transformer.py
@@ -164,6 +164,7 @@ class TwoWayAttentionBlock(nn.Module):
         # Cross attention block, tokens attending to image embedding
         q = queries + query_pe
         k = keys + key_pe
+        q, k, keys = q.to(queries.device), k.to(queries.device), keys.to(queries.device)
         attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
         queries = queries + attn_out
         self.norm2 = self.norm2.to(queries.device)