Skip to content
Snippets Groups Projects
Commit d8dcb3df authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Set all layers to right device

parent 86ac08b5
No related branches found
No related tags found
No related merge requests found
...@@ -158,6 +158,7 @@ class TwoWayAttentionBlock(nn.Module): ...@@ -158,6 +158,7 @@ class TwoWayAttentionBlock(nn.Module):
q = queries + query_pe q = queries + query_pe
attn_out = self.self_attn(q=q, k=q, v=queries) attn_out = self.self_attn(q=q, k=q, v=queries)
queries = queries + attn_out queries = queries + attn_out
self.norm1 = self.norm1.to(queries.device)
queries = self.norm1(queries) queries = self.norm1(queries)
# Cross attention block, tokens attending to image embedding # Cross attention block, tokens attending to image embedding
...@@ -165,18 +166,22 @@ class TwoWayAttentionBlock(nn.Module): ...@@ -165,18 +166,22 @@ class TwoWayAttentionBlock(nn.Module):
k = keys + key_pe k = keys + key_pe
attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
queries = queries + attn_out queries = queries + attn_out
self.norm2 = self.norm2.to(queries.device)
queries = self.norm2(queries) queries = self.norm2(queries)
# MLP block # MLP block
self.mlp = self.mlp.to(queries.device)
mlp_out = self.mlp(queries) mlp_out = self.mlp(queries)
queries = queries + mlp_out queries = queries + mlp_out
queries = self.norm3(queries) queries = self.norm3(queries)
self.norm3 = self.norm3.to(queries.device)
# Cross attention block, image embedding attending to tokens # Cross attention block, image embedding attending to tokens
q = queries + query_pe q = queries + query_pe
k = keys + key_pe k = keys + key_pe
attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
keys = keys + attn_out keys = keys + attn_out
self.norm4 = self.norm4.to(keys.device)
keys = self.norm4(keys) keys = self.norm4(keys)
return queries, keys return queries, keys
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment