diff --git a/osrt/encoder.py b/osrt/encoder.py
index 7c09f41e6b4adb093c7be4df70bc176248713eff..e5292bad070d61e9b673ce9fb9ff59989377e1f2 100644
--- a/osrt/encoder.py
+++ b/osrt/encoder.py
@@ -191,7 +191,7 @@ class FeatureMasking(nn.Module):
         im_size = self.resize.apply_image(images[0]).shape[-3:-1]
 
         ### Pre-process images for the image encoder (Resize and Pad)
-        images = torch.stack([self.preprocess(x) for x in images], device=self.mask_generator.device)
+        images = torch.stack([self.preprocess(x) for x in images]).to(self.mask_generator.device)
 
         ### Encode images 
         image_embeddings, embed_no_red = self.mask_generator.image_encoder(images, before_channel_reduc=True) # [B x N, C, H, W]