From ece9ede389d145c99bd7962814646de0ccab1007 Mon Sep 17 00:00:00 2001
From: alexcbb <alexchapin@hotmail.fr>
Date: Mon, 17 Jul 2023 16:01:12 +0200
Subject: [PATCH] Try to adapt code from SAM lightning

---
 osrt/encoder.py              | 493 +++++++++++------------------------
 osrt/model.py                |   3 +
 osrt/sam/utils/onnx.py       |   2 +-
 osrt/sam/utils/transforms.py |  22 ++
 osrt/utils/losses.py         |  47 ++++
 osrt/utils/training.py       |  18 ++
 6 files changed, 244 insertions(+), 341 deletions(-)
 create mode 100644 osrt/utils/losses.py

diff --git a/osrt/encoder.py b/osrt/encoder.py
index 652e7d4..2388dd3 100644
--- a/osrt/encoder.py
+++ b/osrt/encoder.py
@@ -20,7 +20,7 @@ from osrt.sam import sam_model_registry
 from osrt.sam.image_encoder import ImageEncoderViT
 from osrt.sam.mask_decoder import MaskDecoder
 from osrt.sam.prompt_encoder import PromptEncoder
-from osrt.sam.utils.transforms import ResizeLongestSide
+from osrt.sam.utils.transforms import ResizeLongestSide, ResizeAndPad
 from osrt.sam.utils.amg import batched_mask_to_box, remove_small_regions, mask_to_rle_pytorch, area_from_rle
 
 from torch.nn.utils.rnn import pad_sequence
@@ -115,150 +115,30 @@ class OSRTEncoder(nn.Module):
 
 class FeatureMasking(nn.Module):
     def __init__(self, 
-                 points_per_side=12,
-                 box_nms_thresh = 0.7,
-                 stability_score_thresh = 0.9,
-                 pred_iou_thresh=0.88,
-                 points_per_batch=64,
-                 min_mask_region_area=0,
-                 num_slots=32, 
-                 slot_dim=1536, 
-                 slot_iters=3,  
-                 sam_model="default", 
-                 sam_path="sam_vit_h_4b8939.pth",
-                 randomize_initial_slots=False):
+                points_per_side: Optional[int]=12,
+                box_nms_thresh: float = 0.7,
+                stability_score_thresh: float = 0.9,
+                stability_score_offset: float = 1.0,
+                pred_iou_thresh: float =0.88,
+                points_per_batch: int =64,
+                min_mask_region_area: int=0,
+                num_slots: int=32, 
+                slot_dim: int=1536, 
+                slot_iters: int=3,  
+                sam_model="default", 
+                sam_path="sam_vit_h_4b8939.pth",
+                randomize_initial_slots=False):
         super().__init__() 
 
-        # We first initialize the automatic mask generator from SAM
-        # TODO : change the loading here !!!!
-        sam = sam_model_registry[sam_model](checkpoint=sam_path) 
-        
-        self.mask_generator = SamAutomaticMask(copy.deepcopy(sam.image_encoder), 
-                                               copy.deepcopy(sam.prompt_encoder), 
-                                               copy.deepcopy(sam.mask_decoder), 
-                                               box_nms_thresh=box_nms_thresh, 
-                                               stability_score_thresh = stability_score_thresh,
-                                               pred_iou_thresh=pred_iou_thresh,
-                                               points_per_side=points_per_side,
-                                               points_per_batch=points_per_batch,
-                                               min_mask_region_area=min_mask_region_area)
-        del sam
-        
-        self.slot_attention = SlotAttention(num_slots, input_dim=self.mask_generator.token_dim, slot_dim=slot_dim, iters=slot_iters,
-                                            randomize_initial_slots=randomize_initial_slots)
-
-    def forward(self, images, original_size=None, camera_pos=None, rays=None, extract_masks=False):
-        """
-        Args:
-            images: [batch_size, num_images, height, width, 3]. Assume the first image is canonical.
-            original_size: tuple(height, width) The original size of the image before transformation.
-            camera_pos: [batch_size, num_images, 3]
-            rays: [batch_size, num_images, height, width, 3]
-        Returns:
-            annotations: [batch_size, num_image] An array containing a dict for each image in each batch
-                with the following annotations :
-                    segmentation
-                    area
-                    predicted_iou
-                    point_coords
-                    stability_score
-                    embeddings (Optionnal)
-        """
-        
-        # Generate images
-        masks = self.mask_generator(images, original_size, camera_pos, rays, extract_embeddings=True) # [B, N]
-        
-        B, N = masks.shape
-        dim = masks[0][0]["embeddings"][0].shape[0]
-
-        # We infer each batch separetely as it handle a different number of slots
-        set_latents = None
-        # TODO : set the number of slots according to either we want min or max
-        #with torch.no_grad():
-        #num_slots = 100000
-        embedding_batch = []
-        masks_batch = []
-        for b in range(B):
-            latents_batch = torch.empty((0, dim), device=self.mask_generator.device)
-            for n in range(N):
-                embeds = masks[b][n]["embeddings"]
-                #num_slots = min(len(embeds), num_slots)
-                for embed in embeds:
-                    latents_batch = torch.cat((latents_batch, embed.unsqueeze(0)), 0)
-            masks_batch.append(torch.zeros(latents_batch.shape[:1]))
-            embedding_batch.append(latents_batch)
-        set_latents = pad_sequence(embedding_batch, batch_first=True, padding_value=0.0)
-        attention_mask = pad_sequence(masks_batch, batch_first=True, padding_value=1.0)
-        
-        # [batch_size, num_inputs = num_mask_embed x num_im, dim] 
-        #self.slot_attention.change_slots_number(num_slots)
-        slot_latents = self.slot_attention(set_latents, attention_mask)
-        
-        if extract_masks:
-            return masks, slot_latents
-        else:
-            del masks
-            return slot_latents
-
-    
-class SamAutomaticMask(nn.Module):
-    mask_threshold: float = 0.0
-
-    def __init__(
-        self,
-        image_encoder: ImageEncoderViT,
-        prompt_encoder: PromptEncoder,
-        mask_decoder: MaskDecoder,
-        pixel_mean: List[float] = [123.675, 116.28, 103.53],
-        pixel_std: List[float] = [58.395, 57.12, 57.375],
-        token_dim = 1280,
-        points_per_side: Optional[int] = 0,
-        points_per_batch: int = 64,
-        pred_iou_thresh: float = 0.88,
-        stability_score_thresh: float = 0.95,
-        stability_score_offset: float = 1.0,
-        box_nms_thresh: float = 0.7,
-        min_mask_region_area: int = 0, 
-        pos_start_octave=0,
+        token_dim = 1280
+        pos_start_octave=0
         patch_size = 16
-    ) -> None:
-        """
-        This class adapts SAM implementation from original repository but adapting it to our needs :
-        - Training only the MaskDecoder
-        - Performing automatic Mask Discovery (combined with AutomaticMask from original repo)
-
-        SAM predicts object masks from an image and input prompts.
 
-        Arguments:
-          image_encoder (ImageEncoderViT): The backbone used to encode the
-            image into image embeddings that allow for efficient mask prediction.
-          prompt_encoder (PromptEncoder): Encodes various types of input prompts.
-          mask_decoder (MaskDecoder): Predicts masks from the image embeddings
-            and encoded prompts.
-          pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
-          pixel_std (list(float)): Std values for normalizing pixels in the input image.
-        """
-        super().__init__()
-
-        # SAM part
-        self.image_encoder = image_encoder
-        self.prompt_encoder = prompt_encoder
-
-        # Freeze the image encoder and prompt encoder
-        """for param in self.image_encoder.parameters():
-            param.requires_grad = False
-        for param in self.prompt_encoder.parameters():
-            param.requires_grad = False"""
-
-        self.mask_decoder = mask_decoder
-        """for param in self.mask_decoder.parameters():
-            param.requires_grad = True"""
-        self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
-        self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
-
-        # Transform image to a square by putting it to the longest side
-        #self.resize = transforms.Resize(self.image_encoder.img_size, interpolation=transforms.InterpolationMode.BILINEAR)
-        self.transform = ResizeLongestSide(self.image_encoder.img_size)
+        # We first initialize the automatic mask generator from SAM
+        # TODO : change the loading here !!!!
+        self.mask_generator = sam_model_registry[sam_model](checkpoint=sam_path) 
+        self.preprocess = ResizeAndPad(self.mask_generator.image_encoder.img_size)
+        self.resize = ResizeLongestSide(self.mask_generator.image_encoder.img_size)
         
         if points_per_side > 0:
             self.points_grid = create_points_grid(points_per_side)
