diff --git a/.gitmodules b/.gitmodules
index a8ff4b8fe903913972aeaee9f41e0109c6ac84f1..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -1,3 +0,0 @@
-[submodule "segment-anything"]
-	path = segment-anything
-	url = https://github.com/facebookresearch/segment-anything.git
diff --git a/runs/test/config.json b/runs/test/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..6246b88b65bdf17cf8f6d50d39ea85c0214afa8c
--- /dev/null
+++ b/runs/test/config.json
@@ -0,0 +1,38 @@
+{
+    "data": {
+        "dataset": "clevr3d",
+        "num_points": 2000 ,
+        "kwargs": {
+            "downsample": 1
+        }
+    },
+    "model":{
+        "encoder": "osrt",
+        "encoder_kwargs": {
+            "pos_start_octave": -5,
+            "num_slots": 6
+        },
+        "decoder": "slot_mixer",
+        "decoder_kwargs":{
+            "pos_start_octave": -5
+        }
+    },
+    "training":{
+        "num_workers": 4, 
+        "batch_size": 64,
+        "num_gpu": 8,
+        "model_selection_metric": "psnr",
+        "model_selection_mode": "max",
+        "print_every": 10,
+        "visualize_every": 5000,
+        "validate_every": 5000,
+        "checkpoint_every": 1000,
+        "backup_every": 25000,
+        "max_it": 333000000,
+        "decay_it": 4000000,
+        "lr_warmup": 5000,
+        "precision": "16-mixed",
+        "out_dir": "."
+    }
+    
+}
\ No newline at end of file
diff --git a/train_lit.py b/train_lit.py
index ef1d7ba42a3d572c21276cac3c9d06fbc0280a65..da98f7785c642f6a4b845d4a54356719114e1967 100644
--- a/train_lit.py
+++ b/train_lit.py
@@ -1,99 +1,224 @@
 """
-Code inspired from Lit-Llama training script : https://github.com/Lightning-AI/lit-llama/blob/main/finetune/full.py
+Code inspired and adapted from : https://github.com/luca-medeiros/lightning-sam/blob/main/lightning_sam/train.py
 """
-import sys
-from pathlib import Path
+
 import os
 import time
-from functools import partial
+import json
+import argparse
+import math
 
 import lightning as L
-from lightning.fabric.strategies import FSDPStrategy
-import numpy as np
+import segmentation_models_pytorch as smp
 import torch
+import torch.nn.functional as F
+from lightning.fabric.fabric import _FabricOptimizer
+from lightning.fabric.loggers import TensorBoardLogger
 from torch.utils.data import DataLoader
-from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
-from jsonargparse.cli import CLI
-import json
 
-# support running without installing as a package
-wd = Path(__file__).parent.parent.resolve()
-sys.path.append(str(wd))
+from osrt.model import OSRT
+from osrt.encoder import FeatureMasking
+from osrt import data
+from osrt.utils.training import AverageMeter
+from osrt.utils.losses import DiceLoss, FocalLoss
 
-from generate import generate
-from lit_llama.model import Block, LLaMA, LLaMAConfig
-from lit_llama.tokenizer import Tokenizer
-from lit_llama.utils import save_model_checkpoint
-from scripts.prepare_alpaca import generate_prompt
+torch.set_float32_matmul_precision('high')
 
-from osrt.layers import Transformer
-from osrt import data
-from osrt.model import OSRT
+__LOG10 = math.log(10)
 
