diff --git a/osrt/sam/sam.py b/osrt/sam/sam.py
index 8074cff6b40addc6b66f7ab4962218eef20da13c..764b485fbf061b373447f10c0deebfbe6b06efeb 100644
--- a/osrt/sam/sam.py
+++ b/osrt/sam/sam.py
@@ -109,10 +109,10 @@ class Sam(nn.Module):
                 masks=image_record.get("mask_inputs", None),
             )
             low_res_masks, iou_predictions = self.mask_decoder(
-                image_embeddings=curr_embedding.unsqueeze(0),
-                image_pe=self.prompt_encoder.get_dense_pe(),
-                sparse_prompt_embeddings=sparse_embeddings,
-                dense_prompt_embeddings=dense_embeddings,
+                image_embeddings=curr_embedding.unsqueeze(0).to(self.device),
+                image_pe=self.prompt_encoder.get_dense_pe().to(self.device),
+                sparse_prompt_embeddings=sparse_embeddings.to(self.device),
+                dense_prompt_embeddings=dense_embeddings.to(self.device),
                 multimask_output=multimask_output,
             )
             masks = self.postprocess_masks(