From 4d10dcf92ff2514d4e5ba177f764825746fce176 Mon Sep 17 00:00:00 2001
From: alexcbb <alexchapin@hotmail.fr>
Date: Tue, 18 Jul 2023 10:30:10 +0200
Subject: [PATCH] Fix improt issue

---
 osrt/utils/training.py | 26 ++++----------------------
 train_lit.py           | 22 +++++++++++++++++++---
 2 files changed, 23 insertions(+), 25 deletions(-)

diff --git a/osrt/utils/training.py b/osrt/utils/training.py
index 48240ca..8352301 100644
--- a/osrt/utils/training.py
+++ b/osrt/utils/training.py
@@ -3,7 +3,7 @@ from torch.utils.data import DataLoader
 import lightning as L
 import torchmetrics
 
-from osrt.utils.common import mse2psnr, compute_adjusted_rand_index
+from osrt.utils.common import mse2psnr, compute_ari
 from osrt.utils import nerf
 from osrt.encoder import FeatureMasking
 import osrt.utils.visualize as vis
@@ -15,24 +15,6 @@ import math
 from collections import defaultdict
 import time
 import os
-
-class AverageMeter:
-    """Computes and stores the average and current value."""
-
-    def __init__(self):
-        self.reset()
-
-    def reset(self):
-        self.val = 0
-        self.avg = 0
-        self.sum = 0
-        self.count = 0
-
-    def update(self, val, n=1):
-        self.val = val
-        self.sum += val * n
-        self.count += n
-        self.avg = self.sum / self.count
         
 def compute_loss(
         batch, 
@@ -70,10 +52,10 @@ def compute_loss(
         pred_seg = extras['segmentation']
         true_seg = batch['target_masks'].float()
 
-        loss_terms['ari'] = compute_adjusted_rand_index(true_seg.transpose(1, 2),
+        loss_terms['ari'] = compute_ari(true_seg.transpose(1, 2),
                                                         pred_seg.transpose(1, 2))
 
-        loss_terms['fg_ari'] = compute_adjusted_rand_index(true_seg.transpose(1, 2)[:, 1:],
+        loss_terms['fg_ari'] = compute_ari(true_seg.transpose(1, 2)[:, 1:],
                                                             pred_seg.transpose(1, 2))
     return loss, loss_terms
 
@@ -302,7 +284,7 @@ def visualize(model, vis_data, out_dir, render_args, fabric, config):
                 pred_seg = extras['segmentation'].cpu()
                 columns.append((f'pred seg {angle_deg}°', pred_seg.argmax(-1).numpy(), 'clustering'))
                 if i == 0:
-                    ari = compute_adjusted_rand_index(
+                    ari = compute_ari(
                         input_mask.cpu().flatten(1, 2).transpose(1, 2)[:, 1:],
                         pred_seg.flatten(1, 2).transpose(1, 2))
                     row_labels = ['2D Fg-ARI={:.1f}'.format(x.item() * 100) for x in ari]
diff --git a/train_lit.py b/train_lit.py
index 3694705..67321c7 100644
--- a/train_lit.py
+++ b/train_lit.py
@@ -20,13 +20,31 @@ from torch.utils.data import DataLoader
 from osrt.model import OSRT
 from osrt.encoder import FeatureMasking
 from osrt import data
-from osrt.utils.training import AverageMeter
 from osrt.utils.losses import compute_focal_loss, compute_ari, compute_dice_loss
 
 torch.set_float32_matmul_precision('high')
 
 __LOG10 = math.log(10)
 
+class AverageMeter:
+    """Computes and stores the average and current value."""
+
+    def __init__(self):
+        self.reset()
+
+    def reset(self):
+        self.val = 0
+        self.avg = 0
+        self.sum = 0
+        self.count = 0
+
+    def update(self, val, n=1):
+        self.val = val
+        self.sum += val * n
+        self.count += n
+        self.avg = self.sum / self.count
+     
+
 def validate(fabric: L.Fabric, model: OSRT, val_dataloader: DataLoader, epoch: int = 0):
     # TODO : add segmentation also to select the model following how it's done in the training
     model.eval()
@@ -134,7 +152,6 @@ def train_sam(
 
             ### Evaluate segmentation
             if 'segmentation' in extras:
-
                 fabric.print(f"Pred segmentation shape {extras['segmentation'].shape}")
                 pred_masks = extras['segmentation'] # [B, nb_rays, nb_slots]
 
@@ -252,7 +269,6 @@ def main(cfg) -> None:
     data_vis_val = next(iter(vis_loader_val))  # Validation set data for visualization
     data_vis_val = fabric.to_device(data_vis_val)
 
-
     #########################
     ### Prepare the optimizer
     #########################
-- 
GitLab