-from segment_anything.modeling.transformer import TwoWayTransformer
-
-
-instruction_tuning = True
-eval_interval = 1000
-save_interval = 1000
-eval_iters = 100
-log_interval = 100
-
-# Hyperparameters
-learning_rate = 3e-5
-micro_batch_size = 4
-"""gradient_accumulation_iters = batch_size // micro_batch_size
-assert gradient_accumulation_iters > 0"""
-epoch_size = 50000  # train dataset size
-num_epochs = 5
-#max_iters = num_epochs * (epoch_size // micro_batch_size) // devices
-weight_decay = 0.0
-block_size = 512
-warmup_iters = 100
-
-class LrScheduler():
-    """ Implements a learning rate schedule with warum up and decay """
-    def __init__(self, peak_lr=4e-4, peak_it=10000, decay_rate=0.5, decay_it=100000):
-        self.peak_lr = peak_lr
-        self.peak_it = peak_it
-        self.decay_rate = decay_rate
-        self.decay_it = decay_it
-
-    def get_cur_lr(self, it):
-        if it < self.peak_it:  # Warmup period
-            return self.peak_lr * (it / self.peak_it)
-        it_since_peak = it - self.peak_it
-        return self.peak_lr * (self.decay_rate ** (it_since_peak / self.decay_it))
-
-
-def main(
-    config_path:str,
-    data_dir: str = "data/alpaca",
-    out_dir: str = "out/full/alpaca",
-    checkpoint :str = None
+def validate(fabric: L.Fabric, model: OSRT, val_dataloader: DataLoader, epoch: int = 0):
+    # TODO : add segmentation also to select the model following how it's done in the training
+    model.eval()
+    mses = AverageMeter()
+    psnrs = AverageMeter()
+
+    sceneids = []
+
+    with torch.no_grad():
+        for iter, data in enumerate(val_dataloader):
+            sceneids.append(data['sceneid'])
+
+            input_images = data.get('input_images')
+            input_camera_pos = data.get('input_camera_pos')
+            input_rays = data.get('input_rays')
+            target_pixels = data.get('target_pixels')
+
+            if isinstance(model.encoder, FeatureMasking):
+                input_images = input_images.permute(0, 1, 3, 4, 2) # from [b, k, c, h, w] to [b, k,  h, w, c]
+                h, w, c = input_images[0][0].shape
+                z = model.encoder(input_images,(h, w), input_camera_pos, input_rays)
+            else:
+                z = model.encoder(input_images, input_camera_pos, input_rays)
+
+            target_camera_pos = data.get('target_camera_pos')
+            target_rays = data.get('target_rays')
+
+            loss_mse = torch.tensor(0., device=fabric.device)
+            pred_pixels, extras = model.decoder(z, target_camera_pos, target_rays)#, **self.render_kwargs)
+
+            ### Compute MSE on pixels 
+            loss_mse = loss_mse + ((pred_pixels - target_pixels)**2).mean((1, 2))
+            psnr = -10.*torch.log(loss_mse)/__LOG10
+            mses.update(loss_mse)
+            psnrs.update(psnr)
+            fabric.print(f"Val [{epoch}] - [{iter}/{len(val_dataloader)}] : psnr {psnr}, mse: {loss_mse}")
+    
+    fabric.print(f'Validation [{epoch}]: Mean psnr: [{psnrs.avg:.4f}] -- Mean mse: [{mses.avg:.4f}]')
+
+
+    fabric.print(f"Saving checkpoint to {cfg.out_dir}")
+    state_dict = model.state_dict()
+    if fabric.global_rank == 0:
+        torch.save(state_dict, os.path.join(cfg.out_dir, f"epoch-{epoch:06d}-psnr{psnrs.avg:.2f}-mse{mses.avg:.2f}-ckpt.pth"))
+    model.train()
+
+
+def train_sam(
+    cfg,
+    fabric: L.Fabric,
+    model: OSRT,
+    optimizer: _FabricOptimizer,
+    scheduler: _FabricOptimizer,
+    train_dataloader: DataLoader,
+    val_dataloader: DataLoader,
 ):
+    """The SAM training loop."""
+
+    focal_loss = FocalLoss()
+    dice_loss = DiceLoss()
+    nb_epochs = cfg["training"]["max_it"] // cfg["training"]["batch_size"]
+    for epoch in range(1, nb_epochs):
+        # TODO : add psnr loss ?
+        batch_time = AverageMeter()
+        data_time = AverageMeter()
+        focal_losses = AverageMeter()
+        dice_losses = AverageMeter()
+        mse_losses = AverageMeter()
+        total_losses = AverageMeter()
+        end = time.time()
+        validated = False
+
+        for iter, data in enumerate(train_dataloader):
+            if epoch > 1 and epoch % cfg["training"]["validate_every"] == 0 and not validated:
+                validate(fabric, model, val_dataloader, epoch)
+                validated = True
+
+            data_time.update(time.time() - end)
+
+            # TODO : adapt to our model
+            input_images = data.get('input_images')
+            input_camera_pos = data.get('input_camera_pos')
+            input_rays = data.get('input_rays')
+            target_pixels = data.get('target_pixels')
+
+            if isinstance(model.encoder, FeatureMasking):
+                input_images = input_images.permute(0, 1, 3, 4, 2) # from [b, k, c, h, w] to [b, k,  h, w, c]
+                h, w, c = input_images[0][0].shape
+                masks_info, z = model.encoder(input_images,(h, w), input_camera_pos, input_rays, extract_masks=True)
+            else:
+                z = model.encoder(input_images, input_camera_pos, input_rays)
+
+            target_camera_pos = data.get('target_camera_pos')
+            target_rays = data.get('target_rays')
+
+            loss_mse = torch.tensor(0., device=fabric.device)
+            loss_focal = torch.tensor(0., device=fabric.device)
+            loss_dice = torch.tensor(0., device=fabric.device)
+            pred_pixels, extras = model.decoder(z, target_camera_pos, target_rays)#, **self.render_kwargs)
+
+            ### Compute MSE on pixels 
+            loss_mse = loss_mse + ((pred_pixels - target_pixels)**2).mean((1, 2))
+
+            batch_size = input_images.shape[0]
+
+            if 'segmentation' in extras:
+                # TODO : for visualisation only, could be interesting to check real GT
+                #true_seg = data['target_masks'].float()
+
+                pred_masks = extras['segmentation']
+                # TODO : check the content of num_masks
+                num_masks = sum(len(pred_mask) for pred_mask in pred_mask)
+                for pred_mask, gt_mask in zip(pred_masks, masks_info["segmentations"]):
+                    loss_focal += focal_loss(pred_mask, gt_mask, num_masks)
+                    loss_dice += dice_loss(pred_mask, gt_mask, num_masks)
+
+                # TODO : check the values of the loss and see if scale is ok
+                loss_total = 20. * loss_focal + loss_dice + loss_mse 
+                # TODO : check also with ARI, FG-ARI values and new from recent paper
+                """loss_terms['ari'] = compute_adjusted_rand_index(true_seg.transpose(1, 2),
+                                                                pred_seg.transpose(1, 2))
+
+                loss_terms['fg_ari'] = compute_adjusted_rand_index(true_seg.transpose(1, 2)[:, 1:],
+                                                                    pred_seg.transpose(1, 2))"""
+            
+            optimizer.zero_grad()
+            fabric.backward(loss_total)
+            optimizer.step()
+            scheduler.step()
+            batch_time.update(time.time() - end)
+            end = time.time()
+
+            focal_losses.update(loss_focal.item(), batch_size)
+            dice_losses.update(loss_dice.item(), batch_size)
+            mse_losses.update(loss_mse.item(), batch_size)
+            total_losses.update(loss_total.item(), batch_size)
+
+            fabric.print(f'Epoch: [{epoch}][{iter+1}/{len(train_dataloader)}]'
+                         f' | Time [{batch_time.val:.3f}s ({batch_time.avg:.3f}s)]'
+                         f' | Data [{data_time.val:.3f}s ({data_time.avg:.3f}s)]'
+                         f' | Focal Loss [{focal_losses.val:.4f} ({focal_losses.avg:.4f})]'
+                         f' | Dice Loss [{dice_losses.val:.4f} ({dice_losses.avg:.4f})]'
+                         f' | MSE Loss [{mse_losses.val:.4f} ({mse_losses.avg:.4f})]'
+                         f' | Total Loss [{total_losses.val:.4f} ({total_losses.avg:.4f})]')
+
+def configure_opt(cfg, model: OSRT):
+    warmup_iters = cfg['training']['decay_it'] if 'decay_it' in cfg['training'] else 4000000
+    peak_it = cfg['training']['lr_warmup'] if 'lr_warmup' in cfg['training'] else 2500
+    peak_lr = 1e-4
+    decay_rate=0.16
 
-    with open(config_path, 'r') as f:
-        cfg = json.load(f)
+    # LrScheduler(peak_lr=1e-4, peak_it=peak_it, decay_it=warmup_iters, decay_rate=0.16)
+    def lr_lambda(step):
+        if step < peak_it:  # Warmup period
+            return peak_lr * (step / peak_it)
+        it_since_peak = step - peak_it
+        return peak_lr * (decay_rate ** (it_since_peak / warmup_iters))
+
+    # TODO : check begin value of lr
+    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=decay_rate)
+    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
 
+    return optimizer, scheduler
+
+
+def main(cfg) -> None:
+
+    #########################
+    ### Setup parameters
+    #########################
     num_devices = cfg['training']['num_gpu'] if 'num_gpu' in cfg['training'] else 1
     num_workers = cfg['training']['num_workers'] if 'num_workers' in cfg['training'] else 1
     batch_size = cfg['training']['batch_size'] // num_devices
-
-    auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Transformer, TwoWayTransformer})
-    strategy = FSDPStrategy(auto_wrap_policy=auto_wrap_policy, activation_checkpointing={Transformer, TwoWayTransformer}, limit_all_gathers=True)
-
-    # TODO : activer precision bf16
-    fabric = L.Fabric(accelerator="cuda", devices=num_devices, precision=cfg["training"]["precision"], strategy=strategy)
+    
+    #########################
+    ### Launch the model
+    #########################
+    fabric = L.Fabric(accelerator="gpu",
+                      devices=num_devices,
+                      strategy="auto",
+                      loggers=[TensorBoardLogger(cfg['training']['out_dir'], name="lightning-sam")])
     fabric.launch()
     fabric.seed_everything(1337 + fabric.global_rank)
 
     if fabric.global_rank == 0:
-        os.makedirs(out_dir, exist_ok=True)
+        os.makedirs(cfg['training']['out_dir'], exist_ok=True)
 
