diff --git a/osrt/model.py b/osrt/model.py index b01d1617025bc90ec98af05a3c9d3ed60d4d89e3..98dcdaa6bd01040cb149f4be4e57909f9ddb7d1e 100644 --- a/osrt/model.py +++ b/osrt/model.py @@ -7,7 +7,8 @@ import osrt.layers as layers import lightning.pytorch as pl import torch import torch.optim as optim -from osrt.utils.common import mse2psnr, compute_adjusted_rand_index +from osrt.utils.common import mse2psnr +from osrt.utils.losses import compute_ari from torch.optim.lr_scheduler import LambdaLR from typing import Dict @@ -54,10 +55,11 @@ class OSRT(nn.Module): self.encoder.train() self.decoder.train() - + +""" class LitOSRT(pl.LightningModule): def __init__(self, encoder:nn.Module, decoder: nn.Module, cfg: Dict, extract_masks:bool =False): - """OSRT Model + OSRT Model The definition of the encoder/decoder are defined in the config file with the path to classes Args: @@ -65,7 +67,7 @@ class LitOSRT(pl.LightningModule): decoder: class of the decoder to use cfg: config file containing informations of the model extract_masks: wether to use masks for training - """ + super().__init__() self.save_hyperparameters() self.cfg = cfg @@ -80,14 +82,14 @@ class LitOSRT(pl.LightningModule): return self.encoder(x) # Returns: slot_latents def compute_loss(self, batch): - """ Args: batch: dict containing the informations for training --> input images, rays and position extract_masks (Bool): whether to use masks to compute the segmentation loss or not Returns: loss: the loss value loss_terms: a dict containing more loss values - """ + + device = self.device render_kwargs = self.trainer.datamodule.train_dataset.render_kwargs @@ -131,10 +133,10 @@ class LitOSRT(pl.LightningModule): # These are not actually used as part of the training loss. # We just add the to the dict to report them. - loss_terms['ari'] = compute_adjusted_rand_index(true_seg.transpose(1, 2), + loss_terms['ari'] = compute_ari(true_seg.transpose(1, 2), pred_seg.transpose(1, 2)) - loss_terms['fg_ari'] = compute_adjusted_rand_index(true_seg.transpose(1, 2)[:, 1:], + loss_terms['fg_ari'] = compute_ari(true_seg.transpose(1, 2)[:, 1:], pred_seg.transpose(1, 2)) # TODO : add new ari metrics @@ -222,3 +224,4 @@ class LitOSRT(pl.LightningModule): 'interval': 'step' } } +""" \ No newline at end of file diff --git a/train_lit.py b/train_lit.py index b5086eef97254a034e87dd0b8bb7f7d909644370..3694705b3bd722098ac9d6487337581ac7b77a8a 100644 --- a/train_lit.py +++ b/train_lit.py @@ -29,7 +29,7 @@ __LOG10 = math.log(10) def validate(fabric: L.Fabric, model: OSRT, val_dataloader: DataLoader, epoch: int = 0): # TODO : add segmentation also to select the model following how it's done in the training - """model.eval() + model.eval() mses = AverageMeter() psnrs = AverageMeter() @@ -71,8 +71,7 @@ def validate(fabric: L.Fabric, model: OSRT, val_dataloader: DataLoader, epoch: i state_dict = model.state_dict() if fabric.global_rank == 0: torch.save(state_dict, os.path.join(cfg.out_dir, f"epoch-{epoch:06d}-psnr{psnrs.avg:.2f}-mse{mses.avg:.2f}-ckpt.pth")) - model.train()""" - pass + model.train() def train_sam(