Skip to content
Snippets Groups Projects
Commit 17290f74 authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Fix device on prompt encoder

parent 56908b4f
No related branches found
No related tags found
No related merge requests found
...@@ -151,11 +151,11 @@ class PromptEncoder(nn.Module): ...@@ -151,11 +151,11 @@ class PromptEncoder(nn.Module):
Bx(embed_dim)x(embed_H)x(embed_W) Bx(embed_dim)x(embed_H)x(embed_W)
""" """
bs = self._get_batch_size(points, boxes, masks) bs = self._get_batch_size(points, boxes, masks)
sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=points.device)
if points is not None: if points is not None:
coords, labels = points coords, labels = points
point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)).to(points.device)
sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1).to(points.device)
if boxes is not None: if boxes is not None:
box_embeddings = self._embed_boxes(boxes) box_embeddings = self._embed_boxes(boxes)
sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment