Skip to content
Snippets Groups Projects
Commit 5d6ea2d3 authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Setup training

parent f9d36ede
No related branches found
No related tags found
No related merge requests found
from osrt.data.nmr import NMRDataset
from osrt.data.multishapenet import MultishapenetDataset
from osrt.data.obsurf import Clevr3dDataset
from osrt.data.datamodule import DataModule
......
......@@ -28,11 +28,7 @@ def get_dataset(mode, cfg, max_len=None, full_scale=False):
kwargs = dict()
# Create dataset
if dataset_type == 'nmr':
dataset = data.NMRDataset(dataset_folder, mode, points_per_item=points_per_item,
max_len=max_len, full_scale=full_scale,
canonical_view=canonical_view, **kwargs)
elif dataset_type == 'msn':
if dataset_type == 'msn':
dataset = data.MultishapenetDataset(dataset_folder, mode, points_per_item=points_per_item,
full_scale=full_scale, canonical_view=canonical_view, **kwargs)
elif dataset_type == 'osrt':
......
"""
Code modified from MedSAM repository :
https://github.com/bowang-lab/MedSAM/blob/main/utils/precompute_img_embed.py
"""
import numpy as np
import os
join = os.path.join
from tqdm import tqdm
import torch
from segment_anything import sam_model_registry
from segment_anything.utils.transforms import ResizeLongestSide
import argparse
import cv2
parser = argparse.ArgumentParser(
description='Extract image embeddings from SAM model'
)
parser.add_argument('-i', '--img_path', type=str, default='', help='Path to the folder containing images')
parser.add_argument('-o', '--save_path', type=str, default='', help='Path to the folder containing the final embeddings')
parser.add_argument('--model_type', type=str, default='vit_h', help='model type')
parser.add_argument('--path_model', type=str, default='.', help='path to the pre-trained SAM model')
args = parser.parse_args()
model_type = args.model
if args.model == 'vit_h':
checkpoint = args.path_model + '/sam_vit_h_4b8939.pth'
elif args.model == 'vit_b':
checkpoint = args.path_model + '/sam_vit_b_01ec64.pth'
else:
model_type = 'vit_l'
checkpoint = args.path_model + '/sam_vit_l_0b3195.pth'
img_path = args.img_path
img_files= sorted(os.listdir(img_path))
sam_model = sam_model_registry[args.model_type](checkpoint=args.checkpoint).to('cuda:0')
sam_transform = ResizeLongestSide(sam_model.image_encoder.img_size)
# compute image embeddings
images= []
for name in tqdm(img_files):
img = cv2.imread(name)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_np = np.array(img)
resize_img = sam_transform.apply_image(img)
resize_img_tensor = torch.as_tensor(resize_img.transpose(2, 0, 1)).to('cuda:0')
# model input: (1, 3, 1024, 1024)
input_image = sam_model.preprocess(resize_img_tensor[None,:,:,:]) # (1, 3, 1024, 1024)
assert input_image.shape == (1, 3, sam_model.image_encoder.img_size, sam_model.image_encoder.img_size), 'input image should be resized to 1024*1024'
with torch.no_grad():
embedding = sam_model.image_encoder(input_image)
# save as npy
np.save(join(args.save_path, name.split('.')[0]+'.npy'), embedding.cpu().numpy()[0])
\ No newline at end of file
......@@ -79,7 +79,7 @@ class MultishapenetDataset(IterableDataset):
target_camera_pos = np.reshape(data['ray_origins'][target_views], (-1, 3))
num_pixels = target_pixels.shape[0]
sampled_idxs = None
if not self.full_scale:
sampled_idxs = np.random.choice(np.arange(num_pixels),
size=(self.num_target_pixels,),
......@@ -103,6 +103,7 @@ class MultishapenetDataset(IterableDataset):
'target_camera_pos': target_camera_pos, # [p, 3]
'target_rays': target_rays, # [p, 3]
'sceneid': sceneid, # int
'sampled_idx': sampled_idxs # [p]
}
if self.canonical:
......
import numpy as np
import imageio
import yaml
from torch.utils.data import Dataset
import os
from osrt.utils.nerf import transform_points
class NMRDataset(Dataset):
def __init__(self, path, mode, points_per_item=2048, max_len=None,
canonical_view=True, full_scale=False):
""" Loads the NMR dataset as found at
https://s3.eu-central-1.amazonaws.com/avg-projects/differentiable_volumetric_rendering/data/NMR_Dataset.zip
Hosted by Niemeyer et al. (https://github.com/autonomousvision/differentiable_volumetric_rendering)
Args:
path (str): Path to dataset.
mode (str): 'train', 'val', or 'test'.
points_per_item (int): Number of target points per scene.
max_len (int): Limit to the number of entries in the dataset.
canonical_view (bool): Return data in canonical camera coordinates (like in SRT), as opposed
to world coordinates.
full_scale (bool): Return all available target points, instead of sampling.
"""
self.path = path
self.mode = mode
self.points_per_item = points_per_item
self.max_len = max_len
self.canonical = canonical_view
self.full_scale = full_scale
with open(os.path.join(path, 'metadata.yaml'), 'r') as f:
metadata = yaml.load(f, Loader=yaml.CLoader)
class_ids = [entry['id'] for entry in metadata.values()]
self.scene_paths = []
for class_id in class_ids:
with open(os.path.join(path, class_id, f'softras_{mode}.lst')) as f:
cur_scene_ids = f.readlines()
cur_scene_ids = [scene_id.rstrip() for scene_id in cur_scene_ids if len(scene_id) > 1]
cur_scene_paths = [os.path.join(class_id, scene_id) for scene_id in cur_scene_ids]
self.scene_paths.extend(cur_scene_paths)
self.num_scenes = len(self.scene_paths)
print(f'NMR {mode} dataset loaded: {self.num_scenes} scenes.')
self.render_kwargs = {
'min_dist': 2.,
'max_dist': 4.}
# Rotation matrix making z=0 is the ground plane.
# Ensures that the scenes are layed out in the same way as the other datasets,
# which is convenient for visualization.
self.rot_mat = np.array([[1, 0, 0, 0],
[0, 0, -1, 0],
[0, 1, 0, 0],
[0, 0, 0, 1]])
def __len__(self):
if self.max_len is not None:
return self.max_len
return self.num_scenes * 24
def __getitem__(self, idx):
scene_idx = idx % self.num_scenes
view_idx = idx // self.num_scenes
target_views = np.array(list(set(range(24)) - set([view_idx])))
scene_path = os.path.join(self.path, self.scene_paths[scene_idx])
images = [np.asarray(imageio.imread(
os.path.join(scene_path, 'image', f'{i:04d}.png'))) for i in range(24)]
images = np.stack(images, 0).astype(np.float32) / 255.
input_image = np.transpose(images[view_idx], (2, 0, 1))
cameras = np.load(os.path.join(scene_path, 'cameras.npz'))
cameras = {k: v for k, v in cameras.items()} # Load all matrices into memory
for i in range(24): # Apply rotation matrix to rotate coordinate system
cameras[f'world_mat_inv_{i}'] = self.rot_mat @ cameras[f'world_mat_inv_{i}']
# The transpose here is not technically necessary, since the rotation matrix is symmetric
cameras[f'world_mat_{i}'] = cameras[f'world_mat_{i}'] @ np.transpose(self.rot_mat)
rays = []
height = width = 64
xmap = np.linspace(-1, 1, width)
ymap = np.linspace(-1, 1, height)
xmap, ymap = np.meshgrid(xmap, ymap)
for i in range(24):
cur_rays = np.stack((xmap, ymap, np.ones_like(xmap)), -1)
cur_rays = transform_points(cur_rays,
cameras[f'world_mat_inv_{i}'] @ cameras[f'camera_mat_inv_{i}'],
translate=False)
cur_rays = cur_rays[..., :3]
cur_rays = cur_rays / np.linalg.norm(cur_rays, axis=-1, keepdims=True)
rays.append(cur_rays)
rays = np.stack(rays, axis=0).astype(np.float32)
camera_pos = [cameras[f'world_mat_inv_{i}'][:3, -1] for i in range(24)]
camera_pos = np.stack(camera_pos, axis=0).astype(np.float32)
# camera_pos and rays are now in world coordinates.
if self.canonical: # Transform to canonical camera coordinates
canonical_extrinsic = cameras[f'world_mat_{view_idx}'].astype(np.float32)
camera_pos = transform_points(camera_pos, canonical_extrinsic)
rays = transform_points(rays, canonical_extrinsic, translate=False)
rays_flat = np.reshape(rays[target_views], (-1, 3))
pixels_flat = np.reshape(images[target_views], (-1, 3))
cpos_flat = np.tile(np.expand_dims(camera_pos[target_views], 1), (1, height * width, 1))
cpos_flat = np.reshape(cpos_flat, (len(target_views) * height * width, 3))
num_points = rays_flat.shape[0]
if not self.full_scale:
replace = num_points < self.points_per_item
sampled_idxs = np.random.choice(np.arange(num_points),
size=(self.points_per_item,),
replace=replace)
rays_sel = rays_flat[sampled_idxs]
pixels_sel = pixels_flat[sampled_idxs]
cpos_sel = cpos_flat[sampled_idxs]
else:
rays_sel = rays_flat
pixels_sel = pixels_flat
cpos_sel = cpos_flat
result = {
'input_images': np.expand_dims(input_image, 0), # [1, 3, h, w]
'input_camera_pos': np.expand_dims(camera_pos[view_idx], 0), # [1, 3]
'input_rays': np.expand_dims(rays[view_idx], 0), # [1, h, w, 3]
'target_pixels': pixels_sel, # [p, 3]
'target_camera_pos': cpos_sel, # [p, 3]
'target_rays': rays_sel, # [p, 3]
'sceneid': idx, # int
}
if self.canonical:
result['transform'] = canonical_extrinsic # [3, 4] (optional)
return result
......@@ -132,7 +132,7 @@ class Clevr3dDataset(Dataset):
all_masks = np.reshape(masks, (self.num_views * 240 * 320, self.max_num_entities))
num_points = all_rays.shape[0]
sampled_idxs = None
if not self.full_scale:
# If we have fewer points than we want, sample with replacement
replace = num_points < self.points_per_item
......@@ -160,6 +160,7 @@ class Clevr3dDataset(Dataset):
'target_rays': target_rays, # [p, 3]
'target_masks': target_masks, # [p, self.max_num_entities]
'sceneid': idx, # int
'sampled_idx': sampled_idxs # int
}
if self.canonical:
......
......@@ -253,120 +253,3 @@ def reduce_dict(input_dict, average=True):
reduced_dict = {k: v for k, v in zip(keys, values)}
return reduced_dict
def compute_adjusted_rand_index(true_mask, pred_mask):
"""
Computes the adjusted rand index (ARI) of a given image segmentation, ignoring the background.
Implementation following https://github.com/deepmind/multi_object_datasets/blob/master/segmentation_metrics.py#L20
Args:
true_mask: Tensor of shape [batch_size, n_true_groups, n_points] containing true
one-hot coded cluster assignments, with background being indicated by zero vectors.
pred_mask: Tensor of shape [batch_size, n_pred_groups, n_points] containing predicted
cluster assignments encoded as categorical probabilities.
"""
batch_size, n_true_groups, n_points = true_mask.shape
n_pred_groups = pred_mask.shape[1]
if n_points <= n_true_groups and n_points <= n_pred_groups:
raise ValueError(
"adjusted_rand_index requires n_groups < n_points. We don't handle "
"the special cases that can occur when you have one cluster "
"per datapoint.")
true_group_ids = true_mask.argmax(1)
pred_group_ids = pred_mask.argmax(1)
# Convert to one-hot ('oh') representations
true_mask_oh = true_mask.float()
pred_mask_oh = torch.eye(n_pred_groups).to(pred_mask)[pred_group_ids].transpose(1, 2).float() # TODO : this float was not there before, to check
n_points_fg = true_mask_oh.sum((1, 2))
nij = torch.einsum('bip,bjp->bji', pred_mask_oh, true_mask_oh)
nij = nij.double() # Cast to double, since the expected_rindex can introduce numerical inaccuracies
a = nij.sum(1)
b = nij.sum(2)
rindex = (nij * (nij - 1)).sum((1, 2))
aindex = (a * (a - 1)).sum(1)
bindex = (b * (b - 1)).sum(1)
expected_rindex = aindex * bindex / (n_points_fg * (n_points_fg - 1))
max_rindex = (aindex + bindex) / 2
ari = (rindex - expected_rindex) / (max_rindex - expected_rindex)
# We can get NaN in case max_rindex == expected_rindex. This happens when both true and
# predicted segmentations consist of only a single segment. Since we are allowing the
# true segmentation to contain zeros (i.e. background) which we ignore, it suffices
# if the foreground pixels belong to a single segment.
# We check for this case, and instead set the ARI to 1.
def _fg_all_equal(values, bg):
"""
Check if all pixels in values that do not belong to the background (bg is False) have the same
segmentation id.
Args:
values: Segmentations ids given as integer Tensor of shape [batch_size, n_points]
bg: Binary tensor indicating background, shape [batch_size, n_points]
"""
fg_ids = (values + 1) * (1 - bg.int()) # Move fg ids to [1, n], set bg ids to 0
example_fg_id = fg_ids.max(1, keepdim=True)[0] # Get the id of an arbitrary fg cluster.
return torch.logical_or(fg_ids == example_fg_id[..., :1], # All pixels should match that id...
bg # ...or belong to the background.
).all(-1)
background = (true_mask.sum(1) == 0)
both_single_cluster = torch.logical_and(_fg_all_equal(true_group_ids, background),
_fg_all_equal(pred_group_ids, background))
# Ensure that we are only (close to) getting NaNs in exactly the case described above.
matching = (both_single_cluster == torch.isclose(max_rindex, expected_rindex))
if not matching.all().item():
offending_idx = matching.int().argmin()
return torch.where(both_single_cluster, torch.ones_like(ari), ari)
def precision_recall(segmentation_gt: torch.Tensor, segmentation_pred: torch.Tensor, mode: str, adjusted: bool):
""" Compute the (Adjusted) Rand Precision/Recall.
Implementation obtained from paper : Sensitivity of Slot-Based Object-Centric Models to their Number of Slots
Args:
- segmentation_gt: Int tensor with shape (batch_size, height, width) containing the ground-truth segmentations.
- segmentation_pred: Int tensor with shape (batch_size, height, width) containing the predicted segmentations.
- mode: Either "precision" or "recall" depending on which metric shall be computed.
- adjusted: Return values for adjusted or non-adjusted metric.
Returns:
Float tensor with shape (batch_size), containing the (Adjusted) Rand Precision/Recall per sample.
"""
max_classes = max(segmentation_gt.max(), segmentation_pred.max()) + 1
oh_segmentation_gt = F.one_hot(segmentation_gt, max_classes)
oh_segmentation_pred = F.one_hot(segmentation_pred, max_classes)
coincidence = torch.einsum("bhwk,bhwc->bkc", oh_segmentation_gt, oh_segmentation_pred)
coincidence_gt = coincidence.sum(-1)
coincidence_pred = coincidence.sum(-2)
m_squared = torch.sum(coincidence**2, (1, 2))
m = torch.sum(coincidence, (1, 2))
# How many pairs of pixels have the smae label assigned in ground-truth segmentation.
P = torch.sum(coincidence_gt * (coincidence_gt - 1), -1)
# How many pairs of pixels have the smae label assigned in predicted segmentation.
Q = torch.sum(coincidence_pred * (coincidence_pred - 1), -1)
expected_m_squared = (P + m) * (Q + m) / (m * (m - 2)) + (m**2 - Q - P -2 * m) / (m - 1)
if mode == "precision":
gamma = P + m
elif mode == "recall":
gamma = Q + m
else:
raise ValueError("Invalid mode.")
if adjusted:
return (m_squared - expected_m_squared) / (gamma - expected_m_squared)
else:
return (m_squared - m) / (gamma - m)
\ No newline at end of file
......@@ -9,39 +9,146 @@ import torch.nn.functional as F
ALPHA = 0.8
GAMMA = 2
def compute_focal_loss(inputs, targets, alpha=ALPHA, gamma=GAMMA, smooth=1):
inputs = F.sigmoid(inputs)
inputs = torch.clamp(inputs, min=0, max=1)
#flatten label and prediction tensors
inputs = inputs.view(-1)
targets = targets.view(-1)
class FocalLoss(nn.Module):
BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
BCE_EXP = torch.exp(-BCE)
focal_loss = alpha * (1 - BCE_EXP)**gamma * BCE
def __init__(self, weight=None, size_average=True):
super().__init__()
return focal_loss
def forward(self, inputs, targets, alpha=ALPHA, gamma=GAMMA, smooth=1):
inputs = F.sigmoid(inputs)
inputs = torch.clamp(inputs, min=0, max=1)
#flatten label and prediction tensors
inputs = inputs.view(-1)
targets = targets.view(-1)
def compute_dice_loss(inputs, targets, smooth=1):
inputs = F.sigmoid(inputs)
inputs = torch.clamp(inputs, min=0, max=1)
#flatten label and prediction tensors
inputs = inputs.view(-1)
targets = targets.view(-1)
BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
BCE_EXP = torch.exp(-BCE)
focal_loss = alpha * (1 - BCE_EXP)**gamma * BCE
intersection = (inputs * targets).sum()
dice = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
return 1 - dice
return focal_loss
def compute_ari(true_mask, pred_mask):
"""
Computes the adjusted rand index (ARI) of a given image segmentation, ignoring the background.
Implementation following https://github.com/deepmind/multi_object_datasets/blob/master/segmentation_metrics.py#L20
Args:
true_mask: Tensor of shape [batch_size, n_true_groups, n_points] containing true
one-hot coded cluster assignments, with background being indicated by zero vectors.
pred_mask: Tensor of shape [batch_size, n_pred_groups, n_points] containing predicted
cluster assignments encoded as categorical probabilities.
"""
batch_size, n_true_groups, n_points = true_mask.shape
n_pred_groups = pred_mask.shape[1]
class DiceLoss(nn.Module):
if n_points <= n_true_groups and n_points <= n_pred_groups:
raise ValueError(
"adjusted_rand_index requires n_groups < n_points. We don't handle "
"the special cases that can occur when you have one cluster "
"per datapoint.")
def __init__(self, weight=None, size_average=True):
super().__init__()
true_group_ids = true_mask.argmax(1)
pred_group_ids = pred_mask.argmax(1)
def forward(self, inputs, targets, smooth=1):
inputs = F.sigmoid(inputs)
inputs = torch.clamp(inputs, min=0, max=1)
#flatten label and prediction tensors
inputs = inputs.view(-1)
targets = targets.view(-1)
# Convert to one-hot ('oh') representations
true_mask_oh = true_mask.float()
pred_mask_oh = torch.eye(n_pred_groups).to(pred_mask)[pred_group_ids].transpose(1, 2).float() # TODO : this float was not there before, to check
intersection = (inputs * targets).sum()
dice = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
n_points_fg = true_mask_oh.sum((1, 2))
return 1 - dice
\ No newline at end of file
nij = torch.einsum('bip,bjp->bji', pred_mask_oh, true_mask_oh)
nij = nij.double() # Cast to double, since the expected_rindex can introduce numerical inaccuracies
a = nij.sum(1)
b = nij.sum(2)
rindex = (nij * (nij - 1)).sum((1, 2))
aindex = (a * (a - 1)).sum(1)
bindex = (b * (b - 1)).sum(1)
expected_rindex = aindex * bindex / (n_points_fg * (n_points_fg - 1))
max_rindex = (aindex + bindex) / 2
ari = (rindex - expected_rindex) / (max_rindex - expected_rindex)
# We can get NaN in case max_rindex == expected_rindex. This happens when both true and
# predicted segmentations consist of only a single segment. Since we are allowing the
# true segmentation to contain zeros (i.e. background) which we ignore, it suffices
# if the foreground pixels belong to a single segment.
# We check for this case, and instead set the ARI to 1.
def _fg_all_equal(values, bg):
"""
Check if all pixels in values that do not belong to the background (bg is False) have the same
segmentation id.
Args:
values: Segmentations ids given as integer Tensor of shape [batch_size, n_points]
bg: Binary tensor indicating background, shape [batch_size, n_points]
"""
fg_ids = (values + 1) * (1 - bg.int()) # Move fg ids to [1, n], set bg ids to 0
example_fg_id = fg_ids.max(1, keepdim=True)[0] # Get the id of an arbitrary fg cluster.
return torch.logical_or(fg_ids == example_fg_id[..., :1], # All pixels should match that id...
bg # ...or belong to the background.
).all(-1)
background = (true_mask.sum(1) == 0)
both_single_cluster = torch.logical_and(_fg_all_equal(true_group_ids, background),
_fg_all_equal(pred_group_ids, background))
# Ensure that we are only (close to) getting NaNs in exactly the case described above.
matching = (both_single_cluster == torch.isclose(max_rindex, expected_rindex))
if not matching.all().item():
offending_idx = matching.int().argmin()
return torch.where(both_single_cluster, torch.ones_like(ari), ari)
def precision_recall(segmentation_gt: torch.Tensor, segmentation_pred: torch.Tensor, mode: str, adjusted: bool):
""" Compute the (Adjusted) Rand Precision/Recall.
Implementation obtained from paper : Sensitivity of Slot-Based Object-Centric Models to their Number of Slots
Args:
- segmentation_gt: Int tensor with shape (batch_size, height, width) containing the ground-truth segmentations.
- segmentation_pred: Int tensor with shape (batch_size, height, width) containing the predicted segmentations.
- mode: Either "precision" or "recall" depending on which metric shall be computed.
- adjusted: Return values for adjusted or non-adjusted metric.
Returns:
Float tensor with shape (batch_size), containing the (Adjusted) Rand Precision/Recall per sample.
"""
max_classes = max(segmentation_gt.max(), segmentation_pred.max()) + 1
oh_segmentation_gt = F.one_hot(segmentation_gt, max_classes)
oh_segmentation_pred = F.one_hot(segmentation_pred, max_classes)
coincidence = torch.einsum("bhwk,bhwc->bkc", oh_segmentation_gt, oh_segmentation_pred)
coincidence_gt = coincidence.sum(-1)
coincidence_pred = coincidence.sum(-2)
m_squared = torch.sum(coincidence**2, (1, 2))
m = torch.sum(coincidence, (1, 2))
# How many pairs of pixels have the smae label assigned in ground-truth segmentation.
P = torch.sum(coincidence_gt * (coincidence_gt - 1), -1)
# How many pairs of pixels have the smae label assigned in predicted segmentation.
Q = torch.sum(coincidence_pred * (coincidence_pred - 1), -1)
expected_m_squared = (P + m) * (Q + m) / (m * (m - 2)) + (m**2 - Q - P -2 * m) / (m - 1)
if mode == "precision":
gamma = P + m
elif mode == "recall":
gamma = Q + m
else:
raise ValueError("Invalid mode.")
if adjusted:
return (m_squared - expected_m_squared) / (gamma - expected_m_squared)
else:
return (m_squared - m) / (gamma - m)
\ No newline at end of file
......@@ -32,7 +32,7 @@
"decay_it": 4000000,
"lr_warmup": 5000,
"precision": "16-mixed",
"out_dir": "."
"out_dir": "./logs"
}
}
\ No newline at end of file
......@@ -8,8 +8,9 @@ import json
import argparse
import math
import numpy as np
import lightning as L
import segmentation_models_pytorch as smp
import torch
import torch.nn.functional as F
from lightning.fabric.fabric import _FabricOptimizer
......@@ -20,7 +21,7 @@ from osrt.model import OSRT
from osrt.encoder import FeatureMasking
from osrt import data
from osrt.utils.training import AverageMeter
from osrt.utils.losses import DiceLoss, FocalLoss
from osrt.utils.losses import compute_focal_loss, compute_ari, compute_dice_loss
torch.set_float32_matmul_precision('high')
......@@ -28,7 +29,7 @@ __LOG10 = math.log(10)
def validate(fabric: L.Fabric, model: OSRT, val_dataloader: DataLoader, epoch: int = 0):
# TODO : add segmentation also to select the model following how it's done in the training
model.eval()
"""model.eval()
mses = AverageMeter()
psnrs = AverageMeter()
......@@ -70,7 +71,8 @@ def validate(fabric: L.Fabric, model: OSRT, val_dataloader: DataLoader, epoch: i
state_dict = model.state_dict()
if fabric.global_rank == 0:
torch.save(state_dict, os.path.join(cfg.out_dir, f"epoch-{epoch:06d}-psnr{psnrs.avg:.2f}-mse{mses.avg:.2f}-ckpt.pth"))
model.train()
model.train()"""
pass
def train_sam(
......@@ -81,12 +83,11 @@ def train_sam(
scheduler: _FabricOptimizer,
train_dataloader: DataLoader,
val_dataloader: DataLoader,
batch_size: int
):
"""The SAM training loop."""
nb_epochs = cfg["training"]["max_it"] // batch_size
focal_loss = FocalLoss()
dice_loss = DiceLoss()
nb_epochs = cfg["training"]["max_it"] // cfg["training"]["batch_size"]
for epoch in range(1, nb_epochs):
# TODO : add psnr loss ?
batch_time = AverageMeter()
......@@ -105,12 +106,12 @@ def train_sam(
data_time.update(time.time() - end)
# TODO : adapt to our model
### Extract input data
input_images = data.get('input_images')
input_camera_pos = data.get('input_camera_pos')
input_rays = data.get('input_rays')
target_pixels = data.get('target_pixels')
### Encode input informations and extract masks
if isinstance(model.encoder, FeatureMasking):
input_images = input_images.permute(0, 1, 3, 4, 2) # from [b, k, c, h, w] to [b, k, h, w, c]
h, w, c = input_images[0][0].shape
......@@ -118,41 +119,55 @@ def train_sam(
else:
z = model.encoder(input_images, input_camera_pos, input_rays)
target_camera_pos = data.get('target_camera_pos')
target_rays = data.get('target_rays')
### Extract target data
target_pixels = data.get('target_pixels') # [p, 3]
target_camera_pos = data.get('target_camera_pos') # [p, 3]
target_rays = data.get('target_rays') # [p, 3]
### Decode slots and reconstruct image + segmentation mask
loss_mse = torch.tensor(0., device=fabric.device)
loss_focal = torch.tensor(0., device=fabric.device)
loss_dice = torch.tensor(0., device=fabric.device)
pred_pixels, extras = model.decoder(z, target_camera_pos, target_rays)#, **self.render_kwargs)
### Compute MSE on pixels
loss_mse = loss_mse + ((pred_pixels - target_pixels)**2).mean((1, 2))
batch_size = input_images.shape[0]
### Compute loss
loss_mse += ((pred_pixels - target_pixels)**2).mean((1, 2)).mean(0)
### Evaluate segmentation
if 'segmentation' in extras:
# TODO : for visualisation only, could be interesting to check real GT
#true_seg = data['target_masks'].float()
pred_masks = extras['segmentation']
fabric.print(f"Pred segmentation shape {extras['segmentation'].shape}")
pred_masks = extras['segmentation'] # [B, nb_rays, nb_slots]
sample_idx = data['sampled_idx'] # [nb_pixels_sampled]
# TODO : extract for each batch the number of masks
fabric.print(f"GT segmentation shape {masks_info['segmentations'].shape}")
gt_masks = masks_info["segmentations"] # [B, nb_img, HxW]
gt_masks = gt_masks.permute(0, 2, 1) # [B, HxW, nb_img]
fabric.print(f"True segmentation shape {data['target_masks'].shape}")
# TODO : evaluate also with true seg
true_seg = data['target_masks'].float() # [B, nb_rays, nb_masks]
# TODO : check the content of num_masks
num_masks = sum(len(pred_mask) for pred_mask in pred_mask)
for pred_mask, gt_mask in zip(pred_masks, masks_info["segmentations"]):
loss_focal += focal_loss(pred_mask, gt_mask, num_masks)
loss_dice += dice_loss(pred_mask, gt_mask, num_masks)
"""num_masks = sum(len(pred_mask) for pred_mask in pred_mask)
for pred_mask, gt_mask in zip(pred_masks, gt_masks):
loss_focal += compute_focal_loss(pred_mask, gt_mask, num_masks)
loss_dice += compute_dice_loss(pred_mask, gt_mask, num_masks)
# TODO : check the values of the loss and see if scale is ok
loss_total = 20. * loss_focal + loss_dice + loss_mse
# TODO : check also with ARI, FG-ARI values and new from recent paper
"""loss_terms['ari'] = compute_adjusted_rand_index(true_seg.transpose(1, 2),
pred_seg.transpose(1, 2))
loss_terms['fg_ari'] = compute_adjusted_rand_index(true_seg.transpose(1, 2)[:, 1:],
pred_seg.transpose(1, 2))"""
# TODO : use recent versions of ARI
ari = compute_ari(gt_masks.transpose(1, 2),
pred_masks.transpose(1, 2))
fg_ari = compute_ari(gt_masks.transpose(1, 2)[:, 1:],
pred_masks.transpose(1, 2))"""
optimizer.zero_grad()
fabric.backward(loss_total)
# TODO : check with a combined loss with segmentation
fabric.backward(loss_mse)
optimizer.step()
scheduler.step()
batch_time.update(time.time() - end)
......@@ -161,7 +176,7 @@ def train_sam(
focal_losses.update(loss_focal.item(), batch_size)
dice_losses.update(loss_dice.item(), batch_size)
mse_losses.update(loss_mse.item(), batch_size)
total_losses.update(loss_total.item(), batch_size)
#total_losses.update(loss_total.item(), batch_size)
fabric.print(f'Epoch: [{epoch}][{iter+1}/{len(train_dataloader)}]'
f' | Time [{batch_time.val:.3f}s ({batch_time.avg:.3f}s)]'
......@@ -185,7 +200,7 @@ def configure_opt(cfg, model: OSRT):
return peak_lr * (decay_rate ** (it_since_peak / warmup_iters))
# TODO : check begin value of lr
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=decay_rate)
optimizer = torch.optim.Adam(model.parameters(), lr=8e-4, weight_decay=decay_rate)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
return optimizer, scheduler
......@@ -225,13 +240,16 @@ def main(cfg) -> None:
cfg['render_args'] = train_dataset.render_kwargs
train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers)
val_loader = DataLoader( val_dataset, batch_size=batch_size, num_workers=num_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers)
vis_loader_val = DataLoader(val_dataset, batch_size=12, num_workers=num_workers)
train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers,shuffle=True)
train_loader = fabric._setup_dataloader(train_loader)
val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True)
val_loader = fabric._setup_dataloader(val_loader)
train_loader, val_loader, test_loader = fabric.setup_dataloaders(train_loader, val_loader, test_loader)
test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers)
test_loader = fabric._setup_dataloader(test_loader)
vis_loader_val = DataLoader(val_dataset, batch_size=12, num_workers=num_workers)
data_vis_val = next(iter(vis_loader_val)) # Validation set data for visualization
data_vis_val = fabric.to_device(data_vis_val)
......@@ -245,7 +263,7 @@ def main(cfg) -> None:
#########################
### Training
#########################
train_sam(cfg, fabric, model, optimizer, scheduler, train_loader, val_loader)
train_sam(cfg, fabric, model, optimizer, scheduler, train_loader, val_loader, batch_size)
validate(fabric, model, val_loader, epoch=0)
if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment