From 05b96da0b1f0232d37d1035700d561349c414e77 Mon Sep 17 00:00:00 2001
From: alexcbb <alexchapin@hotmail.fr>
Date: Wed, 19 Jul 2023 10:52:06 +0200
Subject: [PATCH] Logs

---
 osrt/encoder.py          | 4 ----
 osrt/sam/mask_decoder.py | 4 ++--
 train.py                 | 1 +
 3 files changed, 3 insertions(+), 6 deletions(-)

diff --git a/osrt/encoder.py b/osrt/encoder.py
index 706f516..deb5918 100644
--- a/osrt/encoder.py
+++ b/osrt/encoder.py
@@ -151,10 +151,6 @@ class FeatureMasking(nn.Module):
         # We first initialize the automatic mask generator from SAM
         # TODO : change the loading here !!!!
         self.mask_generator = sam_model_registry[sam_model](checkpoint=sam_path) 
-        device = self.mask_generator.device
-        self.mask_generator.image_encoder.to(device)
-        self.mask_generator.prompt_encoder.to(device)
-        self.mask_generator.mask_decoder.to(device)
         self.preprocess = ResizeAndPad(self.mask_generator.image_encoder.img_size)
         self.resize = ResizeLongestSide(self.mask_generator.image_encoder.img_size)
         
diff --git a/osrt/sam/mask_decoder.py b/osrt/sam/mask_decoder.py
index d0739c8..b6848f5 100644
--- a/osrt/sam/mask_decoder.py
+++ b/osrt/sam/mask_decoder.py
@@ -133,8 +133,8 @@ class MaskDecoder(nn.Module):
         pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
         b, c, h, w = src.shape
 
-
-        self.transformer.to(src.device)
+        print(f"Src device : {src.device}")
+        print(f"Transformer MaskDecoder device : {next(self.transformer.parameters()).device}")
         # Run the transformer
         hs, src = self.transformer(src, pos_src, tokens)
         iou_token_out = hs[:, 0, :]
diff --git a/train.py b/train.py
index 4351afb..ac9f416 100755
--- a/train.py
+++ b/train.py
@@ -126,6 +126,7 @@ def main():
     optimizer = AdamW(params, lr=lr_scheduler.get_cur_lr(0)) 
 
     model, optimizer = fabric.setup(model, optimizer)
+    model = fabric.to_device(model)
 
     #########################
     ### Training
-- 
GitLab