diff --git a/osrt/sam/prompt_encoder.py b/osrt/sam/prompt_encoder.py
index f03ce7148560454481cee54b1c634e5fbd23b5de..cc1ed52314da88c0c84a79f4511e724a7fe6a670 100644
--- a/osrt/sam/prompt_encoder.py
+++ b/osrt/sam/prompt_encoder.py
@@ -151,11 +151,11 @@ class PromptEncoder(nn.Module):
             Bx(embed_dim)x(embed_H)x(embed_W)
         """
         bs = self._get_batch_size(points, boxes, masks)
-        sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
+        sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=points.device)
         if points is not None:
             coords, labels = points
-            point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
-            sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
+            point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)).to(points.device)
+            sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1).to(points.device)
         if boxes is not None:
             box_embeddings = self._embed_boxes(boxes)
             sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)