From 17290f74425673e914d60a23f9a56c9fb7315a14 Mon Sep 17 00:00:00 2001
From: alexcbb <alexchapin@hotmail.fr>
Date: Wed, 19 Jul 2023 08:24:26 +0200
Subject: [PATCH] Fix device on prompt encoder

---
 osrt/sam/prompt_encoder.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/osrt/sam/prompt_encoder.py b/osrt/sam/prompt_encoder.py
index f03ce71..cc1ed52 100644
--- a/osrt/sam/prompt_encoder.py
+++ b/osrt/sam/prompt_encoder.py
@@ -151,11 +151,11 @@ class PromptEncoder(nn.Module):
             Bx(embed_dim)x(embed_H)x(embed_W)
         """
         bs = self._get_batch_size(points, boxes, masks)
-        sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
+        sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=points.device)
         if points is not None:
             coords, labels = points
-            point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
-            sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
+            point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)).to(points.device)
+            sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1).to(points.device)
         if boxes is not None:
             box_embeddings = self._embed_boxes(boxes)
             sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
-- 
GitLab