@@ -271,7 +151,7 @@ class SamAutomaticMask(nn.Module):
         self.box_nms_thresh = box_nms_thresh
         self.min_mask_region_area = min_mask_region_area
 
-        self.patch_embed_dim = (self.image_encoder.img_size // patch_size)**2
+        self.patch_embed_dim = (self.mask_generator.image_encoder.img_size // patch_size)**2
         self.token_dim = token_dim
         self.tokenizer = nn.Sequential(
                 nn.Linear(self.patch_embed_dim, 3072),
@@ -284,132 +164,128 @@ class SamAutomaticMask(nn.Module):
         # Space positional embedding
         self.ray_encoder = RayEncoder(pos_octaves=15, pos_start_octave=pos_start_octave,
                                       ray_octaves=15)
+        
+        self.slot_attention = SlotAttention(num_slots, input_dim=self.token_dim, slot_dim=slot_dim, iters=slot_iters,
+                                            randomize_initial_slots=randomize_initial_slots)
 
-    @property
-    def device(self) -> Any:
-        return self.pixel_mean.device
 
-    def forward(
-        self,
-        images,
-        orig_size,
-        camera_pos=None,
-        rays=None,
-        extract_embeddings = False):
+    def forward(self, images, camera_pos=None, rays=None, extract_masks=False):
         """
         Args:
-            images: [batch_size, num_images, 3, height, width]. Assume the first image is canonical.
-            original_size: tuple(height, width) The original size of the image before transformation.
+            images: [batch_size, num_images, height, width, 3]. Assume the first image is canonical.
             camera_pos: [batch_size, num_images, 3]
             rays: [batch_size, num_images, height, width, 3]
+            extract_mask: boolean Wether to extract masks or not
         Returns:
-            annotations: [batch_size, num_image] An array containing a dict for each image in each batch
-                with the following annotations :
-                    segmentation
-                    area
-                    predicted_iou
-                    point_coords
-                    stability_score
-                    embeddings (Optionnal)
+            
         """
+        
+        ##################################
+        #     Extract masks with SAM     #
+        ##################################
+
+        ### Get images size
         B, N, H, W, C  = images.shape
-        images = images.reshape(B*N, H, W, C) # [B x N, C, H, W] 
+        orig_size = (H, W)
+        images = images.reshape(B*N, H, W, C) # [B x N, C, H, W]         
+        im_size = self.resize(images[0]).shape[-3:-1]
 
-        # Pre-process the images for the image encoder
-        input_images = torch.stack([self.preprocess(self.transform.apply_image(x)) for x in images]) 
-        im_size = self.transform.apply_image(images[0]).shape[-3:-1]
+        ### Pre-process images for the image encoder (Resize and Pad)
+        images = torch.stack([self.preprocess(x) for x in images])
 
+        ### Encode images 
+        image_embeddings, embed_no_red = self.mask_generator.image_encoder(images, before_channel_reduc=True) # [B x N, C, H, W] 
+
+        i = 0
+        masks = np.empty((B, N), dtype=object)
         points_scale = np.array(im_size)[None, ::-1]
         points_for_image = self.points_grid * points_scale
 
-        with torch.no_grad():
-            image_embeddings, embed_no_red = self.image_encoder(input_images, before_channel_reduc=True) # [B x N, C, H, W] 
-
-        annotations = np.empty((B, N), dtype=object)
-        i = 0
+        ### Extract mask for each image in the batch
         for curr_embedding, curr_emb_no_red in zip(image_embeddings, embed_no_red):
-            mask_data = MaskData()
+            data = MaskData()
             for (points,) in batch_iterator(self.points_per_batch, points_for_image):
                 batch_data = self.process_batch(points, im_size, curr_embedding, orig_size)
-                mask_data.cat(batch_data)
+                data.cat(batch_data)
                 del batch_data
-            del curr_embedding
 
-            # Remove duplicates
+            ### Remove duplicates
             keep_by_nms = batched_nms(
-                mask_data["boxes"].float(),
-                mask_data["iou_preds"],
-                torch.zeros_like(mask_data["boxes"][:, 0]),  # categories
+                data["boxes"].float(),
+                data["iou_preds"],
+                torch.zeros_like(data["boxes"][:, 0]),  # categories
                 iou_threshold=self.box_nms_thresh,
             )
-            mask_data.filter(keep_by_nms)
+            data.filter(keep_by_nms)
 
-            mask_data["segmentations"] = mask_data["masks"]
+            data["segmentations"] = data["masks"]
 
-            # Extract mask embeddings
-            if extract_embeddings:
-                self.tokenize(mask_data, curr_emb_no_red, im_size, scale_box=1.5)
-                # TODO : vérifier le positional encoding 3D
-                #self.position_embeding_3d(mask_data["embeddings"], camera_pos[batch][num_im], rays[batch][num_im])
+            ### Extract mask embeddings
+            self.tokenize(data, curr_emb_no_red, im_size, scale_box=1.5)
+            # TODO : Add 3D positional encoding in the process
+            #self.position_embeding_3d(mask_data["embeddings"], camera_pos[batch][num_im], rays[batch][num_im])
 
-            mask_data.to_numpy()
+            data.to_numpy()
 
-            # Filter small disconnected regions and holes in masks NOT USED
+            ### Filter small disconnected regions and holes in masks NOT USED
             if self.min_mask_region_area > 0:
-                mask_data = self.postprocess_small_regions(
-                    mask_data,
+                data = self.postprocess_small_regions(
+                    data,
                     self.min_mask_region_area,
                     self.box_nms_thresh,
                 )
             
-            # TODO : have a more efficient way to store the data
-            # Write mask records
-            if extract_embeddings:
-                curr_ann = {
-                    "embeddings": mask_data["embeddings"],
-                    "segmentations": mask_data["segmentations"]
-                }
-            else :
-                curr_ann = {
-                    "segmentations": mask_data["segmentations"]
-                }
+            ### Write mask records
+            curr_ann = {
+                "embeddings": data["embeddings"],
+                "segmentations": data["segmentations"],
+                "iou_preds": data["iou_preds"]
+            }
             batch = math.floor((i / N))
             num_im = i % N
-            annotations[batch][num_im] = curr_ann
+            masks[batch][num_im] = curr_ann
             i+=1
-        return annotations # [B, N, 1] : dict containing diverse annotations such as segmentation, area or also embedding 
 
-    def postprocess_masks(
-        self,
-        masks: torch.Tensor,
-        input_size: Tuple[int, ...],
-        original_size: Tuple[int, ...],
-    ) -> torch.Tensor:
-        """
-        Remove padding and upscale masks to the original image size.
+        ### Get parameters of mask and embedding size
+        B, N = masks.shape
+        dim = masks[0][0]["embeddings"][0].shape[0]
 
-        Arguments:
-          masks (torch.Tensor): Batched masks from the mask_decoder,
-            in BxCxHxW format.
-          input_size (tuple(int, int)): The size of the image input to the
-            model, in (H, W) format. Used to remove padding.
-          original_size (tuple(int, int)): The original size of the image
-            before resizing for input to the model, in (H, W) format.
+        ##################################
+        #    Extract mask embeddings     #
+        ##################################
 
-        Returns:
-          (torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
-            is given by original_size.
-        """
-        masks = F.interpolate(
-            masks,
-            (self.image_encoder.img_size, self.image_encoder.img_size),
-            mode="bilinear",
-            align_corners=True,
-        )
+        # TODO : set the number of slots according to either we want min or max size (by default a static number of slots)
+
+        ### Pad batches to have the same length
+        set_latents = None
+        embedding_batch = []
+        masks_batch = []
+        for b in range(B):
+            latents_batch = torch.empty((0, dim))
+            for n in range(N):
+                embeds = masks[b][n]["embeddings"]
+                for embed in embeds:
+                    latents_batch = torch.cat((latents_batch, embed.unsqueeze(0)), 0)
+            masks_batch.append(torch.zeros(latents_batch.shape[:1]))
+            embedding_batch.append(latents_batch)
+        ### Contains the embeddings values padded
+        set_latents = pad_sequence(embedding_batch, batch_first=True, padding_value=0.0)
+        ### Contains the masks to avoid to apply attention to not embedded values
+        attention_mask = pad_sequence(masks_batch, batch_first=True, padding_value=1.0)
         
-        masks = masks[..., : input_size[0], : input_size[1]]
-        masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=True)
-        return masks
+        # [batch_size, num_inputs = num_mask_embed x num_im, dim] 
+        #self.slot_attention.change_slots_number(num_slots)
+
+        ###################################
+        # Apply slot latent on embeddings #
+        ###################################
+        slot_latents = self.slot_attention(set_latents, attention_mask)
+        
+        if extract_masks:
+            return masks, slot_latents
+        else:
+            del masks
+            return slot_latents
 
     @staticmethod
     def postprocess_small_regions(
@@ -460,23 +336,6 @@ class SamAutomaticMask(nn.Module):
 
         return mask_data
 
-    def preprocess(self, x: torch.Tensor) -> torch.Tensor:
-        """Normalize pixel values and pad to a square input."""
-        # Rescale the image relative to the longest side
-        # TODO : apply this preprocess to the dataset before training
-        x = torch.as_tensor(x, device=self.device)
-        x = x.permute(2, 0, 1).contiguous()#[None, :, :, :]
-        
-        # Normalize colors
-        x = (x - self.pixel_mean) / self.pixel_std
-
-        # Pad
-        h, w = x.shape[-2:]
-        padh = self.image_encoder.img_size - h
-        padw = self.image_encoder.img_size - w
-        x = F.pad(x, (0, padw, 0, padh))
-        return x
-
     def process_batch(
         self,
         points: np.ndarray,
@@ -485,18 +344,41 @@ class SamAutomaticMask(nn.Module):
         curr_orig_size
         ):
 
-        # Run model on this batch
-        in_points = torch.as_tensor(self.transform.apply_coords(points, im_size), device=self.device)
-        in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
+        ### Prepare the points of the batch
+        in_points = torch.as_tensor(self.transform.apply_coords(points, im_size))
+        in_labels = torch.ones(in_points.shape[0], dtype=torch.int)
         
-        masks, iou_preds, _ = self.predict_masks(
-            in_points[:, None, :],
-            in_labels[:, None],
+        point_coords = in_points[:, None, :]
+        point_labels = in_labels[:, None]
+
+        if point_coords is not None:
+            points = (point_coords, point_labels)
+        else:
+            points = None
+
+        ### Embed prompts
+        sparse_embeddings, dense_embeddings = self.mask_generator.prompt_encoder(
+            points=points,
+            boxes=None,
+            masks=None,
+        )
+
+        ### Predict masks
+        low_res_masks, iou_preds = self.mask_generator.mask_decoder(
+            image_embeddings=curr_embedding.unsqueeze(0),
+            image_pe=self.mask_generator.prompt_encoder.get_dense_pe(),
+            sparse_prompt_embeddings=sparse_embeddings,
+            dense_prompt_embeddings=dense_embeddings,
+            multimask_output=False,
+        )
+        
+        ### Upscale the masks to the original image resolution
+        # TODO : verify result, originally put first back to encoder input size to then put to original size linked with the one in tokenize
+        masks = F.interpolate(
+            low_res_masks,
             curr_orig_size,
-            im_size,
-            curr_embedding,
-            multimask_output=True,
-            return_logits=True
+            mode="bilinear",
+            align_corners=False,
         )
 
         data = MaskData(
@@ -506,102 +388,27 @@ class SamAutomaticMask(nn.Module):
         )
         del masks
 
-        # Filter by predicted IoU
+        ### Filter by predicted IoU
         if self.pred_iou_thresh > 0.0:
             keep_mask = data["iou_preds"] > self.pred_iou_thresh
             data.filter(keep_mask)
 
-        # Calculate stability score
+        ### Calculate stability score
         data["stability_score"] = calculate_stability_score(
-            data["masks"], self.mask_threshold, self.stability_score_offset
+            data["masks"], self.mask_generator.mask_threshold, self.stability_score_offset
         )
         if self.stability_score_thresh > 0.0:
             keep_mask = data["stability_score"] >= self.stability_score_thresh
             data.filter(keep_mask)
 
-        # Threshold masks and calculate boxes
-        data["masks"] = data["masks"] > self.mask_threshold
+        ### Threshold masks and calculate boxes
+        data["masks"] = data["masks"] > self.mask_generator.mask_threshold
         data["boxes"] = batched_mask_to_box(data["masks"])
 
         data["rles"] = mask_to_rle_pytorch(data["masks"])
 
         return data
 
-    def predict_masks(
-        self,
-        point_coords: Optional[torch.Tensor],
-        point_labels: Optional[torch.Tensor],
-        curr_orig_size,
-        curr_input_size,
-        curr_embedding,
-        boxes: Optional[torch.Tensor] = None,
-        mask_input: Optional[torch.Tensor] = None,
-        multimask_output: bool = True,
-        return_logits: bool = False,
-    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
-        """
-        Predict masks for the given input prompts, using the current image.
-        Input prompts are batched torch tensors and are expected to already be
-        transformed to the input frame using ResizeLongestSide.
-
-        Arguments:
-          point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
-            model. Each point is in (X,Y) in pixels.
-          point_labels (torch.Tensor or None): A BxN array of labels for the
-            point prompts. 1 indicates a foreground point and 0 indicates a
-            background point.
-          boxes (np.ndarray or None): A Bx4 array given a box prompt to the
-            model, in XYXY format.
-          mask_input (np.ndarray): A low resolution mask input to the model, typically
-            coming from a previous prediction iteration. Has form Bx1xHxW, where
-            for SAM, H=W=256. Masks returned by a previous iteration of the
-            predict method do not need further transformation.
-          multimask_output (bool): If true, the model will return three masks.
-            For ambiguous input prompts (such as a single click), this will often
-            produce better masks than a single prediction. If only a single
-            mask is needed, the model's predicted quality score can be used
-            to select the best mask. For non-ambiguous prompts, such as multiple
-            input prompts, multimask_output=False can give better results.
-          return_logits (bool): If true, returns un-thresholded masks logits
-            instead of a binary mask.
-
-        Returns:
-          (torch.Tensor): The output masks in BxCxHxW format, where C is the
-            number of masks, and (H, W) is the original image size.
-          (torch.Tensor): An array of shape BxC containing the model's
-            predictions for the quality of each mask.
-          (torch.Tensor): An array of shape BxCxHxW, where C is the number
-            of masks and H=W=256. These low res logits can be passed to
-            a subsequent iteration as mask input.
-        """
-        if point_coords is not None:
-            points = (point_coords, point_labels)
-        else:
-            points = None
-        # Embed prompts
-        #with torch.no_grad():
-        sparse_embeddings, dense_embeddings = self.prompt_encoder(
-            points=points,
-            boxes=boxes,
-            masks=mask_input,
-        )
-        # Predict masks
-        low_res_masks, iou_predictions = self.mask_decoder(
-            image_embeddings=curr_embedding.unsqueeze(0),
-            image_pe=self.prompt_encoder.get_dense_pe(),
-            sparse_prompt_embeddings=sparse_embeddings,
-            dense_prompt_embeddings=dense_embeddings,
-            multimask_output=multimask_output,
-        )
-
-        # Upscale the masks to the original image resolution
-        masks = self.postprocess_masks(low_res_masks, curr_input_size, curr_orig_size)
-
-        if not return_logits:
-            masks = masks > self.mask_threshold
-
-        return masks, iou_predictions.detach(), low_res_masks.detach()
-
     def tokenize(self, mask_data, image_embed, input_size, scale_box=1.5):
         """
         Predicts the embeddings from each mask given the global embedding and
@@ -622,7 +429,13 @@ class SamAutomaticMask(nn.Module):
         out_size = mean_embed.shape
 
         # Put the mean image embedding back to the input format 
-        scaled_img_emb = self.postprocess_masks(mean_embed.unsqueeze(0).unsqueeze(0), input_size, (orig_H, orig_W)).squeeze()
+        # TODO : check results linked with the one in process_batch
+        scaled_img_emb = F.interpolate(
+            mean_embed.unsqueeze(0).unsqueeze(0),
+            (orig_H, orig_W),
+            mode="bilinear",
+            align_corners=False,
+        ).squeeze()
 
         mask_data["embeddings"] = []
 
@@ -646,7 +459,7 @@ class SamAutomaticMask(nn.Module):
             token_embed += pos_embedding
 
             # Apply mask to image embedding
-            mask_data["embeddings"].append(torch.from_numpy(token_embed).to(self.device)) # [token_dim] 
+            mask_data["embeddings"].append(torch.from_numpy(token_embed)) # [token_dim] 
 
     def position_embeding_3d(self, tokens, camera_pos, rays):
         """
@@ -663,4 +476,4 @@ class SamAutomaticMask(nn.Module):
         pos_rot_x, pos_rot_y, pos_rot_z = get_positional_embedding(rot_x, self.token_dim), get_positional_embedding(rot_y, self.token_dim), get_positional_embedding(rot_z, self.token_dim)
 
         pos_embed = (pos_rot_x + pos_rot_y + pos_rot_z + pos_x + pox_y + pos_z) // 6
-        tokens += pos_embed
\ No newline at end of file
+        tokens += pos_embed
diff --git a/osrt/model.py b/osrt/model.py
index 6fb24ec..b01d161 100644
--- a/osrt/model.py
+++ b/osrt/model.py
@@ -51,6 +51,9 @@ class OSRT(nn.Module):
             self.decoder = SlotMixerDecoder(**cfg['decoder_kwargs'])
         else:
             raise ValueError(f'Unknown decoder type: {decoder_type}')
+        
+        self.encoder.train()
+        self.decoder.train()
 
 class LitOSRT(pl.LightningModule):
     def __init__(self, encoder:nn.Module, decoder: nn.Module, cfg: Dict, extract_masks:bool =False):
diff --git a/osrt/sam/utils/onnx.py b/osrt/sam/utils/onnx.py
index 3196bdf..396a373 100644
--- a/osrt/sam/utils/onnx.py
+++ b/osrt/sam/utils/onnx.py
@@ -10,7 +10,7 @@ from torch.nn import functional as F
 
 from typing import Tuple
 
-from ..modeling import Sam
+from .. import Sam
 from .amg import calculate_stability_score
 
 
diff --git a/osrt/sam/utils/transforms.py b/osrt/sam/utils/transforms.py
index e801bdf..87ee5c7 100644
--- a/osrt/sam/utils/transforms.py
+++ b/osrt/sam/utils/transforms.py
@@ -11,7 +11,29 @@ from torchvision.transforms.functional import resize, to_pil_image  # type: igno
 
 from copy import deepcopy
 from typing import Tuple
+import torchvision.transforms as transforms
 
+class ResizeAndPad:
+    def __init__(self, target_size):
+        self.target_size = target_size
+        self.transform = ResizeLongestSide(target_size)
+        self.to_tensor = transforms.ToTensor()
+
+    def __call__(self, image):
+        # Resize image
+        image = self.transform.apply_image(image)
+        image = self.to_tensor(image)
+
+        # Pad image to form a square
+        _, h, w = image.shape
+        max_dim = max(w, h)
+        pad_w = (max_dim - w) // 2
+        pad_h = (max_dim - h) // 2
+
+        padding = (pad_w, pad_h, max_dim - w - pad_w, max_dim - h - pad_h)
+        image = transforms.Pad(padding)(image)
+
+        return image
 
 class ResizeLongestSide:
     """
diff --git a/osrt/utils/losses.py b/osrt/utils/losses.py
new file mode 100644
index 0000000..cd94527
--- /dev/null
+++ b/osrt/utils/losses.py
@@ -0,0 +1,47 @@
+"""
+Code extracted from : https://github.com/luca-medeiros/lightning-sam/blob/main/lightning_sam/losses.py
+"""
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+ALPHA = 0.8
+GAMMA = 2
+
+
+class FocalLoss(nn.Module):
+
+    def __init__(self, weight=None, size_average=True):
+        super().__init__()
+
+    def forward(self, inputs, targets, alpha=ALPHA, gamma=GAMMA, smooth=1):
+        inputs = F.sigmoid(inputs)
+        inputs = torch.clamp(inputs, min=0, max=1)
+        #flatten label and prediction tensors
+        inputs = inputs.view(-1)
+        targets = targets.view(-1)
+
+        BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
+        BCE_EXP = torch.exp(-BCE)
+        focal_loss = alpha * (1 - BCE_EXP)**gamma * BCE
+
+        return focal_loss
+
+
+class DiceLoss(nn.Module):
+
+    def __init__(self, weight=None, size_average=True):
+        super().__init__()
+
+    def forward(self, inputs, targets, smooth=1):
+        inputs = F.sigmoid(inputs)
+        inputs = torch.clamp(inputs, min=0, max=1)
+        #flatten label and prediction tensors
+        inputs = inputs.view(-1)
+        targets = targets.view(-1)
+
+        intersection = (inputs * targets).sum()
+        dice = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
+
+        return 1 - dice
\ No newline at end of file
diff --git a/osrt/utils/training.py b/osrt/utils/training.py
index 70392ee..48240ca 100644
--- a/osrt/utils/training.py
+++ b/osrt/utils/training.py
@@ -16,6 +16,24 @@ from collections import defaultdict
 import time
 import os
 
+class AverageMeter:
+    """Computes and stores the average and current value."""
+
+    def __init__(self):
+        self.reset()
+
+    def reset(self):
+        self.val = 0
+        self.avg = 0
+        self.sum = 0
+        self.count = 0
+
+    def update(self, val, n=1):
+        self.val = val
+        self.sum += val * n
+        self.count += n
+        self.avg = self.sum / self.count
+        
 def compute_loss(
         batch, 
         batch_idx : int, 
-- 
GitLab