diff --git a/osrt/model.py b/osrt/model.py
index b01d1617025bc90ec98af05a3c9d3ed60d4d89e3..98dcdaa6bd01040cb149f4be4e57909f9ddb7d1e 100644
--- a/osrt/model.py
+++ b/osrt/model.py
@@ -7,7 +7,8 @@ import osrt.layers as layers
 import lightning.pytorch as pl
 import torch
 import torch.optim as optim
-from osrt.utils.common import mse2psnr, compute_adjusted_rand_index
+from osrt.utils.common import mse2psnr
+from osrt.utils.losses import compute_ari
 from torch.optim.lr_scheduler import LambdaLR
 from typing import Dict
 
@@ -54,10 +55,11 @@ class OSRT(nn.Module):
         
         self.encoder.train()
         self.decoder.train()
-
+        
+"""
 class LitOSRT(pl.LightningModule):
     def __init__(self, encoder:nn.Module, decoder: nn.Module, cfg: Dict, extract_masks:bool =False):
-        """OSRT Model
+        OSRT Model
         The definition of the encoder/decoder are defined in the config file with the path to classes
 
         Args:
@@ -65,7 +67,7 @@ class LitOSRT(pl.LightningModule):
             decoder: class of the decoder to use
             cfg: config file containing informations of the model
             extract_masks: wether to use masks for training
-        """
+        
         super().__init__()
         self.save_hyperparameters()
         self.cfg = cfg
@@ -80,14 +82,14 @@ class LitOSRT(pl.LightningModule):
         return self.encoder(x) # Returns: slot_latents
     
     def compute_loss(self, batch):
-        """
         Args:
             batch: dict containing the informations for training --> input images, rays and position
             extract_masks (Bool): whether to use masks to compute the segmentation loss or not
         Returns:
             loss: the loss value
             loss_terms: a dict containing more loss values
-        """
+
+
         device = self.device
         render_kwargs = self.trainer.datamodule.train_dataset.render_kwargs
 
@@ -131,10 +133,10 @@ class LitOSRT(pl.LightningModule):
 
             # These are not actually used as part of the training loss. 
             # We just add the to the dict to report them.
-            loss_terms['ari'] = compute_adjusted_rand_index(true_seg.transpose(1, 2),
+            loss_terms['ari'] = compute_ari(true_seg.transpose(1, 2),
                                                             pred_seg.transpose(1, 2))
 
-            loss_terms['fg_ari'] = compute_adjusted_rand_index(true_seg.transpose(1, 2)[:, 1:],
+            loss_terms['fg_ari'] = compute_ari(true_seg.transpose(1, 2)[:, 1:],
                                                                pred_seg.transpose(1, 2))
             
             # TODO : add new ari metrics
@@ -222,3 +224,4 @@ class LitOSRT(pl.LightningModule):
             'interval': 'step'
         }
     }
+"""
\ No newline at end of file
diff --git a/train_lit.py b/train_lit.py
index b5086eef97254a034e87dd0b8bb7f7d909644370..3694705b3bd722098ac9d6487337581ac7b77a8a 100644
--- a/train_lit.py
+++ b/train_lit.py
@@ -29,7 +29,7 @@ __LOG10 = math.log(10)
 
 def validate(fabric: L.Fabric, model: OSRT, val_dataloader: DataLoader, epoch: int = 0):
     # TODO : add segmentation also to select the model following how it's done in the training
-    """model.eval()
+    model.eval()
     mses = AverageMeter()
     psnrs = AverageMeter()
 
@@ -71,8 +71,7 @@ def validate(fabric: L.Fabric, model: OSRT, val_dataloader: DataLoader, epoch: i
     state_dict = model.state_dict()
     if fabric.global_rank == 0:
         torch.save(state_dict, os.path.join(cfg.out_dir, f"epoch-{epoch:06d}-psnr{psnrs.avg:.2f}-mse{mses.avg:.2f}-ckpt.pth"))
-    model.train()"""
-    pass
+    model.train()
 
 
 def train_sam(