From b109748f061f916313728c83bad722e4d0f3f04e Mon Sep 17 00:00:00 2001 From: alexcbb <alexchapin@hotmail.fr> Date: Tue, 18 Jul 2023 17:31:23 +0200 Subject: [PATCH] Fix device problem --- osrt/sam/prompt_encoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/osrt/sam/prompt_encoder.py b/osrt/sam/prompt_encoder.py index 3d36827..d8954af 100644 --- a/osrt/sam/prompt_encoder.py +++ b/osrt/sam/prompt_encoder.py @@ -85,7 +85,7 @@ class PromptEncoder(nn.Module): labels = torch.cat([labels, padding_label], dim=1) point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) point_embedding[labels == -1] = 0.0 - point_embedding[labels == -1] += self.not_a_point_embed.weight + point_embedding[labels == -1] += self.not_a_point_embed.weight.to(point_embedding.device) point_embedding[labels == 0] += self.point_embeddings[0].weight point_embedding[labels == 1] += self.point_embeddings[1].weight return point_embedding -- GitLab