From cdf3ad56bcc588149d6a44ec2704205338d112d6 Mon Sep 17 00:00:00 2001 From: alexcbb <alexchapin@hotmail.fr> Date: Wed, 19 Jul 2023 12:21:28 +0200 Subject: [PATCH] To device --- osrt/sam/transformer.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/osrt/sam/transformer.py b/osrt/sam/transformer.py index d70b6e7..4d88344 100644 --- a/osrt/sam/transformer.py +++ b/osrt/sam/transformer.py @@ -199,10 +199,10 @@ class Attention(nn.Module): self.internal_dim = embedding_dim // downsample_rate self.num_heads = num_heads assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." - self.q_proj = nn.Linear(embedding_dim, self.internal_dim).cuda() - self.k_proj = nn.Linear(embedding_dim, self.internal_dim).cuda() - self.v_proj = nn.Linear(embedding_dim, self.internal_dim).cuda() - self.out_proj = nn.Linear(self.internal_dim, embedding_dim).cuda() + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(embedding_dim, self.internal_dim) + self.v_proj = nn.Linear(embedding_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: b, n, c = x.shape @@ -215,6 +215,10 @@ class Attention(nn.Module): return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + self.q_proj = self.q_proj.to(q.device) + self.k_proj = self.k_proj.to(k.device) + self.v_proj = self.v_proj.to(v.device) + # Input projections q = self.q_proj(q) k = self.k_proj(k) -- GitLab