-    ###################
-    #   Import Dataset
-    ###################
+    with fabric.device:
+        model = OSRT(cfg)
+        
+    #########################
+    ### Loading the dataset
+    #########################
     train_dataset = data.get_dataset('train', cfg['data'])
     val_dataset = data.get_dataset('val', cfg['data'])
     test_dataset = data.get_dataset('test', cfg['data'])
@@ -107,162 +232,36 @@ def main(
 
     train_loader, val_loader, test_loader = fabric.setup_dataloaders(train_loader, val_loader, test_loader)
 
+    data_vis_val = next(iter(vis_loader_val))  # Validation set data for visualization
+    data_vis_val = fabric.to_device(data_vis_val)
 
-    if checkpoint:
-        checkpoint = torch.load(checkpoint)
-
-    with fabric.device:
-        torch.set_default_tensor_type(torch.HalfTensor)
-        model = OSRT(cfg['model']).bfloat16()
-        torch.set_default_tensor_type(torch.FloatTensor)
-        if checkpoint:
-            model.load_state_dict(checkpoint, strict=False) 
-
-    model = fabric.setup_module(model)
-
-    params = [p for p in model.parameters() if p.requires_grad]
 
-    # Setup scheduler
-    warmup_iters = cfg['training']['decay_it'] if 'decay_it' in cfg['training'] else 4000000
-    peak_it = cfg['training']['lr_warmup'] if 'lr_warmup' in cfg['training'] else 2500
-    lr_scheduler = LrScheduler(peak_lr=1e-4, peak_it=peak_it, decay_it=warmup_iters, decay_rate=0.16)
-
-    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, foreach=False)
-    optimizer = fabric.setup_optimizers(optimizer)
-
-    train(fabric, model, optimizer, train_loader, val_loader, out_dir)
-
-    # Save the final checkpoint at the end of training
-    save_model_checkpoint(fabric, model, os.path.join(out_dir, "lit-llama-full-finetuned.pth"))
-
-
-def train(
-    fabric: L.Fabric,
-    model: torch.nn.Module,
-    optimizer: torch.optim.Optimizer,
-    train_data: DataLoader, # TODO : maybe use np.array
-    val_data: DataLoader,
-    out_dir: str,
-) -> None:
-    """The training loop.
-
-    Loosely based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT.
-    """
-    step_count = 0
-    model.train()
+    #########################
+    ### Prepare the optimizer
+    #########################
+    optimizer, scheduler = configure_opt(cfg, model)
+    model, optimizer = fabric.setup(model, optimizer)
 
-    for iter_num in range(max_iters):
+    #########################
+    ### Training
+    #########################
+    train_sam(cfg, fabric, model, optimizer, scheduler, train_loader, val_loader)
+    validate(fabric, model, val_loader, epoch=0)
 
