Skip to content
Snippets Groups Projects
Commit cf16a8f4 authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Fix import

parent 5d6ea2d3
No related branches found
No related tags found
No related merge requests found
......@@ -7,7 +7,8 @@ import osrt.layers as layers
import lightning.pytorch as pl
import torch
import torch.optim as optim
from osrt.utils.common import mse2psnr, compute_adjusted_rand_index
from osrt.utils.common import mse2psnr
from osrt.utils.losses import compute_ari
from torch.optim.lr_scheduler import LambdaLR
from typing import Dict
......@@ -54,10 +55,11 @@ class OSRT(nn.Module):
self.encoder.train()
self.decoder.train()
"""
class LitOSRT(pl.LightningModule):
def __init__(self, encoder:nn.Module, decoder: nn.Module, cfg: Dict, extract_masks:bool =False):
"""OSRT Model
OSRT Model
The definition of the encoder/decoder are defined in the config file with the path to classes
Args:
......@@ -65,7 +67,7 @@ class LitOSRT(pl.LightningModule):
decoder: class of the decoder to use
cfg: config file containing informations of the model
extract_masks: wether to use masks for training
"""
super().__init__()
self.save_hyperparameters()
self.cfg = cfg
......@@ -80,14 +82,14 @@ class LitOSRT(pl.LightningModule):
return self.encoder(x) # Returns: slot_latents
def compute_loss(self, batch):
"""
Args:
batch: dict containing the informations for training --> input images, rays and position
extract_masks (Bool): whether to use masks to compute the segmentation loss or not
Returns:
loss: the loss value
loss_terms: a dict containing more loss values
"""
device = self.device
render_kwargs = self.trainer.datamodule.train_dataset.render_kwargs
......@@ -131,10 +133,10 @@ class LitOSRT(pl.LightningModule):
# These are not actually used as part of the training loss.
# We just add the to the dict to report them.
loss_terms['ari'] = compute_adjusted_rand_index(true_seg.transpose(1, 2),
loss_terms['ari'] = compute_ari(true_seg.transpose(1, 2),
pred_seg.transpose(1, 2))
loss_terms['fg_ari'] = compute_adjusted_rand_index(true_seg.transpose(1, 2)[:, 1:],
loss_terms['fg_ari'] = compute_ari(true_seg.transpose(1, 2)[:, 1:],
pred_seg.transpose(1, 2))
# TODO : add new ari metrics
......@@ -222,3 +224,4 @@ class LitOSRT(pl.LightningModule):
'interval': 'step'
}
}
"""
\ No newline at end of file
......@@ -29,7 +29,7 @@ __LOG10 = math.log(10)
def validate(fabric: L.Fabric, model: OSRT, val_dataloader: DataLoader, epoch: int = 0):
# TODO : add segmentation also to select the model following how it's done in the training
"""model.eval()
model.eval()
mses = AverageMeter()
psnrs = AverageMeter()
......@@ -71,8 +71,7 @@ def validate(fabric: L.Fabric, model: OSRT, val_dataloader: DataLoader, epoch: i
state_dict = model.state_dict()
if fabric.global_rank == 0:
torch.save(state_dict, os.path.join(cfg.out_dir, f"epoch-{epoch:06d}-psnr{psnrs.avg:.2f}-mse{mses.avg:.2f}-ckpt.pth"))
model.train()"""
pass
model.train()
def train_sam(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment