diff --git a/automatic_mask_train.py b/automatic_mask_train.py
index 1110115b001d048d25e115fdda97d47befcea4d9..9e0ee5974160169948175804051a75ef53c1a6f2 100644
--- a/automatic_mask_train.py
+++ b/automatic_mask_train.py
@@ -1,13 +1,12 @@
 import argparse
 import torch
-from osrt.encoder import SamAutomaticMask, FeatureMasking
 from osrt.model import OSRT
-from segment_anything import sam_model_registry
 import time
 import matplotlib.pyplot as plt
 import numpy as np
 import cv2
 
+
 def show_anns(masks):
     ax = plt.gca()
     ax.set_autoscale_on(False)    
@@ -57,13 +56,13 @@ if __name__ == '__main__':
     cfg['encoder'] = 'sam'
     cfg['decoder'] = 'slot_mixer'
     cfg['encoder_kwargs'] = {
-        'points_per_side': 12,
+        'points_per_side': 32,
         'box_nms_thresh': 0.7,
         'stability_score_thresh': 0.9,
         'pred_iou_thresh': 0.88,
         'sam_model': model_type, 
         'sam_path': checkpoint,
-        'points_per_batch': 16
+        'points_per_batch': 12
     }
     cfg['decoder_kwargs'] = {
         'pos_start_octave': -5,
@@ -71,10 +70,10 @@ if __name__ == '__main__':
     model = OSRT(cfg)#FeatureMasking(points_per_side=12, box_nms_thresh=0.7, stability_score_thresh= 0.9, pred_iou_thresh=0.88, points_per_batch=64)
     model.to(device, non_blocking=True)
 
-    num_encoder_params = sum(p.numel() for p in model.encoder.parameters())
+    """num_encoder_params = sum(p.numel() for p in model.encoder.parameters())
     num_decoder_params = sum(p.numel() for p in model.decoder.parameters())
 
-    """print('Number of parameters:')
+    print('Number of parameters:')
     print(f'\tEncoder: {num_encoder_params}')
 
     num_mask_encoder_params = sum(p.numel() for p in model.encoder.mask_generator.parameters())
@@ -87,12 +86,12 @@ if __name__ == '__main__':
     print(f'\t\t\tMask Decoder: {num_mask_params}.')
     print(f'\t\t\tPrompt Encoder: {num_prompt_params}.')
     print(f'\t\tSlot Attention: {num_slotatt_params}.')
-    print(f'\tDecoder: {num_decoder_params}')
-"""
+    print(f'\tDecoder: {num_decoder_params}')"""
+
     images= []
     from torchvision import transforms
     transform = transforms.ToTensor()
-    for j in range(2):
+    for j in range(10):
         image = images_path[j]
         img = cv2.imread(image)
         img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
@@ -119,12 +118,12 @@ if __name__ == '__main__':
     # TODO : set ray and camera directions
     #with torch.no_grad():
     with torch.cuda.amp.autocast(): 
-        masks, slots = model.encoder(images_t, (h, w), None, None, extract_masks=True)
+        masks = model.encoder.mask_generator(images_t, (h, w), None, None, extract_embeddings=False)
     end = time.time()
     print(f"Inference time : {int((end-start) * 1000)}ms")
     
     if args.visualize:
-        for j in range(2):
+        for j in range(10):
             image = images_path[j]
             img = cv2.imread(image)
             img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
diff --git a/osrt/encoder.py b/osrt/encoder.py
index aa258ee9355d00b4d7cb6828f35a91bc04736b21..f5b63c6ee3ff1339d4c4d70eaeceb4f48eede921 100644
--- a/osrt/encoder.py
+++ b/osrt/encoder.py
@@ -123,8 +123,8 @@ class FeatureMasking(nn.Module):
                  num_slots=32, 
                  slot_dim=1536, 
                  slot_iters=3,  
-                 sam_model="default", 
-                 sam_path="sam_vit_h_4b8939.pth",
+                 sam_model="vit_t", 
+                 sam_path="mobile_sam.pt",
                  randomize_initial_slots=False):
         super().__init__() 
 
@@ -252,8 +252,6 @@ class SamAutomaticMask(nn.Module):
         self.mask_decoder = mask_decoder
         for param in self.mask_decoder.parameters():
             param.requires_grad = True
-        self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
-        self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
 
         # Transform image to a square by putting it to the longest side
         #self.resize = transforms.Resize(self.image_encoder.img_size, interpolation=transforms.InterpolationMode.BILINEAR)
@@ -280,6 +278,9 @@ class SamAutomaticMask(nn.Module):
                 nn.Linear(2500, self.token_dim),
             )
 
+        self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
+        self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
+
         # Space positional embedding
         self.ray_encoder = RayEncoder(pos_octaves=15, pos_start_octave=pos_start_octave,
                                       ray_octaves=15)
diff --git a/sam_test.py b/sam_test.py
index fa34060be430b1898c9faf275cab317ca6abc202..56d4d8682666f13534b2ec891e9cf8350922acc4 100644
--- a/sam_test.py
+++ b/sam_test.py
@@ -1,18 +1,14 @@
-import argparse
 import torch
 from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
-from torchvision import transforms
-from PIL import Image
-import time
 import matplotlib.pyplot as plt
-import matplotlib as mpl
 import numpy as np
 import cv2
+import time
 
-def show_anns(masks):
-    if len(masks) == 0:
+def show_anns(anns):
+    if len(anns) == 0:
         return
