From 17290f74425673e914d60a23f9a56c9fb7315a14 Mon Sep 17 00:00:00 2001 From: alexcbb <alexchapin@hotmail.fr> Date: Wed, 19 Jul 2023 08:24:26 +0200 Subject: [PATCH] Fix device on prompt encoder --- osrt/sam/prompt_encoder.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/osrt/sam/prompt_encoder.py b/osrt/sam/prompt_encoder.py index f03ce71..cc1ed52 100644 --- a/osrt/sam/prompt_encoder.py +++ b/osrt/sam/prompt_encoder.py @@ -151,11 +151,11 @@ class PromptEncoder(nn.Module): Bx(embed_dim)x(embed_H)x(embed_W) """ bs = self._get_batch_size(points, boxes, masks) - sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) + sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=points.device) if points is not None: coords, labels = points - point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) - sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) + point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)).to(points.device) + sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1).to(points.device) if boxes is not None: box_embeddings = self._embed_boxes(boxes) sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) -- GitLab