Skip to content
Snippets Groups Projects
Commit 56908b4f authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Last fix

parent 83b52a03
No related branches found
No related tags found
No related merge requests found
......@@ -81,9 +81,11 @@ class PromptEncoder(nn.Module):
if pad:
padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
points = torch.cat([points, padding_point], dim=1)
labels = torch.cat([labels, padding_label], dim=1)
points = torch.cat([points, padding_point], dim=1).to(points.device)
labels = torch.cat([labels, padding_label], dim=1).to(labels.device)
point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
self.not_a_point_embed.to(points.device)
self.point_embeddings.to(points.device)
point_embedding[labels == -1] = 0.0
point_embedding[labels == -1] += self.not_a_point_embed.weight
point_embedding[labels == 0] += self.point_embeddings[0].weight
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment