diff --git a/osrt/sam/transformer.py b/osrt/sam/transformer.py
index 4d88344287b4829fe6fbd6c0bcdb48f9c9ab91ba..a39d56427505658b0a1fab2307e6afebeb36a33a 100644
--- a/osrt/sam/transformer.py
+++ b/osrt/sam/transformer.py
@@ -238,6 +238,7 @@ class Attention(nn.Module):
         # Get output
         out = attn @ v
         out = self._recombine_heads(out)
+        self.out_proj = self.out_proj.to(out.device)
         out = self.out_proj(out)
 
         return out