diff --git a/osrt/sam/mask_decoder.py b/osrt/sam/mask_decoder.py
index 347c5c5977a65c36553c487416a281e0e5694356..49f54cc6a7b0f0032433ffd81b85144b7a67c3e0 100644
--- a/osrt/sam/mask_decoder.py
+++ b/osrt/sam/mask_decoder.py
@@ -118,13 +118,14 @@ class MaskDecoder(nn.Module):
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Predicts masks. See 'forward' for more details."""
         # Concatenate output tokens
+        device = sparse_prompt_embeddings.device
         output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
-        output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
-        tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
+        output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1).to(device)
+        tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1).to(device)
 
         # Expand per-image data in batch direction to be per-mask
         if image_embeddings.shape[0] != tokens.shape[0]:
-            src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
+            src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0).to(device)
         else:
             src = image_embeddings
         src = src + dense_prompt_embeddings
@@ -141,10 +142,11 @@ class MaskDecoder(nn.Module):
         # Upscale mask embeddings and predict masks using the mask tokens
         src = src.transpose(1, 2).view(b, c, h, w)
         upscaled_embedding = self.output_upscaling(src)
+        upscaled_embedding = upscaled_embedding.to(device)
         hyper_in_list: List[torch.Tensor] = []
         for i in range(self.num_mask_tokens):
             hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
-        hyper_in = torch.stack(hyper_in_list, dim=1)
+        hyper_in = torch.stack(hyper_in_list, dim=1).to(device)
         b, c, h, w = upscaled_embedding.shape
         masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)