diff --git a/osrt/encoder.py b/osrt/encoder.py index 7c09f41e6b4adb093c7be4df70bc176248713eff..e5292bad070d61e9b673ce9fb9ff59989377e1f2 100644 --- a/osrt/encoder.py +++ b/osrt/encoder.py @@ -191,7 +191,7 @@ class FeatureMasking(nn.Module): im_size = self.resize.apply_image(images[0]).shape[-3:-1] ### Pre-process images for the image encoder (Resize and Pad) - images = torch.stack([self.preprocess(x) for x in images], device=self.mask_generator.device) + images = torch.stack([self.preprocess(x) for x in images]).to(self.mask_generator.device) ### Encode images image_embeddings, embed_no_red = self.mask_generator.image_encoder(images, before_channel_reduc=True) # [B x N, C, H, W]