diff --git a/osrt/sam/transformer.py b/osrt/sam/transformer.py
index 1049c18b7ab68aa3c577fb020d7136d91d57d81b..fe947052f89e684f35837343afeca11ca042bb4e 100644
--- a/osrt/sam/transformer.py
+++ b/osrt/sam/transformer.py
@@ -221,8 +221,8 @@ class Attention(nn.Module):
 
     def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
         self.q_proj = self.q_proj.to(q.device)
-        self.k_proj = self.k_proj.to(k.device)
-        self.v_proj = self.v_proj.to(v.device)
+        self.k_proj = self.k_proj.to(q.device)
+        self.v_proj = self.v_proj.to(q.device)
 
         # Input projections
         q = self.q_proj(q)