-    sorted_anns = sorted(masks, key=(lambda x: x['area']), reverse=True)
+    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
     ax = plt.gca()
     ax.set_autoscale_on(False)
 
@@ -24,100 +20,44 @@ def show_anns(masks):
         img[m] = color_mask
     ax.imshow(img)
 
-def show_points(coords, ax, marker_size=100):
-    ax.scatter(coords[:, 0], coords[:, 1], color='#2ca02c', marker='.', s=marker_size)
-    
-
-if __name__ == '__main__':
-    # Arguments
-    parser = argparse.ArgumentParser(
-        description='Test Segment Anything Auto Mask simplified implementation'
-    )
-    parser.add_argument('--model', default='vit_b', type=str, help='Model to use')
-    parser.add_argument('--path_model', default='.', type=str, help='Path to the model')
-
-    args = parser.parse_args()
-    device = "cuda"
-    torch.backends.cuda.matmul.allow_tf32 = True
-    torch.backends.cudnn.allow_tf32 = True
-
-    model_type = args.model
-    if args.model == 'vit_h':
-        checkpoint = args.path_model + '/sam_vit_h_4b8939.pth'
-    elif args.model == 'vit_b':
-        checkpoint = args.path_model + '/sam_vit_b_01ec64.pth'
-    else:
-        checkpoint = args.path_model + '/sam_vit_l_0b3195.pth'
 
-    ycb_path = "/home/achapin/Documents/Datasets/YCB_Video_Dataset/"
-    images_path = []
-    with open(ycb_path + "image_sets/train.txt", 'r') as f:
-        for line in f.readlines():
-            line = line.strip()
-            images_path.append(ycb_path + 'data/' + line + "-color.png")
 
-    import random
-    random.shuffle(images_path)
+model_type = "vit_t"
+sam_checkpoint = "./mobile_sam.pt"
 
-    sam = sam_model_registry[model_type](checkpoint=checkpoint)
-    sam.to(device=device)
-    mask_generator = SamAutomaticMaskGenerator(sam, points_per_side=12, box_nms_thresh=0.7, crop_n_layers=0, points_per_batch=128, pred_iou_thresh=0.88)
+device = "cuda" if torch.cuda.is_available() else "cpu"
 
-    transform = transforms.Compose([
-        transforms.ToTensor(),
-    ])
-    labels = [1 for i in range(len(mask_generator.point_grids))]
-    with torch.no_grad():
-        for image in images_path:
-            img = cv2.imread(image)
-            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
-            """img_depth = cv2.imread(image.replace("color", "depth"))
-            img_depth = cv2.cvtColor(img_depth, cv2.COLOR_BGR2GRAY)"""
+mobile_sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
+mobile_sam.to(device=device)
+mobile_sam.eval()
 
-            h, w, _ = img.shape
-            points = mask_generator.point_grids[0]
-            new_points= []
-            for val in points:
-                x, y = val[0], val[1]
-                x *= w
-                y *= h
-                new_points.append([x, y])
-            new_points = np.array(new_points)
+mask_generator = SamAutomaticMaskGenerator(mobile_sam, points_per_side=16, points_per_batch= 12)
+ycb_path = "/home/achapin/Documents/Datasets/YCB_Video_Dataset/"
+images_path = []
+with open(ycb_path + "image_sets/train.txt", 'r') as f:
+    for line in f.readlines():
+        line = line.strip()
+        images_path.append(ycb_path + 'data/' + line + "-color.png")
 
-            start = time.time()
-            masks = mask_generator.generate(img)
-            end = time.time()
-            print(f"Inference time : {int((end-start) * 1000)}ms")
-
-            plt.figure(figsize=(15,15))
-            plt.imshow(img)
-            show_anns(masks)
-            show_points(new_points, plt.gca())
-            plt.axis('off')
-            plt.show()
-
-            """fig, ax = plt.subplots()
-            cmap = plt.cm.get_cmap('plasma')
-            img = ax.imshow(img_depth, cmap=cmap)
-            cbar = fig.colorbar(img, ax=ax)
-            depth_array_new = img.get_array()
-            plt.show()
-
-            depth_array_new = cv2.cvtColor(depth_array_new, cv2.COLOR_GRAY2RGB)
-            plt.imshow(depth_array_new)
-            plt.show()
-            print(depth_array_new.shape)
-
-            start = time.time()
-            masks = mask_generator.generate(depth_array_new)
-            end = time.time()
-            print(f"Inference time : {int((end-start) * 1000)}ms")
-
-            
-            plt.figure(figsize=(15,15))
-            plt.imshow(depth_array_new)
-            show_anns(masks)
-            show_points(new_points, plt.gca())
-            plt.axis('off')
-            plt.show()"""
+import random
+#random.shuffle(images_path)
 
+images= []
+from torchvision import transforms
+transform = transforms.ToTensor()
+for j in range(20):
+    image = images_path[j]
+    img = cv2.imread(image)
+    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+    images.append(transform(img).unsqueeze(0))    
+images_t = torch.stack(images).to(device)
+
+start = time.time()
+masks = mask_generator.generate(images_t)
+end = time.time()
+print(f"Inference time : {int((end-start) * 1000)}ms")
+plt.figure(figsize=(15,15))
+plt.imshow(img)
+show_anns(masks) # show masks 
+plt.axis('off')
+plt.show()
\ No newline at end of file