Skip to content
Snippets Groups Projects
Commit 09bce025 authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Add fsdp on classical branch

parent 2c7468b2
No related branches found
No related tags found
No related merge requests found
......@@ -11,61 +11,7 @@ from osrt.utils.common import get_rank, get_world_size
import os
import math
from collections import defaultdict
def train(args, model, rank, world_size, train_loader, optimizer, epoch):
ddp_loss = torch.zeros(2).to(rank)
for batch_idx, (data, target) in enumerate(train_loader):
model.train()
optimizer.zero_grad()
input_images = data.get('input_images').to(rank)
input_camera_pos = data.get('input_camera_pos').to(rank)
input_rays = data.get('input_rays').to(rank)
target_pixels = data.get('target_pixels').to(rank)
with torch.cuda.amp.autocast():
z = model.encoder(input_images, input_camera_pos, input_rays)
target_camera_pos = data.get('target_camera_pos').to(rank)
target_rays = data.get('target_rays').to(rank)
loss = 0.
loss_terms = dict()
with torch.cuda.amp.autocast():
pred_pixels, extras = model.decoder(z, target_camera_pos, target_rays)#, **self.render_kwargs)
loss = loss + ((pred_pixels - target_pixels)**2).mean((1, 2))
loss_terms['mse'] = loss
if 'coarse_img' in extras:
coarse_loss = ((extras['coarse_img'] - target_pixels)**2).mean((1, 2))
loss_terms['coarse_mse'] = coarse_loss
loss = loss + coarse_loss
if 'segmentation' in extras:
pred_seg = extras['segmentation']
true_seg = data['target_masks'].to(rank).float()
# These are not actually used as part of the training loss.
# We just add the to the dict to report them.
loss_terms['ari'] = compute_adjusted_rand_index(true_seg.transpose(1, 2),
pred_seg.transpose(1, 2))
loss_terms['fg_ari'] = compute_adjusted_rand_index(true_seg.transpose(1, 2)[:, 1:],
pred_seg.transpose(1, 2))
loss = loss.mean(0)
loss_terms = {k: v.mean(0).item() for k, v in loss_terms.items()}
loss.backward()
optimizer.step()
ddp_loss[0] += loss.item()
ddp_loss[1] += len(input_images)
dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
if rank == 0:
print('Train Epoch: {} \tLoss: {:.6f}'.format(epoch, ddp_loss[0] / ddp_loss[1]))
class SRTTrainer:
def __init__(self, model, optimizer, cfg, device, out_dir, render_kwargs):
self.model = model
......
import torch
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel
from torch.distributed.fsdp import FullyShardedDataParallel, CPUOffload
from torch.distributed.fsdp.wrap import default_auto_wrap_policy
import numpy as np
#import bitsandbytes as bnb
......@@ -47,6 +50,7 @@ if __name__ == '__main__':
parser.add_argument('--max-eval', type=int, help='Limit the number of scenes in the evaluation set.')
parser.add_argument('--full-scale', action='store_true', help='Evaluate on full images.')
parser.add_argument('--print-model', action='store_true', help='Print model and parameters on startup.')
parser.add_argument('--strategy', type=str, default='ddp', help='Strategy to use for parallelisation [ddp, dfsdp]')
args = parser.parse_args()
with open(args.config, 'r') as f:
......@@ -128,8 +132,15 @@ if __name__ == '__main__':
print('Model created.')
if world_size > 1:
model.encoder = DistributedDataParallel(model.encoder, device_ids=[rank], output_device=rank, find_unused_parameters=True) # Set find_unused_parameters to True because the ViT is not trained
model.decoder = DistributedDataParallel(model.decoder, device_ids=[rank], output_device=rank, find_unused_parameters=False)
model_encoder_ddp = DistributedDataParallel(model.encoder, device_ids=[rank], output_device=rank, find_unused_parameters=True) # Set find_unused_parameters to True because the ViT is not trained
model_decoder_ddp = DistributedDataParallel(model.decoder, device_ids=[rank], output_device=rank, find_unused_parameters=False)
model.encoder = FullyShardedDataParallel(
model_encoder_ddp(),
fsdp_auto_wrap_policy=default_auto_wrap_policy)
model.decoder = FullyShardedDataParallel(
model_decoder_ddp(),
fsdp_auto_wrap_policy=default_auto_wrap_policy)
encoder_module = model.encoder.module
decoder_module = model.decoder.module
else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment