Skip to content
Snippets Groups Projects
Commit 8626edef authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Fix keys device

parent 50ab36b3
No related branches found
No related tags found
No related merge requests found
......@@ -179,10 +179,11 @@ class TwoWayAttentionBlock(nn.Module):
# Cross attention block, image embedding attending to tokens
q = queries + query_pe
key_pe, keys = key_pe.to(queries.device), keys.to(queries.device)
k = keys + key_pe
attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
keys = keys + attn_out
self.norm4 = self.norm4.to(keys.device)
self.norm4 = self.norm4.to(queries.device)
keys = self.norm4(keys)
return queries, keys
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment