diff --git a/osrt/sam/transformer.py b/osrt/sam/transformer.py index 1049c18b7ab68aa3c577fb020d7136d91d57d81b..fe947052f89e684f35837343afeca11ca042bb4e 100644 --- a/osrt/sam/transformer.py +++ b/osrt/sam/transformer.py @@ -221,8 +221,8 @@ class Attention(nn.Module): def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: self.q_proj = self.q_proj.to(q.device) - self.k_proj = self.k_proj.to(k.device) - self.v_proj = self.v_proj.to(v.device) + self.k_proj = self.k_proj.to(q.device) + self.v_proj = self.v_proj.to(q.device) # Input projections q = self.q_proj(q)