diff --git a/osrt/sam/transformer.py b/osrt/sam/transformer.py index 4d88344287b4829fe6fbd6c0bcdb48f9c9ab91ba..a39d56427505658b0a1fab2307e6afebeb36a33a 100644 --- a/osrt/sam/transformer.py +++ b/osrt/sam/transformer.py @@ -238,6 +238,7 @@ class Attention(nn.Module): # Get output out = attn @ v out = self._recombine_heads(out) + self.out_proj = self.out_proj.to(out.device) out = self.out_proj(out) return out