diff --git a/osrt/sam/prompt_encoder.py b/osrt/sam/prompt_encoder.py index f03ce7148560454481cee54b1c634e5fbd23b5de..cc1ed52314da88c0c84a79f4511e724a7fe6a670 100644 --- a/osrt/sam/prompt_encoder.py +++ b/osrt/sam/prompt_encoder.py @@ -151,11 +151,11 @@ class PromptEncoder(nn.Module): Bx(embed_dim)x(embed_H)x(embed_W) """ bs = self._get_batch_size(points, boxes, masks) - sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) + sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=points.device) if points is not None: coords, labels = points - point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) - sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) + point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)).to(points.device) + sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1).to(points.device) if boxes is not None: box_embeddings = self._embed_boxes(boxes) sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)