diff --git a/automatic_mask_train.py b/automatic_mask_train.py
index b85f53c214be857867ff1f415635d2895790846c..1110115b001d048d25e115fdda97d47befcea4d9 100644
--- a/automatic_mask_train.py
+++ b/automatic_mask_train.py
@@ -63,6 +63,7 @@ if __name__ == '__main__':
         'pred_iou_thresh': 0.88,
         'sam_model': model_type, 
         'sam_path': checkpoint,
+        'points_per_batch': 16
     }
     cfg['decoder_kwargs'] = {
         'pos_start_octave': -5,
@@ -91,7 +92,7 @@ if __name__ == '__main__':
     images= []
     from torchvision import transforms
     transform = transforms.ToTensor()
-    for j in range(1):
+    for j in range(2):
         image = images_path[j]
         img = cv2.imread(image)
         img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
@@ -99,8 +100,11 @@ if __name__ == '__main__':
         #images.append(np.expand_dims(img, axis=0))
     #images_np = np.array(images)
     images_t = torch.stack(images).to(device)
+    print(f"Shape image {images_t.shape}")
 
-    _, h, w = images_t[0][0].shape
+    images_t = images_t.permute(0, 1, 3, 4, 2)
+
+    h, w, c = images_t[0][0].shape
     #print(f"Begin shape {images_np.shape}")
     points = model.encoder.mask_generator.points_grid
     new_points= []
@@ -115,14 +119,18 @@ if __name__ == '__main__':
     # TODO : set ray and camera directions
     #with torch.no_grad():
     with torch.cuda.amp.autocast(): 
-        masks, slots = model.encoder(images_t, (h, w), None, None)
+        masks, slots = model.encoder(images_t, (h, w), None, None, extract_masks=True)
     end = time.time()
     print(f"Inference time : {int((end-start) * 1000)}ms")
     
     if args.visualize:
-        plt.figure(figsize=(15,15))
-        plt.imshow(img)
-        show_anns(masks[0][0]) # show masks 
-        show_points(new_points, plt.gca()) # show points
-        #plt.axis('off')
-        plt.show()
\ No newline at end of file
+        for j in range(2):
+            image = images_path[j]
+            img = cv2.imread(image)
+            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+            plt.figure(figsize=(15,15))
+            plt.imshow(img)
+            show_anns(masks[j][0]) # show masks 
+            show_points(new_points, plt.gca()) # show points
+            #plt.axis('off')
+            plt.show()
\ No newline at end of file
diff --git a/osrt/encoder.py b/osrt/encoder.py
index a96650e3c43aa6cc4eee6834cbadd47a34aff773..aa258ee9355d00b4d7cb6828f35a91bc04736b21 100644
--- a/osrt/encoder.py
+++ b/osrt/encoder.py
@@ -131,25 +131,6 @@ class FeatureMasking(nn.Module):
         # We first initialize the automatic mask generator from SAM
         # TODO : change the loading here !!!!
         sam = sam_model_registry[sam_model](checkpoint=sam_path) 
-        """if sam_model == "default" or sam_model == "vit_h":
-            encoder_embed_dim=1280
-            encoder_depth=32
-            encoder_num_heads=16
-            encoder_global_attn_indexes=[7, 15, 23, 31]
-        elif sam_model == "vit_l":
-            encoder_embed_dim=1024
-            encoder_depth=24
-            encoder_num_heads=16
-            encoder_global_attn_indexes=[5, 11, 17, 23]
-        else:
-            encoder_embed_dim=768
-            encoder_depth=12
-            encoder_num_heads=12
-            encoder_global_attn_indexes=[2, 5, 8, 11]
-        prompt_embed_dim = 256
-        image_size = 1024
-        vit_patch_size = 16
-        image_embedding_size = image_size // vit_patch_size"""
         
         self.mask_generator = SamAutomaticMask(copy.deepcopy(sam.image_encoder), 
                                                copy.deepcopy(sam.prompt_encoder),