Skip to content
Snippets Groups Projects
Commit 2c7468b2 authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Check fsdp

parent 43291e42
No related branches found
No related tags found
No related merge requests found
import torch
import torch.distributed as dist
import numpy as np
from tqdm import tqdm
......@@ -11,80 +12,28 @@ import os
import math
from collections import defaultdict
class OSRTSamTrainer:
def __init__(self, model, optimizer, cfg, device, out_dir, render_kwargs):
self.model = model
self.optimizer = optimizer
self.config = cfg
self.device = device
self.out_dir = out_dir
self.render_kwargs = render_kwargs
if 'num_coarse_samples' in cfg['training']:
self.render_kwargs['num_coarse_samples'] = cfg['training']['num_coarse_samples']
if 'num_fine_samples' in cfg['training']:
self.render_kwargs['num_fine_samples'] = cfg['training']['num_fine_samples']
def evaluate(self, val_loader, **kwargs):
''' Performs an evaluation.
Args:
val_loader (dataloader): pytorch dataloader
'''
self.model.eval()
eval_lists = defaultdict(list)
def train(args, model, rank, world_size, train_loader, optimizer, epoch):
ddp_loss = torch.zeros(2).to(rank)
for batch_idx, (data, target) in enumerate(train_loader):
model.train()
optimizer.zero_grad()
loader = val_loader if get_rank() > 0 else tqdm(val_loader)
sceneids = []
input_images = data.get('input_images').to(rank)
input_camera_pos = data.get('input_camera_pos').to(rank)
input_rays = data.get('input_rays').to(rank)
target_pixels = data.get('target_pixels').to(rank)
for data in loader:
sceneids.append(data['sceneid'])
eval_step_dict = self.eval_step(data, **kwargs)
for k, v in eval_step_dict.items():
eval_lists[k].append(v)
sceneids = torch.cat(sceneids, 0).cuda()
sceneids = torch.cat(gather_all(sceneids), 0)
print(f'Evaluated {len(torch.unique(sceneids))} unique scenes.')
eval_dict = {k: torch.cat(v, 0) for k, v in eval_lists.items()}
eval_dict = reduce_dict(eval_dict, average=True) # Average across processes
eval_dict = {k: v.mean().item() for k, v in eval_dict.items()} # Average across batch_size
print('Evaluation results:')
print(eval_dict)
return eval_dict
def train_step(self, data, it):
self.model.train()
self.optimizer.zero_grad()
loss, loss_terms = self.compute_loss(data, it)
loss = loss.mean(0)
loss_terms = {k: v.mean(0).item() for k, v in loss_terms.items()}
loss.backward()
self.optimizer.step()
return loss.item(), loss_terms
def compute_loss(self, data, it):
device = self.device
input_images = data.get('input_images').to(device)
input_camera_pos = data.get('input_camera_pos').to(device)
input_rays = data.get('input_rays').to(device)
target_pixels = data.get('target_pixels').to(device)
input_images = input_images.permute(0, 1, 3, 4, 2) # from [b, k, c, h, w] to [b, k, h, w, c]
h, w, c = input_images[0][0].shape
with torch.cuda.amp.autocast():
z = self.model.encoder(input_images, (h, w), input_camera_pos, input_rays)
z = model.encoder(input_images, input_camera_pos, input_rays)
target_camera_pos = data.get('target_camera_pos').to(device)
target_rays = data.get('target_rays').to(device)
target_camera_pos = data.get('target_camera_pos').to(rank)
target_rays = data.get('target_rays').to(rank)
loss = 0.
loss_terms = dict()
with torch.cuda.amp.autocast():
pred_pixels, extras = self.model.decoder(z, target_camera_pos, target_rays, **self.render_kwargs)
pred_pixels, extras = model.decoder(z, target_camera_pos, target_rays)#, **self.render_kwargs)
loss = loss + ((pred_pixels - target_pixels)**2).mean((1, 2))
loss_terms['mse'] = loss
......@@ -95,7 +44,7 @@ class OSRTSamTrainer:
if 'segmentation' in extras:
pred_seg = extras['segmentation']
true_seg = data['target_masks'].to(device).float()
true_seg = data['target_masks'].to(rank).float()
# These are not actually used as part of the training loss.
# We just add the to the dict to report them.
......@@ -103,126 +52,20 @@ class OSRTSamTrainer:
pred_seg.transpose(1, 2))
loss_terms['fg_ari'] = compute_adjusted_rand_index(true_seg.transpose(1, 2)[:, 1:],
pred_seg.transpose(1, 2))
return loss, loss_terms
def eval_step(self, data, full_scale=False):
with torch.no_grad():
loss, loss_terms = self.compute_loss(data, 1000000)
pred_seg.transpose(1, 2))
mse = loss_terms['mse']
psnr = mse2psnr(mse)
return {'psnr': psnr, 'mse': mse, **loss_terms}
def render_image(self, z, camera_pos, rays, **render_kwargs):
"""
Args:
z [n, k, c]: set structured latent variables
camera_pos [n, 3]: camera position
rays [n, h, w, 3]: ray directions
render_kwargs: kwargs passed on to decoder
"""
batch_size, height, width = rays.shape[:3]
rays = rays.flatten(1, 2)
camera_pos = camera_pos.unsqueeze(1).repeat(1, rays.shape[1], 1)
max_num_rays = self.config['data']['num_points'] * \
self.config['training']['batch_size'] // (rays.shape[0] * get_world_size())
num_rays = rays.shape[1]
img = torch.zeros_like(rays)
all_extras = []
for i in range(0, num_rays, max_num_rays):
img[:, i:i+max_num_rays], extras = self.model.decoder(
z, camera_pos[:, i:i+max_num_rays], rays[:, i:i+max_num_rays],
**render_kwargs)
all_extras.append(extras)
agg_extras = {}
for key in all_extras[0]:
agg_extras[key] = torch.cat([extras[key] for extras in all_extras], 1)
agg_extras[key] = agg_extras[key].view(batch_size, height, width, -1)
img = img.view(img.shape[0], height, width, 3)
return img, agg_extras
def visualize(self, data, mode='val'):
self.model.eval()
with torch.no_grad():
device = self.device
input_images = data.get('input_images').to(device)
input_camera_pos = data.get('input_camera_pos').to(device)
input_rays = data.get('input_rays').to(device)
camera_pos_base = input_camera_pos[:, 0]
input_rays_base = input_rays[:, 0]
if 'transform' in data:
# If the data is transformed in some different coordinate system, where
# rotating around the z axis doesn't make sense, we first undo this transform,
# then rotate, and then reapply it.
transform = data['transform'].to(device)
inv_transform = torch.inverse(transform)
camera_pos_base = nerf.transform_points_torch(camera_pos_base, inv_transform)
input_rays_base = nerf.transform_points_torch(
input_rays_base, inv_transform.unsqueeze(1).unsqueeze(2), translate=False)
else:
transform = None
input_images_np = np.transpose(input_images.cpu().numpy(), (0, 1, 3, 4, 2))
z = self.model.encoder(input_images, input_camera_pos, input_rays)
batch_size, num_input_images, height, width, _ = input_rays.shape
num_angles = 6
columns = []
for i in range(num_input_images):
header = 'input' if num_input_images == 1 else f'input {i+1}'
columns.append((header, input_images_np[:, i], 'image'))
if 'input_masks' in data:
input_mask = data['input_masks'][:, 0]
columns.append(('true seg 0°', input_mask.argmax(-1), 'clustering'))
row_labels = None
for i in range(num_angles):
angle = i * (2 * math.pi / num_angles)
angle_deg = (i * 360) // num_angles
camera_pos_rot = nerf.rotate_around_z_axis_torch(camera_pos_base, angle)
rays_rot = nerf.rotate_around_z_axis_torch(input_rays_base, angle)
if transform is not None:
camera_pos_rot = nerf.transform_points_torch(camera_pos_rot, transform)
rays_rot = nerf.transform_points_torch(
rays_rot, transform.unsqueeze(1).unsqueeze(2), translate=False)
img, extras = self.render_image(z, camera_pos_rot, rays_rot, **self.render_kwargs)
columns.append((f'render {angle_deg}°', img.cpu().numpy(), 'image'))
if 'depth' in extras:
depth_img = extras['depth'].unsqueeze(-1) / self.render_kwargs['max_dist']
depth_img = depth_img.view(batch_size, height, width, 1)
columns.append((f'depths {angle_deg}°', depth_img.cpu().numpy(), 'image'))
if 'segmentation' in extras:
pred_seg = extras['segmentation'].cpu()
columns.append((f'pred seg {angle_deg}°', pred_seg.argmax(-1).numpy(), 'clustering'))
if i == 0:
ari = compute_adjusted_rand_index(
input_mask.flatten(1, 2).transpose(1, 2)[:, 1:],
pred_seg.flatten(1, 2).transpose(1, 2))
row_labels = ['2D Fg-ARI={:.1f}'.format(x.item() * 100) for x in ari]
output_img_path = os.path.join(self.out_dir, f'renders-{mode}')
vis.draw_visualization_grid(columns, output_img_path, row_labels=row_labels)
loss = loss.mean(0)
loss_terms = {k: v.mean(0).item() for k, v in loss_terms.items()}
loss.backward()
optimizer.step()
ddp_loss[0] += loss.item()
ddp_loss[1] += len(input_images)
dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
if rank == 0:
print('Train Epoch: {} \tLoss: {:.6f}'.format(epoch, ddp_loss[0] / ddp_loss[1]))
class SRTTrainer:
def __init__(self, model, optimizer, cfg, device, out_dir, render_kwargs):
self.model = model
......
......@@ -189,6 +189,8 @@ def init_ddp():
setup_dist_print(local_rank == 0)
return local_rank, world_size
def cleanup():
dist.destroy_process_group()
def setup_dist_print(is_main):
import builtins as __builtin__
......
Subproject commit 6fdee8f2727f4506cfbbe553e23b895e27956588
Subproject commit f7b29ba9df1496489af8c71a4bdabed7e8b017b1
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment