-
Alexandre Chapin authored75048e18
train.py 14.59 KiB
import torch
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel
from torch.distributed.fsdp import FullyShardedDataParallel, CPUOffload
from torch.distributed.fsdp.wrap import (
transformer_auto_wrap_policy,
)
import functools
import numpy as np
#import bitsandbytes as bnb
import os
import argparse
import time, datetime
import yaml
from osrt import data
from osrt.model import OSRT
from osrt.trainer import SRTTrainer, OSRTSamTrainer
from osrt.layers import Transformer
from osrt.checkpoint import Checkpoint
from osrt.utils.common import init_ddp
from segment_anything.modeling.image_encoder import Block
from segment_anything.modeling.transformer import TwoWayTransformer
from torch.profiler import profile, tensorboard_trace_handler, ProfilerActivity,schedule
class LrScheduler():
""" Implements a learning rate schedule with warum up and decay """
def __init__(self, peak_lr=4e-4, peak_it=10000, decay_rate=0.5, decay_it=100000):
self.peak_lr = peak_lr
self.peak_it = peak_it
self.decay_rate = decay_rate
self.decay_it = decay_it
def get_cur_lr(self, it):
if it < self.peak_it: # Warmup period
return self.peak_lr * (it / self.peak_it)
it_since_peak = it - self.peak_it
return self.peak_lr * (self.decay_rate ** (it_since_peak / self.decay_it))
if __name__ == '__main__':
# Arguments
parser = argparse.ArgumentParser(
description='Train a 3D scene representation model.'
)
parser.add_argument('config', type=str, help='Path to config file.')
parser.add_argument('--test', action='store_true', help='When evaluating, use test instead of validation split.')
parser.add_argument('--evalnow', action='store_true', help='Run evaluation on startup.')
parser.add_argument('--visnow', action='store_true', help='Run visualization on startup.')
parser.add_argument('--wandb', action='store_true', help='Log run to Weights and Biases.')
parser.add_argument('--offline_log', default=True, help='Log the W&B logs offline by default')
parser.add_argument('--max-eval', type=int, help='Limit the number of scenes in the evaluation set.')
parser.add_argument('--full-scale', action='store_true', help='Evaluate on full images.')
parser.add_argument('--print-model', action='store_true', help='Print model and parameters on startup.')
parser.add_argument('--strategy', type=str, default='ddp', help='Strategy to use for parallelisation [ddp, dfsdp]')
args = parser.parse_args()
with open(args.config, 'r') as f:
cfg = yaml.load(f, Loader=yaml.CLoader)
rank, world_size = init_ddp()
device = torch.device(f"cuda:{rank}")
args.wandb = args.wandb and rank == 0 # Only log to wandb in main process
if 'max_it' in cfg['training']:
max_it = cfg['training']['max_it']
else:
max_it = 1000000
exp_name = os.path.basename(os.path.dirname(args.config))
out_dir = os.path.dirname(args.config)
# Divide batch size across
batch_size = cfg['training']['batch_size'] // world_size
model_selection_metric = cfg['training']['model_selection_metric']
if cfg['training']['model_selection_mode'] == 'maximize':
model_selection_sign = 1
elif cfg['training']['model_selection_mode'] == 'minimize':
model_selection_sign = -1
else:
raise ValueError('model_selection_mode must be either maximize or minimize.')
# Initialize datasets
print('Loading training set...')
train_dataset = data.get_dataset('train', cfg['data'])
eval_split = 'test' if args.test else 'val'
print(f'Loading {eval_split} set...')
eval_dataset = data.get_dataset(eval_split, cfg['data'],
max_len=args.max_eval, full_scale=args.full_scale)
num_workers = cfg['training']['num_workers'] if 'num_workers' in cfg['training'] else 1
print(f'Using {num_workers} workers per process for data loading.')
# Initialize data loaders
train_sampler = val_sampler = None
shuffle = False
if isinstance(train_dataset, torch.utils.data.IterableDataset):
#assert num_workers == 1, "Our MSN dataset is implemented as Tensorflow iterable, and does not currently support multiple PyTorch workers per process. Is also shouldn't need any, since Tensorflow uses multiple workers internally."
print("Iterable dataset")
else:
if world_size > 1:
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset, shuffle=True, drop_last=False)
val_sampler = torch.utils.data.distributed.DistributedSampler(
eval_dataset, shuffle=True, drop_last=False)
else:
shuffle = True
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=True,
sampler=train_sampler, shuffle=shuffle,
worker_init_fn=data.worker_init_fn)#, persistent_workers=True)
val_loader = torch.utils.data.DataLoader(
eval_dataset, batch_size=max(1, batch_size // 8), num_workers=num_workers,#1,
sampler=val_sampler, shuffle=shuffle,
pin_memory=False, worker_init_fn=data.worker_init_fn)#, persistent_workers=True)
# Loaders for visualization scenes
vis_loader_val = torch.utils.data.DataLoader(
eval_dataset, batch_size=12, shuffle=shuffle, worker_init_fn=data.worker_init_fn)
vis_loader_train = torch.utils.data.DataLoader(
train_dataset, batch_size=12, shuffle=shuffle, worker_init_fn=data.worker_init_fn)
data_vis_val = next(iter(vis_loader_val)) # Validation set data for visualization
train_dataset.mode = 'val' # Get validation info from training set just this once
data_vis_train = next(iter(vis_loader_train)) # Validation set data for visualization
train_dataset.mode = 'train'
# Create model
model = OSRT(cfg['model']).to(device)
print('Model created.')
if world_size > 1:
if args.strategy == "fsdp":
model_encoder_ddp = DistributedDataParallel(model.encoder, device_ids=[rank], output_device=rank, find_unused_parameters=True) # Set find_unused_parameters to True because the ViT is not trained
model_decoder_ddp = DistributedDataParallel(model.decoder, device_ids=[rank], output_device=rank, find_unused_parameters=False)
custom_auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
Transformer,
TwoWayTransformer,
Block
},
)
model.encoder = FullyShardedDataParallel(
model_encoder_ddp(),
fsdp_auto_wrap_policy=custom_auto_wrap_policy)
model.decoder = FullyShardedDataParallel(
model_decoder_ddp(),
fsdp_auto_wrap_policy=custom_auto_wrap_policy)
else:
model.encoder = DistributedDataParallel(model.encoder, device_ids=[rank], output_device=rank, find_unused_parameters=True) # Set find_unused_parameters to True because the ViT is not trained
model.decoder = DistributedDataParallel(model.decoder, device_ids=[rank], output_device=rank, find_unused_parameters=False)
encoder_module = model.encoder.module
decoder_module = model.decoder.module
else:
encoder_module = model.encoder
decoder_module = model.decoder
if 'lr_warmup' in cfg['training']:
peak_it = cfg['training']['lr_warmup']
else:
peak_it = 2500
decay_it = cfg['training']['decay_it'] if 'decay_it' in cfg['training'] else 4000000
lr_scheduler = LrScheduler(peak_lr=1e-4, peak_it=peak_it, decay_it=decay_it, decay_rate=0.16)
# Intialize training
params = [p for p in model.parameters() if p.requires_grad] # only keep trainable parameters
optimizer = optim.Adam(params, lr=lr_scheduler.get_cur_lr(0)) # to check after bnb.optim.Adam8bit(params, lr=lr_scheduler.get_cur_lr(0))
trainer = OSRTSamTrainer(model, optimizer, cfg, device, out_dir, train_dataset.render_kwargs)
checkpoint = Checkpoint(out_dir, device=device, encoder=encoder_module,
decoder=decoder_module, optimizer=optimizer)
# Try to automatically resume
try:
if os.path.exists(os.path.join(out_dir, f'model_{max_it}.pt')):
load_dict = checkpoint.load(f'model_{max_it}.pt')
else:
load_dict = checkpoint.load('model.pt')
except FileNotFoundError:
load_dict = dict()
epoch_it = load_dict.get('epoch_it', -1)
it = load_dict.get('it', -1)
time_elapsed = load_dict.get('t', 0.)
run_id = load_dict.get('run_id', None)
metric_val_best = load_dict.get(
'loss_val_best', -model_selection_sign * np.inf)
print(f'Current best validation metric ({model_selection_metric}): {metric_val_best:.8f}.')
if args.wandb:
import wandb
if run_id is None:
run_id = wandb.util.generate_id()
print(f'Sampled new wandb run_id {run_id}.')
else:
print(f'Resuming wandb with existing run_id {run_id}.')
# Tell in which mode to launch the logging in W&B (for offline cluster)
if args.offline_log:
mode = "offline"
else:
mode = "online"
wandb.init(project='osrt', name=os.path.dirname(args.config),
id=run_id, resume=True, mode=mode, sync_tensorboard=True)
wandb.config = cfg
prof = profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
schedule=schedule(wait=1, warmup=1, active=12, repeat=1),
on_trace_ready=tensorboard_trace_handler(f"osrt_{run_id}"),
profile_memory=True,
record_shapes=False,
with_stack=False,
with_flops=False)
prof.start()
else:
prof = None
if args.print_model:
print(model)
for name, param in model.named_parameters():
if param.requires_grad:
print(f'{name:80}{str(list(param.data.shape)):20}{int(param.data.numel()):10d}')
num_encoder_params = sum(p.numel() for p in model.encoder.parameters())
num_decoder_params = sum(p.numel() for p in model.decoder.parameters())
print('Number of parameters:')
print(f'\tEncoder: {num_encoder_params}')
if cfg['model']['encoder'] == 'osrt':
num_srt_encoder_params = sum(p.numel() for p in model.encoder.module.srt_encoder.parameters())
num_slotatt_params = sum(p.numel() for p in model.encoder.module.slot_attention.parameters())
print(f'\t\tSRT Encoder: {num_srt_encoder_params}.')
print(f'\t\tSlot Attention: {num_slotatt_params}.')
print(f'\tDecoder: {num_decoder_params}')
print(f'Total: {num_encoder_params + num_decoder_params}')
# Shorthands
print_every = cfg['training']['print_every']
checkpoint_every = cfg['training']['checkpoint_every']
validate_every = cfg['training']['validate_every']
visualize_every = cfg['training']['visualize_every']
backup_every = cfg['training']['backup_every']
training_has_begun = False
# Training loop
while True:
epoch_it += 1
if train_sampler is not None:
train_sampler.set_epoch(epoch_it)
for batch in train_loader:
it += 1
training_has_begun = True
# Special responsibilities for the main process
if rank == 0:
checkpoint_scalars = {'epoch_it': epoch_it,
'it': it,
't': time_elapsed,
'loss_val_best': metric_val_best,
'run_id': run_id}
# Save checkpoint
if (checkpoint_every > 0 and (it % checkpoint_every) == 0) and it > 0:
checkpoint.save('model.pt', **checkpoint_scalars)
print('Checkpoint saved.')
# Backup if necessary
if (backup_every > 0 and (it % backup_every) == 0):
checkpoint.save('model_%d.pt' % it, **checkpoint_scalars)
print('Backup checkpoint saved.')
# Visualize output
if args.visnow or (it > 0 and visualize_every > 0 and (it % visualize_every) == 0):
print('Visualizing...')
trainer.visualize(data_vis_val, mode='val')
trainer.visualize(data_vis_train, mode='train')
# Run evaluation
if args.evalnow or (it > 0 and validate_every > 0 and (it % validate_every) == 0):
print('Evaluating...')
eval_dict = trainer.evaluate(val_loader)
metric_val = eval_dict[model_selection_metric]
print(f'Validation metric ({model_selection_metric}): {metric_val:.4f}')
if args.wandb:
wandb.log(eval_dict, step=it)
if model_selection_sign * (metric_val - metric_val_best) > 0:
metric_val_best = metric_val
if rank == 0:
checkpoint_scalars['loss_val_best'] = metric_val_best
print(f'New best model (loss {metric_val_best:.6f})')
checkpoint.save('model_best.pt', **checkpoint_scalars)
# Update learning rate
new_lr = lr_scheduler.get_cur_lr(it)
for param_group in optimizer.param_groups:
param_group['lr'] = new_lr
# Run training step
t0 = time.perf_counter()
loss, log_dict = trainer.train_step(batch, it)
if prof:
prof.step()
time_elapsed += time.perf_counter() - t0
time_elapsed_str = str(datetime.timedelta(seconds=time_elapsed))
log_dict['lr'] = new_lr
# Print progress
if print_every > 0 and (it % print_every) == 0:
log_str = ['{}={:f}'.format(k, v) for k, v in log_dict.items()]
print(out_dir, 't=%s [Epoch %02d] it=%03d, loss=%.4f'
% (time_elapsed_str, epoch_it, it, loss), log_str)
log_dict['t'] = time_elapsed
if args.wandb:
wandb.log(log_dict, step=it)
args.evalnow = False
args.visnow = False
if it >= max_it:
print('Iteration limit reached. Exiting.')
if rank == 0:
checkpoint.save('model.pt', **checkpoint_scalars)
if prof:
prof.stop()
exit(0)
if prof:
prof.stop()