diff --git a/osrt/encoder.py b/osrt/encoder.py index deb591831b95b4daf60af4a019e520ec67519552..706f516ca8fbee0391d4cbeac27475a5cc84ee22 100644 --- a/osrt/encoder.py +++ b/osrt/encoder.py @@ -151,6 +151,10 @@ class FeatureMasking(nn.Module): # We first initialize the automatic mask generator from SAM # TODO : change the loading here !!!! self.mask_generator = sam_model_registry[sam_model](checkpoint=sam_path) + device = self.mask_generator.device + self.mask_generator.image_encoder.to(device) + self.mask_generator.prompt_encoder.to(device) + self.mask_generator.mask_decoder.to(device) self.preprocess = ResizeAndPad(self.mask_generator.image_encoder.img_size) self.resize = ResizeLongestSide(self.mask_generator.image_encoder.img_size) diff --git a/osrt/sam/prompt_encoder.py b/osrt/sam/prompt_encoder.py index d8954af804285be6e7048ebabc7a83df9c1d7ddb..3d36827977edb679005c7a38fcfcdfb37a819f71 100644 --- a/osrt/sam/prompt_encoder.py +++ b/osrt/sam/prompt_encoder.py @@ -85,7 +85,7 @@ class PromptEncoder(nn.Module): labels = torch.cat([labels, padding_label], dim=1) point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) point_embedding[labels == -1] = 0.0 - point_embedding[labels == -1] += self.not_a_point_embed.weight.to(point_embedding.device) + point_embedding[labels == -1] += self.not_a_point_embed.weight point_embedding[labels == 0] += self.point_embeddings[0].weight point_embedding[labels == 1] += self.point_embeddings[1].weight return point_embedding