From be1c9aa686a92b7c410bbb299ffa8bb2fcaaf4f6 Mon Sep 17 00:00:00 2001
From: alexcbb <alexchapin@hotmail.fr>
Date: Wed, 19 Jul 2023 15:10:27 +0200
Subject: [PATCH] Device

---
 osrt/sam/transformer.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/osrt/sam/transformer.py b/osrt/sam/transformer.py
index 1049c18..fe94705 100644
--- a/osrt/sam/transformer.py
+++ b/osrt/sam/transformer.py
@@ -221,8 +221,8 @@ class Attention(nn.Module):
 
     def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
         self.q_proj = self.q_proj.to(q.device)
-        self.k_proj = self.k_proj.to(k.device)
-        self.v_proj = self.v_proj.to(v.device)
+        self.k_proj = self.k_proj.to(q.device)
+        self.v_proj = self.v_proj.to(q.device)
 
         # Input projections
         q = self.q_proj(q)
-- 
GitLab