From 83b52a03331aadb971c0dfaeba0027dc06608b6b Mon Sep 17 00:00:00 2001
From: alexcbb <alexchapin@hotmail.fr>
Date: Tue, 18 Jul 2023 17:40:14 +0200
Subject: [PATCH] Put all modules to same device

---
 osrt/encoder.py            | 4 ++++
 osrt/sam/prompt_encoder.py | 2 +-
 2 files changed, 5 insertions(+), 1 deletion(-)

diff --git a/osrt/encoder.py b/osrt/encoder.py
index deb5918..706f516 100644
--- a/osrt/encoder.py
+++ b/osrt/encoder.py
@@ -151,6 +151,10 @@ class FeatureMasking(nn.Module):
         # We first initialize the automatic mask generator from SAM
         # TODO : change the loading here !!!!
         self.mask_generator = sam_model_registry[sam_model](checkpoint=sam_path) 
+        device = self.mask_generator.device
+        self.mask_generator.image_encoder.to(device)
+        self.mask_generator.prompt_encoder.to(device)
+        self.mask_generator.mask_decoder.to(device)
         self.preprocess = ResizeAndPad(self.mask_generator.image_encoder.img_size)
         self.resize = ResizeLongestSide(self.mask_generator.image_encoder.img_size)
         
diff --git a/osrt/sam/prompt_encoder.py b/osrt/sam/prompt_encoder.py
index d8954af..3d36827 100644
--- a/osrt/sam/prompt_encoder.py
+++ b/osrt/sam/prompt_encoder.py
@@ -85,7 +85,7 @@ class PromptEncoder(nn.Module):
             labels = torch.cat([labels, padding_label], dim=1)
         point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
         point_embedding[labels == -1] = 0.0
-        point_embedding[labels == -1] += self.not_a_point_embed.weight.to(point_embedding.device)
+        point_embedding[labels == -1] += self.not_a_point_embed.weight
         point_embedding[labels == 0] += self.point_embeddings[0].weight
         point_embedding[labels == 1] += self.point_embeddings[1].weight
         return point_embedding
-- 
GitLab