From 4d10dcf92ff2514d4e5ba177f764825746fce176 Mon Sep 17 00:00:00 2001 From: alexcbb <alexchapin@hotmail.fr> Date: Tue, 18 Jul 2023 10:30:10 +0200 Subject: [PATCH] Fix improt issue --- osrt/utils/training.py | 26 ++++---------------------- train_lit.py | 22 +++++++++++++++++++--- 2 files changed, 23 insertions(+), 25 deletions(-) diff --git a/osrt/utils/training.py b/osrt/utils/training.py index 48240ca..8352301 100644 --- a/osrt/utils/training.py +++ b/osrt/utils/training.py @@ -3,7 +3,7 @@ from torch.utils.data import DataLoader import lightning as L import torchmetrics -from osrt.utils.common import mse2psnr, compute_adjusted_rand_index +from osrt.utils.common import mse2psnr, compute_ari from osrt.utils import nerf from osrt.encoder import FeatureMasking import osrt.utils.visualize as vis @@ -15,24 +15,6 @@ import math from collections import defaultdict import time import os - -class AverageMeter: - """Computes and stores the average and current value.""" - - def __init__(self): - self.reset() - - def reset(self): - self.val = 0 - self.avg = 0 - self.sum = 0 - self.count = 0 - - def update(self, val, n=1): - self.val = val - self.sum += val * n - self.count += n - self.avg = self.sum / self.count def compute_loss( batch, @@ -70,10 +52,10 @@ def compute_loss( pred_seg = extras['segmentation'] true_seg = batch['target_masks'].float() - loss_terms['ari'] = compute_adjusted_rand_index(true_seg.transpose(1, 2), + loss_terms['ari'] = compute_ari(true_seg.transpose(1, 2), pred_seg.transpose(1, 2)) - loss_terms['fg_ari'] = compute_adjusted_rand_index(true_seg.transpose(1, 2)[:, 1:], + loss_terms['fg_ari'] = compute_ari(true_seg.transpose(1, 2)[:, 1:], pred_seg.transpose(1, 2)) return loss, loss_terms @@ -302,7 +284,7 @@ def visualize(model, vis_data, out_dir, render_args, fabric, config): pred_seg = extras['segmentation'].cpu() columns.append((f'pred seg {angle_deg}°', pred_seg.argmax(-1).numpy(), 'clustering')) if i == 0: - ari = compute_adjusted_rand_index( + ari = compute_ari( input_mask.cpu().flatten(1, 2).transpose(1, 2)[:, 1:], pred_seg.flatten(1, 2).transpose(1, 2)) row_labels = ['2D Fg-ARI={:.1f}'.format(x.item() * 100) for x in ari] diff --git a/train_lit.py b/train_lit.py index 3694705..67321c7 100644 --- a/train_lit.py +++ b/train_lit.py @@ -20,13 +20,31 @@ from torch.utils.data import DataLoader from osrt.model import OSRT from osrt.encoder import FeatureMasking from osrt import data -from osrt.utils.training import AverageMeter from osrt.utils.losses import compute_focal_loss, compute_ari, compute_dice_loss torch.set_float32_matmul_precision('high') __LOG10 = math.log(10) +class AverageMeter: + """Computes and stores the average and current value.""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def validate(fabric: L.Fabric, model: OSRT, val_dataloader: DataLoader, epoch: int = 0): # TODO : add segmentation also to select the model following how it's done in the training model.eval() @@ -134,7 +152,6 @@ def train_sam( ### Evaluate segmentation if 'segmentation' in extras: - fabric.print(f"Pred segmentation shape {extras['segmentation'].shape}") pred_masks = extras['segmentation'] # [B, nb_rays, nb_slots] @@ -252,7 +269,6 @@ def main(cfg) -> None: data_vis_val = next(iter(vis_loader_val)) # Validation set data for visualization data_vis_val = fabric.to_device(data_vis_val) - ######################### ### Prepare the optimizer ######################### -- GitLab