-        is_accumulating = (iter_num + 1) % gradient_accumulation_iters != 0
-
-        if step_count <= warmup_iters:
-            # linear warmup
-            lr = learning_rate * step_count / warmup_iters
-            for param_group in optimizer.param_groups:
-                param_group['lr'] = lr
-
-        t0 = time.time()
-        
-        input_ids, targets = get_batch(fabric, train_data)
-        with fabric.no_backward_sync(model, enabled=is_accumulating):
-            logits = model(input_ids)
-            loss = loss_fn(logits, targets)
-            fabric.backward(loss / gradient_accumulation_iters)
-
-        if not is_accumulating:
-            optimizer.step()
-            optimizer.zero_grad()
-            step_count += 1
-
-            if step_count % eval_interval == 0:
-                val_loss = validate(fabric, model, val_data)
-                fabric.print(f"step {iter_num}: val loss {val_loss:.4f}")
-                fabric.barrier()
-
-            if step_count % save_interval == 0:
-                print(f"Saving weights to {out_dir}")
-                save_model_checkpoint(fabric, model, os.path.join(out_dir, f"iter-{iter_num:06d}-ckpt.pth"))
-
-        dt = time.time() - t0
-        if iter_num % log_interval == 0:
-            fabric.print(f"iter {iter_num}: loss {loss.item():.4f}, time: {dt*1000:.2f}ms")
-
-def generate_response(model, instruction):
-    tokenizer = Tokenizer("checkpoints/lit-llama/tokenizer.model")
-    sample = {"instruction": instruction, "input": ""}
-    prompt = instruction
-    if instruction_tuning:
-        prompt = generate_prompt(sample)
-    encoded = tokenizer.encode(prompt, bos=True, eos=False, device=model.device)
-
-    output = generate(
-        model,
-        idx=encoded,
-        max_seq_length=block_size,
-        max_new_tokens=100,
+if __name__ == "__main__":
+    ### Arguments
+    parser = argparse.ArgumentParser(
+        description='Train a 3D scene representation model.'
     )
-    output = tokenizer.decode(output)
-    return output # output.split("### Response:")[1].strip()
-
+    parser.add_argument('config', type=str, help='Path to config file.')
+    parser.add_argument('--wandb', action='store_true', help='Log run to Weights and Biases.')
+    parser.add_argument('--checkpoint', type=str, default='', help='Path to a model checkpoint')
 
-@torch.no_grad()
-def validate(fabric: L.Fabric, model: torch.nn.Module, val_data: np.ndarray) -> torch.Tensor:
-    fabric.print("Validating ...")
-    model.eval()
-    losses = torch.zeros(eval_iters)
-    for k in range(eval_iters):
-        input_ids, targets = get_batch(fabric, val_data)
-        logits = model(input_ids)
-        loss = loss_fn(logits, targets)
-        losses[k] = loss.item()
-    out = losses.mean()
-
-    # produce an example:
-    instruction = "Recommend a movie for me to watch during the weekend and explain the reason."
+    args = parser.parse_args()
     
-    output = generate_response(model, instruction)
-    fabric.print(instruction)
-    fabric.print(output)
-
-    model.train()
-    return out.item()
-
-def loss_fn(logits, targets):
-    # shift the targets such that output n predicts token n+1
-    logits = logits[..., :-1, :].contiguous()
-    targets = targets[..., 1:].contiguous()
-    loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
-    return loss
-
-
-def get_batch(fabric: L.Fabric, data: list):
-    ix = torch.randint(len(data), (micro_batch_size,))
-
-    input_ids = [data[i]["input_ids"].type(torch.int64) for i in ix]
-    labels = [data[i]["labels"].type(torch.int64) for i in ix]
-
-    max_len = max(len(s) for s in input_ids)
-
-    def pad_right(x, pad_id):
-        # pad right based on the longest sequence
-        n = max_len - len(x)
-        return torch.cat((x, torch.full((n,), pad_id, dtype=x.dtype)))
-
-    x = torch.stack([pad_right(x, pad_id=0) for x in input_ids])
-    y = torch.stack([pad_right(x, pad_id=-1) for x in labels])
-    x, y = fabric.to_device((x.pin_memory(), y.pin_memory()))
-    return x, y
-
-
-def load_datasets(data_dir):
-    train_data = torch.load(os.path.join(data_dir, "train.pt"))
-    val_data = torch.load(os.path.join(data_dir, "test.pt"))
-    return train_data, val_data
-
-
-if __name__ == "__main__":
-    # Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false"
-    # torch.backends.cuda.enable_flash_sdp(False)
-    torch.set_float32_matmul_precision("high")
-
-    CLI(main)
\ No newline at end of file
+    #########################
+    ### Creating utility var
+    #########################
+    with open(args.config, 'r') as f:
+        cfg = json.load(f)
+    main(cfg)
\ No newline at end of file