From be1c9aa686a92b7c410bbb299ffa8bb2fcaaf4f6 Mon Sep 17 00:00:00 2001 From: alexcbb <alexchapin@hotmail.fr> Date: Wed, 19 Jul 2023 15:10:27 +0200 Subject: [PATCH] Device --- osrt/sam/transformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/osrt/sam/transformer.py b/osrt/sam/transformer.py index 1049c18..fe94705 100644 --- a/osrt/sam/transformer.py +++ b/osrt/sam/transformer.py @@ -221,8 +221,8 @@ class Attention(nn.Module): def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: self.q_proj = self.q_proj.to(q.device) - self.k_proj = self.k_proj.to(k.device) - self.v_proj = self.v_proj.to(v.device) + self.k_proj = self.k_proj.to(q.device) + self.v_proj = self.v_proj.to(q.device) # Input projections q = self.q_proj(q) -- GitLab