diff --git a/osrt/sam/transformer.py b/osrt/sam/transformer.py index a39d56427505658b0a1fab2307e6afebeb36a33a..1049c18b7ab68aa3c577fb020d7136d91d57d81b 100644 --- a/osrt/sam/transformer.py +++ b/osrt/sam/transformer.py @@ -158,6 +158,7 @@ class TwoWayAttentionBlock(nn.Module): q = queries + query_pe attn_out = self.self_attn(q=q, k=q, v=queries) queries = queries + attn_out + self.norm1 = self.norm1.to(queries.device) queries = self.norm1(queries) # Cross attention block, tokens attending to image embedding @@ -165,18 +166,22 @@ class TwoWayAttentionBlock(nn.Module): k = keys + key_pe attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) queries = queries + attn_out + self.norm2 = self.norm2.to(queries.device) queries = self.norm2(queries) # MLP block + self.mlp = self.mlp.to(queries.device) mlp_out = self.mlp(queries) queries = queries + mlp_out queries = self.norm3(queries) + self.norm3 = self.norm3.to(queries.device) # Cross attention block, image embedding attending to tokens q = queries + query_pe k = keys + key_pe attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) keys = keys + attn_out + self.norm4 = self.norm4.to(keys.device) keys = self.norm4(keys) return queries, keys