Skip to content
Snippets Groups Projects
Commit 570c677f authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

To cuda

parent e7077179
No related branches found
No related tags found
No related merge requests found
...@@ -199,11 +199,10 @@ class Attention(nn.Module): ...@@ -199,11 +199,10 @@ class Attention(nn.Module):
self.internal_dim = embedding_dim // downsample_rate self.internal_dim = embedding_dim // downsample_rate
self.num_heads = num_heads self.num_heads = num_heads
assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
device = next(super().parameters()).device self.q_proj = nn.Linear(embedding_dim, self.internal_dim).cuda()
self.q_proj = nn.Linear(embedding_dim, self.internal_dim).to(device) self.k_proj = nn.Linear(embedding_dim, self.internal_dim).cuda()
self.k_proj = nn.Linear(embedding_dim, self.internal_dim).to(device) self.v_proj = nn.Linear(embedding_dim, self.internal_dim).cuda()
self.v_proj = nn.Linear(embedding_dim, self.internal_dim).to(device) self.out_proj = nn.Linear(self.internal_dim, embedding_dim).cuda()
self.out_proj = nn.Linear(self.internal_dim, embedding_dim).to(device)
def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
b, n, c = x.shape b, n, c = x.shape
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment