diff --git a/.gitignore b/.gitignore
index daa7f462c86eaf2b5fb13386aefb02423f316eb3..ded95b5c8558122940b59f598f87605d69d90c16 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,4 +1,5 @@
 *.pt
+*.pth
 .ipynb_checkpoints
 /data
 /wandb
diff --git a/.gitmodules b/.gitmodules
deleted file mode 100644
index e2de6e56b61f5f6a995c66d32d58b7cf79582a66..0000000000000000000000000000000000000000
--- a/.gitmodules
+++ /dev/null
@@ -1,3 +0,0 @@
-[submodule "segment-anything"]
-	path = segment-anything
-	url = git@github.com:facebookresearch/segment-anything.git
diff --git a/README.md b/README.md
index c43ed0956bdca3733a5a9fa8b00b985ff332349c..f3b7874a06966944e245d2b40ec3b28b67018ce6 100644
--- a/README.md
+++ b/README.md
@@ -7,6 +7,7 @@ All credit for the model goes to the original authors.
 <img src="https://drive.google.com/uc?id=1Gsoxlab6c3wOL0Bdj6SEV8L1RsI-mhWF" alt="MSN Example" width="900"/>
 
 ## Setup
+`git clone -r git@github.com:alexcbb/OSRT-experiments.git`
 After cloning the repository and creating a new conda environment, the following steps will get you started:
 
 ### Data
@@ -27,6 +28,13 @@ required to load OSRT's MultiShapeNet data, though the CPU version suffices.
 
 Rendering videos additionally depends on `ffmpeg>=4.3` being available in your `$PATH`.
 
+To install Segment Anything dependencies :
+```bash
+cd segment-anything/
+
+pip install -e
+```
+
 ## Running Experiments
 Each run's config, checkpoints, and visualization are stored in a dedicated directory. Recommended configs can be found under `runs/[dataset]/[model]`.
 
diff --git a/array.txt b/array.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8f52fa4c2d67ed09eaee8206684309132a60b92b
--- /dev/null
+++ b/array.txt
@@ -0,0 +1,480 @@
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+1.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+1.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
+0.000000000000000000e+00
diff --git a/automatic_mask_train.py b/automatic_mask_train.py
new file mode 100644
index 0000000000000000000000000000000000000000..8df774b2337558bcec063e009feecef0e87420d7
--- /dev/null
+++ b/automatic_mask_train.py
@@ -0,0 +1,98 @@
+import argparse
+import torch
+from osrt.encoder import SamAutomaticMask
+from segment_anything import sam_model_registry
+from torchvision import transforms
+import time
+import matplotlib.pyplot as plt
+import numpy as np
+import cv2
+
+def show_anns(masks):
+    ax = plt.gca()
+    ax.set_autoscale_on(False)    
+    sorted_anns = sorted(masks, key=(lambda x: x['area']), reverse=True)
+    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
+    img[:,:,3] = 0
+    for ann in sorted_anns:
+        m = ann['segmentation']
+        color_mask = np.concatenate([np.random.random(3), [0.95]])
+        img[m] = color_mask
+    ax.imshow(img)
+
+def show_points(coords, ax, marker_size=100):
+    ax.scatter(coords[:, 0], coords[:, 1], color='#2ca02c', marker='.', s=marker_size)
+    
+
+if __name__ == '__main__':
+    # Arguments
+    parser = argparse.ArgumentParser(
+        description='Test Segment Anything Auto Mask simplified implementation'
+    )
+    parser.add_argument('--model', default='vit_l', type=str, help='Model to use')
+    parser.add_argument('--path_model', default='.', type=str, help='Path to the model')
+
+    args = parser.parse_args()
+    device = "cuda"
+    torch.backends.cuda.matmul.allow_tf32 = True
+    torch.backends.cudnn.allow_tf32 = True
+
+    model_type = args.model
+    if args.model == 'vit_h':
+        checkpoint = args.path_model + '/sam_vit_h_4b8939.pth'
+    elif args.model == 'vit_b':
+        checkpoint = args.path_model + '/sam_vit_b_01ec64.pth'
+    else:
+        checkpoint = args.path_model + '/sam_vit_l_0b3195.pth'
+
+    ycb_path = "/home/achapin/Documents/Datasets/YCB_Video_Dataset/"
+    images_path = []
+    with open(ycb_path + "image_sets/train.txt", 'r') as f:
+        for line in f.readlines():
+            line = line.strip()
+            images_path.append(ycb_path + 'data/' + line + "-color.png")
+
+    import random
+    random.shuffle(images_path)
+
+    sam = sam_model_registry[model_type](checkpoint=checkpoint)
+    sam.to(device=device)
+    #mask_generator = SamAutomaticMaskGenerator(sam, points_per_side=12, box_nms_thresh=0.7, crop_n_layers=0, points_per_batch=128, pred_iou_thresh=0.88)
+    sam_mask = SamAutomaticMask(sam.image_encoder, sam.prompt_encoder, sam.mask_decoder, box_nms_thresh=0.7, stability_score_thresh= 0.9, pred_iou_thresh=0.88, points_per_side=8, points_per_batch=64, min_mask_region_area=4000)
+    sam_mask.to(device)
+    
+    transform = transforms.Compose([
+        transforms.ToTensor(),
+    ])
+    labels = [1 for i in range(len(sam_mask.points_grid))]
+    with torch.no_grad():
+        for image in images_path:
+            img = cv2.imread(image)
+            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+
+            h, w, _ = img.shape
+            points = sam_mask.points_grid
+            new_points= []
+            for val in points:
+                x, y = val[0], val[1]
+                x *= w
+                y *= h
+                new_points.append([x, y])
+            new_points = np.array(new_points)
+            img_batch = []
+            img_el = {}
+            img_el["image"] = img
+            img_el["original_size"] = (h, w)
+            img_batch.append(img_el)
+
+            start = time.time()
+            masks = sam_mask(img_batch)
+            end = time.time()
+            print(f"Inference time : {int((end-start) * 1000)}ms")
+            plt.figure(figsize=(15,15))
+            plt.imshow(img)
+            show_anns(masks[0]["annotations"])
+            show_points(new_points, plt.gca())
+            plt.axis('off')
+            plt.show()
+
diff --git a/osrt/encoder.py b/osrt/encoder.py
index 1dff3bc325e574b586487ab5bfcb199a6321f47b..8f2dcb8e84fec9ec913f16400d78d89731bd2ec0 100644
--- a/osrt/encoder.py
+++ b/osrt/encoder.py
@@ -1,9 +1,25 @@
 import numpy as np
 import torch
 import torch.nn as nn
