diff --git a/osrt/sam/prompt_encoder.py b/osrt/sam/prompt_encoder.py
index 3d36827977edb679005c7a38fcfcdfb37a819f71..d8954af804285be6e7048ebabc7a83df9c1d7ddb 100644
--- a/osrt/sam/prompt_encoder.py
+++ b/osrt/sam/prompt_encoder.py
@@ -85,7 +85,7 @@ class PromptEncoder(nn.Module):
             labels = torch.cat([labels, padding_label], dim=1)
         point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
         point_embedding[labels == -1] = 0.0
-        point_embedding[labels == -1] += self.not_a_point_embed.weight
+        point_embedding[labels == -1] += self.not_a_point_embed.weight.to(point_embedding.device)
         point_embedding[labels == 0] += self.point_embeddings[0].weight
         point_embedding[labels == 1] += self.point_embeddings[1].weight
         return point_embedding