diff --git a/osrt/sam/mask_decoder.py b/osrt/sam/mask_decoder.py index 9967d162cc15d1c41e257710b265a41cc9f234a3..d0739c84374b9e95db54cdf98d4141bd4652c99e 100644 --- a/osrt/sam/mask_decoder.py +++ b/osrt/sam/mask_decoder.py @@ -134,6 +134,7 @@ class MaskDecoder(nn.Module): b, c, h, w = src.shape + self.transformer.to(src.device) # Run the transformer hs, src = self.transformer(src, pos_src, tokens) iou_token_out = hs[:, 0, :]