+from typing import Any, Dict, List, Optional, Tuple
+from torch.nn import functional as F
+from torchvision.transforms.functional import resize, to_pil_image  # type: ignore
+from torchvision.ops.boxes import batched_nms
+from typing import Any, Dict, List, Optional, Tuple
+
 from osrt.layers import RayEncoder, Transformer, SlotAttention
+from osrt.utils.common import batch_iterator, MaskData, calculate_stability_score
+
 from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
+from segment_anything.modeling import Sam
+from segment_anything.modeling import Sam
+from segment_anything.modeling.image_encoder import ImageEncoderViT
+from segment_anything.modeling.mask_decoder import MaskDecoder
+from segment_anything.modeling.prompt_encoder import PromptEncoder
+from segment_anything.utils.transforms import ResizeLongestSide
+from segment_anything.utils.amg import batched_mask_to_box, remove_small_regions, mask_to_rle_pytorch, area_from_rle
 
+import cv2
 
 class SRTConvBlock(nn.Module):
     def __init__(self, idim, hdim=None, odim=None):
@@ -92,60 +108,554 @@ class OSRTEncoder(nn.Module):
         slot_latents = self.slot_attention(set_latents)
         return slot_latents
 
+
 class FeatureMasking(nn.Module):
-    def __init__(self, pos_start_octave=0, num_slots=6, num_conv_blocks=3, num_att_blocks=5, slot_dim=1536, slot_iters=1, sam_model="default", sam_path="sam_vit_h_4b8939.pth",
+    def __init__(self, 
+                 #pos_start_octave=0, 
+                 points_per_side=8,
+                 box_nms_thresh = 0.7,
+                 stability_score_thresh = 0.9,
+                 pred_iou_thresh=0.88,
+                 points_per_batch=64,
+                 min_mask_region_area=4000,
+                 num_slots=6, 
+                 slot_dim=1536, 
+                 slot_iters=1, 
+                 sam_model="default", 
+                 sam_path="sam_vit_h_4b8939.pth",
                  randomize_initial_slots=False):
-        super().__init__()
+        super().__init__() 
 
-        ##########################################################
-        ### TODO : REPLACE
+        # We first initialize the automatic mask generator from SAM
         sam = sam_model_registry[sam_model](checkpoint=sam_path) 
-        self.mask_generator = SamAutomaticMaskGenerator(sam)
+        self.mask_generator = SamAutomaticMask(sam.image_encoder, 
+                                               sam.prompt_encoder, 
+                                               sam.mask_decoder, 
+                                               box_nms_thresh=box_nms_thresh, 
+                                               stability_score_thresh = stability_score_thresh,
+                                               pred_iou_thresh=pred_iou_thresh,
+                                               points_per_side=points_per_side,
+                                               points_per_batch=points_per_batch,
+                                               min_mask_region_area=min_mask_region_area)
+    
+        """self.ray_encoder = RayEncoder(pos_octaves=15, pos_start_octave=pos_start_octave,
+                                      ray_octaves=15)"""
+        # We will see if this is usefull later on... for now, keep it simple
 
-        # They first encode the rays
-        self.ray_encoder = RayEncoder(pos_octaves=15, pos_start_octave=pos_start_octave,
-                                      ray_octaves=15)
-        # This is concatenated to the image features
-        # They then extract features for each images
-        conv_blocks = [SRTConvBlock(idim=183, hdim=96)]
-        cur_hdim = 192
-        for i in range(1, num_conv_blocks):
-            conv_blocks.append(SRTConvBlock(idim=cur_hdim, odim=None))
-            cur_hdim *= 2
+        self.slot_attention = SlotAttention(num_slots, slot_dim=slot_dim, iters=slot_iters,
+                                            randomize_initial_slots=randomize_initial_slots)
 
-        self.conv_blocks = nn.Sequential(*conv_blocks)
+    def forward(self, images):
+        # Generate images
+        masks = self.mask_generator(images)
 
-        self.per_patch_linear = nn.Conv2d(cur_hdim, 768, kernel_size=1)
 
-        self.transformer = Transformer(768, depth=num_att_blocks, heads=12, dim_head=64,
-                                       mlp_dim=1536, selfatt=True)
+        set_latents = None
+        num_masks = None
+        
+        # TODO : set the number of slots according to the masks number 
+        self.slot_attention.change_slots_number(num_masks)
 
-        ### TODO : REPLACE
-        ##########################################################
+        # [batch_size, num_inputs, dim] 
+        slot_latents = self.slot_attention(set_latents)
 
-        self.slot_attention = SlotAttention(num_slots, slot_dim=slot_dim, iters=slot_iters,
-                                            randomize_initial_slots=randomize_initial_slots)
+        return slot_latents
 
-    def forward(self, images, camera_pos, rays):
+
+class SamAutomaticMask(nn.Module):
+    mask_threshold: float = 0.0
+    image_format: str = "RGB"
+
+    def __init__(
+        self,
+        image_encoder: ImageEncoderViT,
+        prompt_encoder: PromptEncoder,
+        mask_decoder: MaskDecoder,
+        pixel_mean: List[float] = [123.675, 116.28, 103.53],
+        pixel_std: List[float] = [58.395, 57.12, 57.375],
+        points_per_side: Optional[int] = 0,
+        points_per_batch: int = 64,
+        pred_iou_thresh: float = 0.88,
+        stability_score_thresh: float = 0.95,
+        stability_score_offset: float = 1.0,
+        box_nms_thresh: float = 0.7,
+        min_mask_region_area: int = 0
+    ) -> None:
+        """
+        This class adapts SAM implementation from original repository but adapting it to our needs :
+        - Training only the MaskDecoder
+        - Performing automatic Mask Discovery (combined with AutomaticMask from original repo)
+
+        SAM predicts object masks from an image and input prompts.
+
+        Arguments:
+          image_encoder (ImageEncoderViT): The backbone used to encode the
+            image into image embeddings that allow for efficient mask prediction.
+          prompt_encoder (PromptEncoder): Encodes various types of input prompts.
+          mask_decoder (MaskDecoder): Predicts masks from the image embeddings
+            and encoded prompts.
+          pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
+          pixel_std (list(float)): Std values for normalizing pixels in the input image.
+        """
+        super().__init__()
+        self.image_encoder = image_encoder
+        self.prompt_encoder = prompt_encoder
+
+        # Freeze the image encoder and prompt encoder
+        for param in self.image_encoder.parameters():
+            param.requires_grad = False
+        for param in self.prompt_encoder.parameters():
+            param.requires_grad = False
+
+        self.mask_decoder = mask_decoder
+        self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
+        self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
+
+        # Transform image to a square by putting it to the longest side
+        self.transform = ResizeLongestSide(self.image_encoder.img_size)
         
-        masks = self.mask_generator.generate(images)
-        batch_size, num_images = images.shape[:2]
+        if points_per_side > 0:
+            self.points_grid = self.create_points_grid(points_per_side)
+        else:
+            self.points_grid = None
+        self.points_per_batch = points_per_batch
+        self.pred_iou_thresh = pred_iou_thresh
+        self.stability_score_thresh = stability_score_thresh
+        self.stability_score_offset = stability_score_offset
+        self.box_nms_thresh = box_nms_thresh
+        self.min_mask_region_area = min_mask_region_area
+
+
+    @property
+    def device(self) -> Any:
+        return self.pixel_mean.device
+
+    def forward(
+        self,
+        batched_input: List[Dict[str, Any]],
+        extract_embeddings: bool = True
+    ) -> List[Dict[str, torch.Tensor]]:
+        """
+        Predicts masks end-to-end from provided images and prompts.
+        If prompts are not known in advance, using SamPredictor is
+        recommended over calling the model directly.
+
+        Arguments:
+          batched_input (list(dict)): A list over input images, each a
+            dictionary with the following keys. A prompt key can be
+            excluded if it is not present.
+              'image': The image as in 3xHxW format 
+              'original_size': (tuple(int, int)) The original size of
+                the image before transformation, as (H, W).
+              'point_coords': (torch.Tensor) Batched point prompts for
+                this image, with shape BxNx2. Already transformed to the
+                input frame of the model.
+              'point_labels': (torch.Tensor) Batched labels for point prompts,
+                with shape BxN.
+              'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
+                Already transformed to the input frame of the model.
+              'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
+                in the form Bx1xHxW.
+          multimask_output (bool): Whether the model should predict multiple
+            disambiguating masks, or return a single mask.
 
-        x = images.flatten(0, 1)
-        camera_pos = camera_pos.flatten(0, 1)
-        rays = rays.flatten(0, 1)
+        Returns:
+          (list(dict)): A list over input images, where each element is
+            as dictionary with the following keys.
+              'masks': (torch.Tensor) Batched binary mask predictions,
+                with shape BxCxHxW, where B is the number of input prompts,
+                C is determined by multimask_output, and (H, W) is the
+                original size of the image.
+              'iou_predictions': (torch.Tensor) The model's predictions
+                of mask quality, in shape BxC.
+              'low_res_logits': (torch.Tensor) Low resolution logits with
+                shape BxCxHxW, where H=W=256. Can be passed as mask input
+                to subsequent iterations of prediction.
+        """
+        # TODO : add a way to extract mask embeddings
+        # Extract image embeddings
+        input_images = [self.preprocess(x["image"]) for x in batched_input][0]
+        with torch.no_grad():
+            image_embeddings = self.image_encoder(input_images)#, before_channel_reduc=True), embed_no_red
+            """
+        # Extract image embedding before channel reduction, cf. https://github.com/facebookresearch/segment-anything/issues/283
+        if before_channel_reduc :  
+            return x, embed_no_red  """
+
+        outputs = []
+        for image_record, curr_embedding in zip(batched_input, image_embeddings):
+            # TODO : check if we've got the points given in the batch (to change the current point_grid !)
+
+            im_size = self.transform.apply_image(image_record["image"]).shape[:2]
+            points_scale = np.array(im_size)[None, ::-1]
+            points_for_image = self.points_grid * points_scale
+
+            mask_data = MaskData()
+            for (points,) in batch_iterator(self.points_per_batch, points_for_image):
+                batch_data = self.process_batch(points, im_size, curr_embedding, image_record["original_size"])
+                mask_data.cat(batch_data)
+                del batch_data
+
+            # Remove duplicates within this crop.
+            keep_by_nms = batched_nms(
+                mask_data["boxes"].float(),
+                mask_data["iou_preds"],
+                torch.zeros_like(mask_data["boxes"][:, 0]),  # categories
+                iou_threshold=self.box_nms_thresh,
+            )
+            mask_data.filter(keep_by_nms)
+
+            mask_data.to_numpy()
+            # TODO : find a way to extract masks
+            #new_masks = self.complete_holes(mask_data["masks"])
+            #print(new_masks)
+
+            # Filter small disconnected regions and holes in masks
+            if self.min_mask_region_area > 0:
+                mask_data = self.postprocess_small_regions(
+                    mask_data,
+                    self.min_mask_region_area,
+                    self.box_nms_thresh,
+                )
+
+            mask_data["segmentations"] = mask_data["masks"]
+            mask_embed = self.extract_mask_embedding(mask_data, embed_no_red, scale_box=1.5)
+
+            # Write mask records
+            curr_anns = []
+            for idx in range(len(mask_data["segmentations"])):
+                ann = {
+                    "segmentation": mask_data["segmentations"][idx],
+                    "area": area_from_rle(mask_data["rles"][idx]),
+                    "predicted_iou": mask_data["iou_preds"][idx].item(),
+                    "point_coords": [mask_data["points"][idx].tolist()],
+                    "stability_score": mask_data["stability_score"][idx].item()
+                }
+                if extract_embeddings:
+                    # TODO : add embeddings into the annotations
+                    continue
+                curr_anns.append(ann)
+            outputs.append(
+                {
+                    "annotations": curr_anns
+                }
+            )
+        return outputs
+
+    def postprocess_masks(
+        self,
+        masks: torch.Tensor,
+        input_size: Tuple[int, ...],
+        original_size: Tuple[int, ...],
+    ) -> torch.Tensor:
+        """
+        Remove padding and upscale masks to the original image size.
 
-        ray_enc = self.ray_encoder(camera_pos, rays)
-        x = torch.cat((x, ray_enc), 1) 
-        x = self.conv_blocks(x)
-        x = self.per_patch_linear(x)
-        x = x.flatten(2, 3).permute(0, 2, 1)
+        Arguments:
+          masks (torch.Tensor): Batched masks from the mask_decoder,
+            in BxCxHxW format.
+          input_size (tuple(int, int)): The size of the image input to the
+            model, in (H, W) format. Used to remove padding.
+          original_size (tuple(int, int)): The original size of the image
+            before resizing for input to the model, in (H, W) format.
 
-        patches_per_image, channels_per_patch = x.shape[1:]
-        x = x.reshape(batch_size, num_images * patches_per_image, channels_per_patch)
+        Returns:
+          (torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
+            is given by original_size.
+        """
+        masks = F.interpolate(
+            masks,
+            (self.image_encoder.img_size, self.image_encoder.img_size),
+            mode="bilinear",
+            align_corners=False,
+        )
+        masks = masks[..., : input_size[0], : input_size[1]]
+        masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
+        return masks
+
+    @staticmethod
+    def postprocess_small_regions(
+        mask_data: MaskData, min_area: int, nms_thresh: float
+    ) -> MaskData:
+        """
+        Removes small disconnected regions and holes in masks, then reruns
+        box NMS to remove any new duplicates.
 
-        x = self.transformer(x)
+        Edits mask_data in place.
 
-        slot_latents = self.slot_attention(x)
+        Requires open-cv as a dependency.
+        """
+        if len(mask_data["rles"]) == 0:
+            return mask_data
+
+        # Filter small disconnected regions and holes
+        new_masks = []
+        scores = []
+        for mask in mask_data["masks"]:
+            mask, changed = remove_small_regions(mask, min_area, mode="holes")
+            unchanged = not changed
+            mask, changed = remove_small_regions(mask, min_area, mode="islands")
+            unchanged = unchanged and not changed
+
+            new_masks.append(torch.as_tensor(mask).unsqueeze(0))
+            # Give score=0 to changed masks and score=1 to unchanged masks
+            # so NMS will prefer ones that didn't need postprocessing
+            scores.append(float(unchanged))
+
+        # Recalculate boxes and remove any new duplicates
+        masks = torch.cat(new_masks, dim=0)
+        boxes = batched_mask_to_box(masks)
+        keep_by_nms = batched_nms(
+            boxes.float(),
+            torch.as_tensor(scores),
+            torch.zeros_like(boxes[:, 0]),  # categories
+            iou_threshold=nms_thresh,
+        )
+
+        # Only recalculate RLEs for masks that have changed
+        for i_mask in keep_by_nms:
+            if scores[i_mask] == 0.0:
+                mask_torch = masks[i_mask].unsqueeze(0)
+                mask_data["masks"][i_mask] = mask_torch
+                mask_data["boxes"][i_mask] = boxes[i_mask]  # update res directly
+        mask_data.filter(keep_by_nms)
+
+        return mask_data
+
+
+    def preprocess(self, x: torch.Tensor) -> torch.Tensor:
+        """Normalize pixel values and pad to a square input."""
+        # Rescale the image relative to the longest side
+        x = self.transform.apply_image(x)
+        x = torch.as_tensor(x, device=self.device)
+        x = x.permute(2, 0, 1).contiguous()[None, :, :, :]
+
+        # Normalize colors
+        x = (x - self.pixel_mean) / self.pixel_std
+
+        # Pad
+        h, w = x.shape[-2:]
+        padh = self.image_encoder.img_size - h
+        padw = self.image_encoder.img_size - w
+        x = F.pad(x, (0, padw, 0, padh))
+        return x
 
-        return slot_latents
+    def create_points_grid(self, number_points):
+        """Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
+        offset = 1 / (2 * number_points)
+        points_one_side = np.linspace(offset, 1 - offset, number_points)
+        points_x = np.tile(points_one_side[None, :], (number_points, 1))
+        points_y = np.tile(points_one_side[:, None], (1, number_points))
+        points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
+        return points
+
+    def process_batch(
+        self,
+        points: np.ndarray,
+        im_size: Tuple[int, ...],
+        curr_embedding,
+        curr_orig_size
+        ):
+
+        # Run model on this batch
+        transformed_points = self.transform.apply_coords(points, im_size)
+        in_points = torch.as_tensor(transformed_points, device=self.device)
+        in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
+
+        masks, iou_preds, _ = self.predict_masks(
+            in_points[:, None, :],
+            in_labels[:, None],
+            curr_orig_size,
+            im_size,
+            curr_embedding,
+            multimask_output=True,
+            return_logits=True
+        )
+
+        data = MaskData(
+            masks=masks.flatten(0, 1),
+            iou_preds=iou_preds.flatten(0, 1),
+            points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)),
+        )
+        del masks
+
+        # Filter by predicted IoU
+        if self.pred_iou_thresh > 0.0:
+            keep_mask = data["iou_preds"] > self.pred_iou_thresh
+            data.filter(keep_mask)
+
+        # Calculate stability score
+        data["stability_score"] = calculate_stability_score(
+            data["masks"], self.mask_threshold, self.stability_score_offset
+        )
+        if self.stability_score_thresh > 0.0:
+            keep_mask = data["stability_score"] >= self.stability_score_thresh
+            data.filter(keep_mask)
+
+        # Threshold masks and calculate boxes
+        data["masks"] = data["masks"] > self.mask_threshold
+        data["boxes"] = batched_mask_to_box(data["masks"])
+
+        data["rles"] = mask_to_rle_pytorch(data["masks"])
+
+        return data
+
+    def predict_masks(
+        self,
+        point_coords: Optional[torch.Tensor],
+        point_labels: Optional[torch.Tensor],
+        curr_orig_size,
+        curr_input_size,
+        curr_embedding,
+        boxes: Optional[torch.Tensor] = None,
+        mask_input: Optional[torch.Tensor] = None,
+        multimask_output: bool = True,
+        return_logits: bool = False,
+    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+        """
+        Predict masks for the given input prompts, using the current image.
+        Input prompts are batched torch tensors and are expected to already be
+        transformed to the input frame using ResizeLongestSide.
+
+        Arguments:
+          point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
+            model. Each point is in (X,Y) in pixels.
+          point_labels (torch.Tensor or None): A BxN array of labels for the
+            point prompts. 1 indicates a foreground point and 0 indicates a
+            background point.
+          boxes (np.ndarray or None): A Bx4 array given a box prompt to the
+            model, in XYXY format.
+          mask_input (np.ndarray): A low resolution mask input to the model, typically
+            coming from a previous prediction iteration. Has form Bx1xHxW, where
+            for SAM, H=W=256. Masks returned by a previous iteration of the
+            predict method do not need further transformation.
+          multimask_output (bool): If true, the model will return three masks.
+            For ambiguous input prompts (such as a single click), this will often
+            produce better masks than a single prediction. If only a single
+            mask is needed, the model's predicted quality score can be used
+            to select the best mask. For non-ambiguous prompts, such as multiple
+            input prompts, multimask_output=False can give better results.
+          return_logits (bool): If true, returns un-thresholded masks logits
+            instead of a binary mask.
+
+        Returns:
+          (torch.Tensor): The output masks in BxCxHxW format, where C is the
+            number of masks, and (H, W) is the original image size.
+          (torch.Tensor): An array of shape BxC containing the model's
+            predictions for the quality of each mask.
+          (torch.Tensor): An array of shape BxCxHxW, where C is the number
+            of masks and H=W=256. These low res logits can be passed to
+            a subsequent iteration as mask input.
+        """
+        if point_coords is not None:
+            points = (point_coords, point_labels)
+        else:
+            points = None
+        # Embed prompts
+        sparse_embeddings, dense_embeddings = self.prompt_encoder(
+            points=points,
+            boxes=boxes,
+            masks=mask_input,
+        )
+
+        # Predict masks
+        low_res_masks, iou_predictions = self.mask_decoder(
+            image_embeddings=curr_embedding.unsqueeze(0),
+            image_pe=self.prompt_encoder.get_dense_pe(),
+            sparse_prompt_embeddings=sparse_embeddings,
+            dense_prompt_embeddings=dense_embeddings,
+            multimask_output=multimask_output,
+        )
+
+        # Upscale the masks to the original image resolution
+        masks = self.postprocess_masks(low_res_masks, curr_input_size, curr_orig_size)
+
+        if not return_logits:
+            masks = masks > self.mask_threshold
+
+        return masks, iou_predictions, low_res_masks
+
+    def extract_mask_embedding(self, mask_data, image_embed, scale_box=1.5):
+        """
+        Predicts the embeddings from each mask given the global embedding and
+        a scale factor around each mask.
+
+        Arguments:
+          mask_data : the data of the masks extracted by SAM
+          image_embed : embedding of the corresponding image
+          scale_box : factor by which the bounding box will be scaled
+
+        Returns:
+          embeddings : the embeddings for each mask extracted from the image
+        """
+        for idx in range(len(mask_data["segmentations"])):
+            mask = mask_data["segmentations"][idx]
+            box = mask_data["boxes"][idx]
+
+            def scale_bounding_box(box, scale_factor):
+              x1, y1, x2, y2 = box
+
+              width = x2 - x1
+              height = y2 - y1
+
+              new_width = width * scale_factor
+              new_height = height * scale_factor
+
+              new_x1 = x1 - (new_width - width) / 2
+              new_y1 = y1 - (new_height - height) / 2
+              new_x2 = new_x1 + new_width
+              new_y2 = new_y1 + new_height
+
+              return new_x1, new_y1, new_x2, new_y2
+            
+            # Scale bounding box 
+            scaled_box = scale_bounding_box(box, scale_box)
+            print(image_embed.shape)
+
+        masks_embedding = None
+        return masks_embedding
+
+    def complete_holes(self,
+                       masks):
+        """"
+        The purpose of this function is to segment EVERYTHING from the image, without letting any remaining hole
+        """
+        total_mask = masks[0]
+        for idx in range(len(masks)):
+            if idx > 0:
+                total_mask += masks[idx]
+
+        des = total_mask.astype(np.uint8)*255
+        kernel = np.ones((4, 4), np.uint8)
+        img_dilate = cv2.dilate(des, kernel, iterations=1)
+
+        import matplotlib.pyplot as plt
+        plt.imshow(img_dilate)
+        plt.show()
+
+        inverse_dilate = np.zeros((total_mask.shape), dtype=np.uint8)
+        inverse_dilate = np.logical_not(img_dilate).astype(np.uint8)*255
+
+        contours, _ = cv2.findContours(inverse_dilate, cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_SIMPLE)
+        result_masks = []
+        for contour in contours:
+            area = cv2.contourArea(contour)
+            if area > 4000:
+              mask = np.zeros((total_mask.shape), dtype=np.uint8)
+              cv2.drawContours(mask, [contour], 0, 255, -1)
+              result_masks.append(mask)
+
+        new_masks_data = MaskData(
+            masks=torch.tensor(result_masks),
+            iou_preds=torch.tensor([0.9 for i in range(len(result_masks))])
+        )
+
+        new_masks_data["stability_score"] = calculate_stability_score(
+            new_masks_data["masks"], self.mask_threshold, self.stability_score_offset
+        )
+
+        new_masks_data["boxes"] = batched_mask_to_box(new_masks_data["masks"])
+
+        new_masks_data["rles"] = mask_to_rle_pytorch(new_masks_data["masks"])
+
+        return new_masks_data.to_numpy()
\ No newline at end of file
diff --git a/osrt/layers.py b/osrt/layers.py
index 2cfdecc1610caf22087d7befdea7ffe388169431..019d5c025d9981e19d90bdcdf31a840d1917f0a7 100644
--- a/osrt/layers.py
+++ b/osrt/layers.py
@@ -251,4 +251,5 @@ class SlotAttention(nn.Module):
 
         return slots
 
-
+    def change_slots_number(self, num_slots): 
+        self.num_slots = num_slots
diff --git a/osrt/utils/common.py b/osrt/utils/common.py
index 6ac5489d9833f72ec015d5fd4044e62cff69d0c5..832d5e3ef87c17503ac1dc244554916f10163916 100644
--- a/osrt/utils/common.py
+++ b/osrt/utils/common.py
@@ -4,9 +4,106 @@ import os
 import torch
 import torch.distributed as dist
 
+from typing import Any, List, Generator, ItemsView
+import numpy as np
+import math
+from copy import deepcopy
 
 __LOG10 = math.log(10)
 
+# Method extracted from SAM repo in segment_anything.utils.amg.py
+def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
+    assert len(args) > 0 and all(
+        len(a) == len(args[0]) for a in args
+    ), "Batched iteration must have inputs of all the same size."
+    n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
+    for b in range(n_batches):
+        yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
+
+# Method extracted from SAM repo in segment_anything.utils.amg.py
+class MaskData:
+    """
+    A structure for storing masks and their related data in batched format.
+    Implements basic filtering and concatenation.
+    """
+
+    def __init__(self, **kwargs) -> None:
+        for v in kwargs.values():
+            assert isinstance(
+                v, (list, np.ndarray, torch.Tensor)
+            ), "MaskData only supports list, numpy arrays, and torch tensors."
+        self._stats = dict(**kwargs)
+
+    def __setitem__(self, key: str, item: Any) -> None:
+        assert isinstance(
+            item, (list, np.ndarray, torch.Tensor)
+        ), "MaskData only supports list, numpy arrays, and torch tensors."
+        self._stats[key] = item
+
+    def __delitem__(self, key: str) -> None:
+        del self._stats[key]
+
+    def __getitem__(self, key: str) -> Any:
+        return self._stats[key]
+
+    def items(self) -> ItemsView[str, Any]:
+        return self._stats.items()
+
+    def filter(self, keep: torch.Tensor) -> None:
+        for k, v in self._stats.items():
+            if v is None:
+                self._stats[k] = None
+            elif isinstance(v, torch.Tensor):
+                self._stats[k] = v[torch.as_tensor(keep, device=v.device)]
+            elif isinstance(v, np.ndarray):
+                self._stats[k] = v[keep.detach().cpu().numpy()]
+            elif isinstance(v, list) and keep.dtype == torch.bool:
+                self._stats[k] = [a for i, a in enumerate(v) if keep[i]]
+            elif isinstance(v, list):
+                self._stats[k] = [v[i] for i in keep]
+            else:
+                raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
+
+    def cat(self, new_stats: "MaskData") -> None:
+        for k, v in new_stats.items():
+            if k not in self._stats or self._stats[k] is None:
+                self._stats[k] = deepcopy(v)
+            elif isinstance(v, torch.Tensor):
+                self._stats[k] = torch.cat([self._stats[k], v], dim=0)
+            elif isinstance(v, np.ndarray):
+                self._stats[k] = np.concatenate([self._stats[k], v], axis=0)
+            elif isinstance(v, list):
+                self._stats[k] = self._stats[k] + deepcopy(v)
+            else:
+                raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
+
+    def to_numpy(self) -> None:
+        for k, v in self._stats.items():
+            if isinstance(v, torch.Tensor):
+                self._stats[k] = v.detach().cpu().numpy()
+
+# Method extracted from SAM repo in segment_anything.utils.amg.py
+def calculate_stability_score(
+    masks: torch.Tensor, mask_threshold: float, threshold_offset: float
+) -> torch.Tensor:
+    """
+    Computes the stability score for a batch of masks. The stability
+    score is the IoU between the binary masks obtained by thresholding
+    the predicted mask logits at high and low values.
+    """
+    # One mask is always contained inside the other.
+    # Save memory by preventing unnecessary cast to torch.int64
+    intersections = (
+        (masks > (mask_threshold + threshold_offset))
+        .sum(-1, dtype=torch.int16)
+        .sum(-1, dtype=torch.int32)
+    )
+    unions = (
+        (masks > (mask_threshold - threshold_offset))
+        .sum(-1, dtype=torch.int16)
+        .sum(-1, dtype=torch.int32)
+    )
+    return intersections / unions
 
 def mse2psnr(x):
     return -10.*torch.log(x)/__LOG10
diff --git a/sam_test.py b/sam_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..a62e962695918cc8412443e86f4a61eb85093b59
--- /dev/null
+++ b/sam_test.py
@@ -0,0 +1,123 @@
+import argparse
+import torch
+from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
+from torchvision import transforms
+from PIL import Image
+import time
+import matplotlib.pyplot as plt
+import matplotlib as mpl
+import numpy as np
+import cv2
+
+def show_anns(masks):
+    if len(masks) == 0:
+        return
+    sorted_anns = sorted(masks, key=(lambda x: x['area']), reverse=True)
+    ax = plt.gca()
+    ax.set_autoscale_on(False)
+
+    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
+    img[:,:,3] = 0
+    for ann in sorted_anns:
+        m = ann['segmentation']
+        color_mask = np.concatenate([np.random.random(3), [0.95]])
+        img[m] = color_mask
+    ax.imshow(img)
+
+def show_points(coords, ax, marker_size=100):
+    ax.scatter(coords[:, 0], coords[:, 1], color='#2ca02c', marker='.', s=marker_size)
+    
+
+if __name__ == '__main__':
+    # Arguments
+    parser = argparse.ArgumentParser(
+        description='Test Segment Anything Auto Mask simplified implementation'
+    )
+    parser.add_argument('--model', default='vit_b', type=str, help='Model to use')
+    parser.add_argument('--path_model', default='.', type=str, help='Path to the model')
+
+    args = parser.parse_args()
+    device = "cuda"
+    torch.backends.cuda.matmul.allow_tf32 = True
+    torch.backends.cudnn.allow_tf32 = True
+
+    model_type = args.model
+    if args.model == 'vit_h':
+        checkpoint = args.path_model + '/sam_vit_h_4b8939.pth'
+    elif args.model == 'vit_b':
+        checkpoint = args.path_model + '/sam_vit_b_01ec64.pth'
+    else:
+        checkpoint = args.path_model + '/sam_vit_l_0b3195.pth'
+
+    ycb_path = "/home/achapin/Documents/Datasets/YCB_Video_Dataset/"
+    images_path = []
+    with open(ycb_path + "image_sets/train.txt", 'r') as f:
+        for line in f.readlines():
+            line = line.strip()
+            images_path.append(ycb_path + 'data/' + line + "-color.png")
+
+    import random
+    #random.shuffle(images_path)
+
+    sam = sam_model_registry[model_type](checkpoint=checkpoint)
+    sam.to(device=device)
+    mask_generator = SamAutomaticMaskGenerator(sam, points_per_side=12, box_nms_thresh=0.7, crop_n_layers=0, points_per_batch=128, pred_iou_thresh=0.88)
+
+    transform = transforms.Compose([
+        transforms.ToTensor(),
+    ])
+    labels = [1 for i in range(len(mask_generator.point_grids))]
+    with torch.no_grad():
+        for image in images_path:
+            img = cv2.imread(image)
+            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+            """img_depth = cv2.imread(image.replace("color", "depth"))
+            img_depth = cv2.cvtColor(img_depth, cv2.COLOR_BGR2GRAY)"""
+
+            h, w, _ = img.shape
+            points = mask_generator.point_grids[0]
+            new_points= []
+            for val in points:
+                x, y = val[0], val[1]
+                x *= w
+                y *= h
+                new_points.append([x, y])
+            new_points = np.array(new_points)
+
+            start = time.time()
+            masks = mask_generator.generate(img)
+            end = time.time()
+            print(f"Inference time : {int((end-start) * 1000)}ms")
+
+            plt.figure(figsize=(15,15))
+            plt.imshow(img)
+            show_anns(masks)
+            show_points(new_points, plt.gca())
+            plt.axis('off')
+            plt.show()
+
+            """fig, ax = plt.subplots()
+            cmap = plt.cm.get_cmap('plasma')
+            img = ax.imshow(img_depth, cmap=cmap)
+            cbar = fig.colorbar(img, ax=ax)
+            depth_array_new = img.get_array()
+            plt.show()
+
+            depth_array_new = cv2.cvtColor(depth_array_new, cv2.COLOR_GRAY2RGB)
+            plt.imshow(depth_array_new)
+            plt.show()
+            print(depth_array_new.shape)
+
+            start = time.time()
+            masks = mask_generator.generate(depth_array_new)
+            end = time.time()
+            print(f"Inference time : {int((end-start) * 1000)}ms")
+
+            
+            plt.figure(figsize=(15,15))
+            plt.imshow(depth_array_new)
+            show_anns(masks)
+            show_points(new_points, plt.gca())
+            plt.axis('off')
+            plt.show()"""
+
diff --git a/segment-anything b/segment-anything
deleted file mode 160000
index 6fdee8f2727f4506cfbbe553e23b895e27956588..0000000000000000000000000000000000000000
--- a/segment-anything
+++ /dev/null
@@ -1 +0,0 @@
-Subproject commit 6fdee8f2727f4506cfbbe553e23b895e27956588