Skip to content
Snippets Groups Projects
Commit cf16a8f4 authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Fix import

parent 5d6ea2d3
No related branches found
No related tags found
No related merge requests found
...@@ -7,7 +7,8 @@ import osrt.layers as layers ...@@ -7,7 +7,8 @@ import osrt.layers as layers
import lightning.pytorch as pl import lightning.pytorch as pl
import torch import torch
import torch.optim as optim import torch.optim as optim
from osrt.utils.common import mse2psnr, compute_adjusted_rand_index from osrt.utils.common import mse2psnr
from osrt.utils.losses import compute_ari
from torch.optim.lr_scheduler import LambdaLR from torch.optim.lr_scheduler import LambdaLR
from typing import Dict from typing import Dict
...@@ -54,10 +55,11 @@ class OSRT(nn.Module): ...@@ -54,10 +55,11 @@ class OSRT(nn.Module):
self.encoder.train() self.encoder.train()
self.decoder.train() self.decoder.train()
"""
class LitOSRT(pl.LightningModule): class LitOSRT(pl.LightningModule):
def __init__(self, encoder:nn.Module, decoder: nn.Module, cfg: Dict, extract_masks:bool =False): def __init__(self, encoder:nn.Module, decoder: nn.Module, cfg: Dict, extract_masks:bool =False):
"""OSRT Model OSRT Model
The definition of the encoder/decoder are defined in the config file with the path to classes The definition of the encoder/decoder are defined in the config file with the path to classes
Args: Args:
...@@ -65,7 +67,7 @@ class LitOSRT(pl.LightningModule): ...@@ -65,7 +67,7 @@ class LitOSRT(pl.LightningModule):
decoder: class of the decoder to use decoder: class of the decoder to use
cfg: config file containing informations of the model cfg: config file containing informations of the model
extract_masks: wether to use masks for training extract_masks: wether to use masks for training
"""
super().__init__() super().__init__()
self.save_hyperparameters() self.save_hyperparameters()
self.cfg = cfg self.cfg = cfg
...@@ -80,14 +82,14 @@ class LitOSRT(pl.LightningModule): ...@@ -80,14 +82,14 @@ class LitOSRT(pl.LightningModule):
return self.encoder(x) # Returns: slot_latents return self.encoder(x) # Returns: slot_latents
def compute_loss(self, batch): def compute_loss(self, batch):
"""
Args: Args:
batch: dict containing the informations for training --> input images, rays and position batch: dict containing the informations for training --> input images, rays and position
extract_masks (Bool): whether to use masks to compute the segmentation loss or not extract_masks (Bool): whether to use masks to compute the segmentation loss or not
Returns: Returns:
loss: the loss value loss: the loss value
loss_terms: a dict containing more loss values loss_terms: a dict containing more loss values
"""
device = self.device device = self.device
render_kwargs = self.trainer.datamodule.train_dataset.render_kwargs render_kwargs = self.trainer.datamodule.train_dataset.render_kwargs
...@@ -131,10 +133,10 @@ class LitOSRT(pl.LightningModule): ...@@ -131,10 +133,10 @@ class LitOSRT(pl.LightningModule):
# These are not actually used as part of the training loss. # These are not actually used as part of the training loss.
# We just add the to the dict to report them. # We just add the to the dict to report them.
loss_terms['ari'] = compute_adjusted_rand_index(true_seg.transpose(1, 2), loss_terms['ari'] = compute_ari(true_seg.transpose(1, 2),
pred_seg.transpose(1, 2)) pred_seg.transpose(1, 2))
loss_terms['fg_ari'] = compute_adjusted_rand_index(true_seg.transpose(1, 2)[:, 1:], loss_terms['fg_ari'] = compute_ari(true_seg.transpose(1, 2)[:, 1:],
pred_seg.transpose(1, 2)) pred_seg.transpose(1, 2))
# TODO : add new ari metrics # TODO : add new ari metrics
...@@ -222,3 +224,4 @@ class LitOSRT(pl.LightningModule): ...@@ -222,3 +224,4 @@ class LitOSRT(pl.LightningModule):
'interval': 'step' 'interval': 'step'
} }
} }
"""
\ No newline at end of file
...@@ -29,7 +29,7 @@ __LOG10 = math.log(10) ...@@ -29,7 +29,7 @@ __LOG10 = math.log(10)
def validate(fabric: L.Fabric, model: OSRT, val_dataloader: DataLoader, epoch: int = 0): def validate(fabric: L.Fabric, model: OSRT, val_dataloader: DataLoader, epoch: int = 0):
# TODO : add segmentation also to select the model following how it's done in the training # TODO : add segmentation also to select the model following how it's done in the training
"""model.eval() model.eval()
mses = AverageMeter() mses = AverageMeter()
psnrs = AverageMeter() psnrs = AverageMeter()
...@@ -71,8 +71,7 @@ def validate(fabric: L.Fabric, model: OSRT, val_dataloader: DataLoader, epoch: i ...@@ -71,8 +71,7 @@ def validate(fabric: L.Fabric, model: OSRT, val_dataloader: DataLoader, epoch: i
state_dict = model.state_dict() state_dict = model.state_dict()
if fabric.global_rank == 0: if fabric.global_rank == 0:
torch.save(state_dict, os.path.join(cfg.out_dir, f"epoch-{epoch:06d}-psnr{psnrs.avg:.2f}-mse{mses.avg:.2f}-ckpt.pth")) torch.save(state_dict, os.path.join(cfg.out_dir, f"epoch-{epoch:06d}-psnr{psnrs.avg:.2f}-mse{mses.avg:.2f}-ckpt.pth"))
model.train()""" model.train()
pass
def train_sam( def train_sam(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment