Skip to content
Snippets Groups Projects
Commit cdf3ad56 authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

To device

parent 570c677f
No related branches found
No related tags found
No related merge requests found
......@@ -199,10 +199,10 @@ class Attention(nn.Module):
self.internal_dim = embedding_dim // downsample_rate
self.num_heads = num_heads
assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
self.q_proj = nn.Linear(embedding_dim, self.internal_dim).cuda()
self.k_proj = nn.Linear(embedding_dim, self.internal_dim).cuda()
self.v_proj = nn.Linear(embedding_dim, self.internal_dim).cuda()
self.out_proj = nn.Linear(self.internal_dim, embedding_dim).cuda()
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
b, n, c = x.shape
......@@ -215,6 +215,10 @@ class Attention(nn.Module):
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
self.q_proj = self.q_proj.to(q.device)
self.k_proj = self.k_proj.to(k.device)
self.v_proj = self.v_proj.to(v.device)
# Input projections
q = self.q_proj(q)
k = self.k_proj(k)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment