diff --git a/runs/test/config.json b/runs/test/config.json
index 43c2e06f7e5d034a23d11d7bd00a43a2af6d57d5..de98f8e2c6df682615ec26ac0a483e29f9c30a2c 100644
--- a/runs/test/config.json
+++ b/runs/test/config.json
@@ -7,9 +7,12 @@
         }
     },
     "model":{
-        "encoder": "osrt",
+        "encoder": "sam",
         "encoder_kwargs": {
-            "pos_start_octave": -5,
+            "points_per_side": 12,
+            "box_nms_thresh": 0.7,
+            "stability_score_thresh": 0.9,
+            "pred_iou_thresh": 0.88,
             "num_slots": 6
         },
         "decoder": "slot_mixer",
@@ -19,7 +22,7 @@
     },
     "training":{
         "num_workers": 4, 
-        "batch_size": 64,
+        "batch_size": 16,
         "num_gpu": 8,
         "model_selection_metric": "psnr",
         "model_selection_mode": "max",
diff --git a/train_lit.py b/train_lit.py
index 67321c7b35dcad92b53d69fe6a438a551e164b99..b290fbdddf49c59053cc3f080363fa101c3625d5 100644
--- a/train_lit.py
+++ b/train_lit.py
@@ -131,8 +131,7 @@ def train_sam(
             ### Encode input informations and extract masks
             if isinstance(model.encoder, FeatureMasking):
                 input_images = input_images.permute(0, 1, 3, 4, 2) # from [b, k, c, h, w] to [b, k,  h, w, c]
-                h, w, c = input_images[0][0].shape
-                masks_info, z = model.encoder(input_images,(h, w), input_camera_pos, input_rays, extract_masks=True)
+                masks_info, z = model.encoder(input_images, input_camera_pos, input_rays, extract_masks=True)
             else:
                 z = model.encoder(input_images, input_camera_pos, input_rays)
 
@@ -245,7 +244,7 @@ def main(cfg) -> None:
         os.makedirs(cfg['training']['out_dir'], exist_ok=True)
 
     with fabric.device:
-        model = OSRT(cfg)
+        model = OSRT(cfg["model"])
         
     #########################
     ### Loading the dataset