diff --git a/osrt/encoder.py b/osrt/encoder.py
index 706f516ca8fbee0391d4cbeac27475a5cc84ee22..deb591831b95b4daf60af4a019e520ec67519552 100644
--- a/osrt/encoder.py
+++ b/osrt/encoder.py
@@ -151,10 +151,6 @@ class FeatureMasking(nn.Module):
         # We first initialize the automatic mask generator from SAM
         # TODO : change the loading here !!!!
         self.mask_generator = sam_model_registry[sam_model](checkpoint=sam_path) 
-        device = self.mask_generator.device
-        self.mask_generator.image_encoder.to(device)
-        self.mask_generator.prompt_encoder.to(device)
-        self.mask_generator.mask_decoder.to(device)
         self.preprocess = ResizeAndPad(self.mask_generator.image_encoder.img_size)
         self.resize = ResizeLongestSide(self.mask_generator.image_encoder.img_size)
         
diff --git a/osrt/sam/mask_decoder.py b/osrt/sam/mask_decoder.py
index d0739c84374b9e95db54cdf98d4141bd4652c99e..b6848f569564763e721afbf1d6f28267e287a0b2 100644
--- a/osrt/sam/mask_decoder.py
+++ b/osrt/sam/mask_decoder.py
@@ -133,8 +133,8 @@ class MaskDecoder(nn.Module):
         pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
         b, c, h, w = src.shape
 
-
-        self.transformer.to(src.device)
+        print(f"Src device : {src.device}")
+        print(f"Transformer MaskDecoder device : {next(self.transformer.parameters()).device}")
         # Run the transformer
         hs, src = self.transformer(src, pos_src, tokens)
         iou_token_out = hs[:, 0, :]
diff --git a/train.py b/train.py
index 4351afbf4d9fb84de942c02be0546bf81887c61a..ac9f416e040ceccae8f007c97434d01528a93e86 100755
--- a/train.py
+++ b/train.py
@@ -126,6 +126,7 @@ def main():
     optimizer = AdamW(params, lr=lr_scheduler.get_cur_lr(0)) 
 
     model, optimizer = fabric.setup(model, optimizer)
+    model = fabric.to_device(model)
 
     #########################
     ### Training