From cdf3ad56bcc588149d6a44ec2704205338d112d6 Mon Sep 17 00:00:00 2001
From: alexcbb <alexchapin@hotmail.fr>
Date: Wed, 19 Jul 2023 12:21:28 +0200
Subject: [PATCH] To device

---
 osrt/sam/transformer.py | 12 ++++++++----
 1 file changed, 8 insertions(+), 4 deletions(-)

diff --git a/osrt/sam/transformer.py b/osrt/sam/transformer.py
index d70b6e7..4d88344 100644
--- a/osrt/sam/transformer.py
+++ b/osrt/sam/transformer.py
@@ -199,10 +199,10 @@ class Attention(nn.Module):
         self.internal_dim = embedding_dim // downsample_rate
         self.num_heads = num_heads
         assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
-        self.q_proj = nn.Linear(embedding_dim, self.internal_dim).cuda()
-        self.k_proj = nn.Linear(embedding_dim, self.internal_dim).cuda()
-        self.v_proj = nn.Linear(embedding_dim, self.internal_dim).cuda()
-        self.out_proj = nn.Linear(self.internal_dim, embedding_dim).cuda()
+        self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
+        self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
+        self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
+        self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
 
     def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
         b, n, c = x.shape
@@ -215,6 +215,10 @@ class Attention(nn.Module):
         return x.reshape(b, n_tokens, n_heads * c_per_head)  # B x N_tokens x C
 
     def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
+        self.q_proj = self.q_proj.to(q.device)
+        self.k_proj = self.k_proj.to(k.device)
+        self.v_proj = self.v_proj.to(v.device)
+
         # Input projections
         q = self.q_proj(q)
         k = self.k_proj(k)
-- 
GitLab