Skip to content
Snippets Groups Projects
Commit be1c9aa6 authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Device

parent d8dcb3df
No related branches found
No related tags found
No related merge requests found
......@@ -221,8 +221,8 @@ class Attention(nn.Module):
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
self.q_proj = self.q_proj.to(q.device)
self.k_proj = self.k_proj.to(k.device)
self.v_proj = self.v_proj.to(v.device)
self.k_proj = self.k_proj.to(q.device)
self.v_proj = self.v_proj.to(q.device)
# Input projections
q = self.q_proj(q)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment