Skip to content
Snippets Groups Projects
Commit e7077179 authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Fix get device

parent bf7f214a
No related branches found
No related tags found
No related merge requests found
...@@ -199,7 +199,7 @@ class Attention(nn.Module): ...@@ -199,7 +199,7 @@ class Attention(nn.Module):
self.internal_dim = embedding_dim // downsample_rate self.internal_dim = embedding_dim // downsample_rate
self.num_heads = num_heads self.num_heads = num_heads
assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
device = super().parameters().device device = next(super().parameters()).device
self.q_proj = nn.Linear(embedding_dim, self.internal_dim).to(device) self.q_proj = nn.Linear(embedding_dim, self.internal_dim).to(device)
self.k_proj = nn.Linear(embedding_dim, self.internal_dim).to(device) self.k_proj = nn.Linear(embedding_dim, self.internal_dim).to(device)
self.v_proj = nn.Linear(embedding_dim, self.internal_dim).to(device) self.v_proj = nn.Linear(embedding_dim, self.internal_dim).to(device)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment