Skip to content
Snippets Groups Projects
Commit b109748f authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Fix device problem

parent aea78759
No related branches found
No related tags found
No related merge requests found
......@@ -85,7 +85,7 @@ class PromptEncoder(nn.Module):
labels = torch.cat([labels, padding_label], dim=1)
point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
point_embedding[labels == -1] = 0.0
point_embedding[labels == -1] += self.not_a_point_embed.weight
point_embedding[labels == -1] += self.not_a_point_embed.weight.to(point_embedding.device)
point_embedding[labels == 0] += self.point_embeddings[0].weight
point_embedding[labels == 1] += self.point_embeddings[1].weight
return point_embedding
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment