Skip to content
Snippets Groups Projects
train.py 14.59 KiB
import torch
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel
from torch.distributed.fsdp import FullyShardedDataParallel, CPUOffload
from torch.distributed.fsdp.wrap import (
   transformer_auto_wrap_policy,
)
import functools

import numpy as np
#import bitsandbytes as bnb

import os
import argparse
import time, datetime
import yaml

from osrt import data
from osrt.model import OSRT
from osrt.trainer import SRTTrainer, OSRTSamTrainer
from osrt.layers import Transformer
from osrt.checkpoint import Checkpoint
from osrt.utils.common import init_ddp

from segment_anything.modeling.image_encoder import Block
from segment_anything.modeling.transformer import TwoWayTransformer

from torch.profiler import profile, tensorboard_trace_handler, ProfilerActivity,schedule


class LrScheduler():
    """ Implements a learning rate schedule with warum up and decay """
    def __init__(self, peak_lr=4e-4, peak_it=10000, decay_rate=0.5, decay_it=100000):
        self.peak_lr = peak_lr
        self.peak_it = peak_it
        self.decay_rate = decay_rate
        self.decay_it = decay_it

    def get_cur_lr(self, it):
        if it < self.peak_it:  # Warmup period
            return self.peak_lr * (it / self.peak_it)
        it_since_peak = it - self.peak_it
        return self.peak_lr * (self.decay_rate ** (it_since_peak / self.decay_it))


if __name__ == '__main__':
    # Arguments
    parser = argparse.ArgumentParser(
        description='Train a 3D scene representation model.'
    )
    parser.add_argument('config', type=str, help='Path to config file.')
    parser.add_argument('--test', action='store_true', help='When evaluating, use test instead of validation split.')
    parser.add_argument('--evalnow', action='store_true', help='Run evaluation on startup.')
    parser.add_argument('--visnow', action='store_true', help='Run visualization on startup.')
    parser.add_argument('--wandb', action='store_true', help='Log run to Weights and Biases.')
    parser.add_argument('--offline_log', default=True, help='Log the W&B logs offline by default')
    parser.add_argument('--max-eval', type=int, help='Limit the number of scenes in the evaluation set.')
    parser.add_argument('--full-scale', action='store_true', help='Evaluate on full images.')
    parser.add_argument('--print-model', action='store_true', help='Print model and parameters on startup.')
    parser.add_argument('--strategy', type=str, default='ddp', help='Strategy to use for parallelisation [ddp, dfsdp]')

    args = parser.parse_args()
    with open(args.config, 'r') as f:
        cfg = yaml.load(f, Loader=yaml.CLoader)

    rank, world_size = init_ddp()
    device = torch.device(f"cuda:{rank}")

    args.wandb = args.wandb and rank == 0  # Only log to wandb in main process

    if 'max_it' in cfg['training']:
        max_it = cfg['training']['max_it']
    else:
        max_it = 1000000

    exp_name = os.path.basename(os.path.dirname(args.config))
    out_dir = os.path.dirname(args.config)

    # Divide batch size across 
    batch_size = cfg['training']['batch_size'] // world_size

    model_selection_metric = cfg['training']['model_selection_metric']
    if cfg['training']['model_selection_mode'] == 'maximize':
        model_selection_sign = 1
    elif cfg['training']['model_selection_mode'] == 'minimize':
        model_selection_sign = -1
    else:
        raise ValueError('model_selection_mode must be either maximize or minimize.')

    # Initialize datasets
    print('Loading training set...')
    train_dataset = data.get_dataset('train', cfg['data'])
    eval_split = 'test' if args.test else 'val'
    print(f'Loading {eval_split} set...')
    eval_dataset = data.get_dataset(eval_split, cfg['data'],
                                    max_len=args.max_eval, full_scale=args.full_scale)

    num_workers = cfg['training']['num_workers'] if 'num_workers' in cfg['training'] else 1
    print(f'Using {num_workers} workers per process for data loading.')

    # Initialize data loaders
    train_sampler = val_sampler = None
    shuffle = False
    if isinstance(train_dataset, torch.utils.data.IterableDataset):
        #assert num_workers == 1, "Our MSN dataset is implemented as Tensorflow iterable, and does not currently support multiple PyTorch workers per process. Is also shouldn't need any, since Tensorflow uses multiple workers internally."
        print("Iterable dataset")
    else:
        if world_size > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset, shuffle=True, drop_last=False)
            val_sampler = torch.utils.data.distributed.DistributedSampler(
                eval_dataset, shuffle=True, drop_last=False)
        else:
            shuffle = True

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=True,
        sampler=train_sampler, shuffle=shuffle,
        worker_init_fn=data.worker_init_fn)#, persistent_workers=True)

    val_loader = torch.utils.data.DataLoader(
        eval_dataset, batch_size=max(1, batch_size // 8), num_workers=num_workers,#1,
        sampler=val_sampler, shuffle=shuffle,
        pin_memory=False, worker_init_fn=data.worker_init_fn)#, persistent_workers=True)

    # Loaders for visualization scenes
    vis_loader_val = torch.utils.data.DataLoader(
        eval_dataset, batch_size=12, shuffle=shuffle, worker_init_fn=data.worker_init_fn)
    vis_loader_train = torch.utils.data.DataLoader(
        train_dataset, batch_size=12, shuffle=shuffle, worker_init_fn=data.worker_init_fn)

    data_vis_val = next(iter(vis_loader_val))  # Validation set data for visualization
    train_dataset.mode = 'val'  # Get validation info from training set just this once
    data_vis_train = next(iter(vis_loader_train))  # Validation set data for visualization
    train_dataset.mode = 'train'

    # Create model
    model = OSRT(cfg['model']).to(device)
    print('Model created.')

    if world_size > 1:
        if args.strategy == "fsdp":
            model_encoder_ddp = DistributedDataParallel(model.encoder, device_ids=[rank], output_device=rank, find_unused_parameters=True) # Set find_unused_parameters to True because the ViT is not trained 
            model_decoder_ddp = DistributedDataParallel(model.decoder, device_ids=[rank], output_device=rank, find_unused_parameters=False)
            custom_auto_wrap_policy = functools.partial(
                transformer_auto_wrap_policy,
                transformer_layer_cls={
                    Transformer,
                    TwoWayTransformer,
                    Block
                },
            )
            model.encoder = FullyShardedDataParallel(
                model_encoder_ddp(), 
                fsdp_auto_wrap_policy=custom_auto_wrap_policy)
            model.decoder = FullyShardedDataParallel(
                model_decoder_ddp(), 
                fsdp_auto_wrap_policy=custom_auto_wrap_policy)
        else:
            model.encoder = DistributedDataParallel(model.encoder, device_ids=[rank], output_device=rank, find_unused_parameters=True) # Set find_unused_parameters to True because the ViT is not trained 
            model.decoder = DistributedDataParallel(model.decoder, device_ids=[rank], output_device=rank, find_unused_parameters=False)

        encoder_module = model.encoder.module
        decoder_module = model.decoder.module
    else:
        encoder_module = model.encoder
        decoder_module = model.decoder

    if 'lr_warmup' in cfg['training']:
        peak_it = cfg['training']['lr_warmup']
    else:
        peak_it = 2500

    decay_it = cfg['training']['decay_it'] if 'decay_it' in cfg['training'] else 4000000

    lr_scheduler = LrScheduler(peak_lr=1e-4, peak_it=peak_it, decay_it=decay_it, decay_rate=0.16)

    # Intialize training
    params = [p for p in model.parameters() if p.requires_grad] # only keep trainable parameters
    optimizer = optim.Adam(params, lr=lr_scheduler.get_cur_lr(0)) # to check after bnb.optim.Adam8bit(params, lr=lr_scheduler.get_cur_lr(0))
    trainer = OSRTSamTrainer(model, optimizer, cfg, device, out_dir, train_dataset.render_kwargs)
    checkpoint = Checkpoint(out_dir, device=device, encoder=encoder_module,
                            decoder=decoder_module, optimizer=optimizer)

    # Try to automatically resume
    try:
        if os.path.exists(os.path.join(out_dir, f'model_{max_it}.pt')):
            load_dict = checkpoint.load(f'model_{max_it}.pt')
        else:
            load_dict = checkpoint.load('model.pt')
    except FileNotFoundError:
        load_dict = dict()

    epoch_it = load_dict.get('epoch_it', -1)
    it = load_dict.get('it', -1)
    time_elapsed = load_dict.get('t', 0.)
    run_id = load_dict.get('run_id', None)
    metric_val_best = load_dict.get(
        'loss_val_best', -model_selection_sign * np.inf)

    print(f'Current best validation metric ({model_selection_metric}): {metric_val_best:.8f}.')

    if args.wandb:
        import wandb
        if run_id is None:
            run_id =  wandb.util.generate_id()
            print(f'Sampled new wandb run_id {run_id}.')
        else:
            print(f'Resuming wandb with existing run_id {run_id}.')
        # Tell in which mode to launch the logging in W&B (for offline cluster)
        if args.offline_log:
            mode = "offline"
        else:
            mode = "online"
        wandb.init(project='osrt', name=os.path.dirname(args.config),
                   id=run_id, resume=True, mode=mode, sync_tensorboard=True) 
        wandb.config = cfg

        prof = profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], 
                schedule=schedule(wait=1, warmup=1, active=12, repeat=1), 
                on_trace_ready=tensorboard_trace_handler(f"osrt_{run_id}"),
                profile_memory=True, 
                record_shapes=False,
                with_stack=False,
                with_flops=False)
        prof.start()
    else:
        prof = None

    if args.print_model:
        print(model)
        for name, param in model.named_parameters():
            if param.requires_grad:
                print(f'{name:80}{str(list(param.data.shape)):20}{int(param.data.numel()):10d}')

    num_encoder_params = sum(p.numel() for p in model.encoder.parameters())
    num_decoder_params = sum(p.numel() for p in model.decoder.parameters())

    print('Number of parameters:')
    print(f'\tEncoder: {num_encoder_params}')

    if cfg['model']['encoder'] == 'osrt':
        num_srt_encoder_params = sum(p.numel() for p in model.encoder.module.srt_encoder.parameters())
        num_slotatt_params = sum(p.numel() for p in model.encoder.module.slot_attention.parameters())
        print(f'\t\tSRT Encoder: {num_srt_encoder_params}.')
        print(f'\t\tSlot Attention: {num_slotatt_params}.')
    print(f'\tDecoder: {num_decoder_params}')
    print(f'Total: {num_encoder_params + num_decoder_params}')

    # Shorthands
    print_every = cfg['training']['print_every']
    checkpoint_every = cfg['training']['checkpoint_every']
    validate_every = cfg['training']['validate_every']
    visualize_every = cfg['training']['visualize_every']
    backup_every = cfg['training']['backup_every']

    training_has_begun = False
    # Training loop
    while True:
        epoch_it += 1
        if train_sampler is not None:
            train_sampler.set_epoch(epoch_it)
        for batch in train_loader:
            it += 1
            training_has_begun = True

            # Special responsibilities for the main process
            if rank == 0:
                checkpoint_scalars = {'epoch_it': epoch_it,
                                        'it': it,
                                        't': time_elapsed,
                                        'loss_val_best': metric_val_best,
                                        'run_id': run_id}
                # Save checkpoint
                if (checkpoint_every > 0 and (it % checkpoint_every) == 0) and it > 0:
                    checkpoint.save('model.pt', **checkpoint_scalars)
                    print('Checkpoint saved.')

                # Backup if necessary
                if (backup_every > 0 and (it % backup_every) == 0):
                    checkpoint.save('model_%d.pt' % it, **checkpoint_scalars)
                    print('Backup checkpoint saved.')

                # Visualize output
                if args.visnow or (it > 0 and visualize_every > 0 and (it % visualize_every) == 0):
                    print('Visualizing...')
                    trainer.visualize(data_vis_val, mode='val')
                    trainer.visualize(data_vis_train, mode='train')

            # Run evaluation
            if args.evalnow or (it > 0 and validate_every > 0 and (it % validate_every) == 0):
                print('Evaluating...')
                eval_dict = trainer.evaluate(val_loader)
                metric_val = eval_dict[model_selection_metric]
                print(f'Validation metric ({model_selection_metric}): {metric_val:.4f}')

                if args.wandb:
                    wandb.log(eval_dict, step=it)

                if model_selection_sign * (metric_val - metric_val_best) > 0:
                    metric_val_best = metric_val
                    if rank == 0:
                        checkpoint_scalars['loss_val_best'] = metric_val_best
                        print(f'New best model (loss {metric_val_best:.6f})')
                        checkpoint.save('model_best.pt', **checkpoint_scalars)

            # Update learning rate
            new_lr = lr_scheduler.get_cur_lr(it)
            for param_group in optimizer.param_groups:
                param_group['lr'] = new_lr

            # Run training step
            t0 = time.perf_counter()
            loss, log_dict = trainer.train_step(batch, it)
            if prof:
                prof.step()
            time_elapsed += time.perf_counter() - t0
            time_elapsed_str = str(datetime.timedelta(seconds=time_elapsed))
            log_dict['lr'] = new_lr


            # Print progress
            if print_every > 0 and (it % print_every) == 0:
                log_str = ['{}={:f}'.format(k, v) for k, v in log_dict.items()]
                print(out_dir, 't=%s [Epoch %02d] it=%03d, loss=%.4f'
                        % (time_elapsed_str, epoch_it, it, loss), log_str)
                log_dict['t'] = time_elapsed
                if args.wandb:
                    wandb.log(log_dict, step=it)

            args.evalnow = False
            args.visnow = False

            if it >= max_it:
                print('Iteration limit reached. Exiting.')
                if rank == 0:
                    checkpoint.save('model.pt', **checkpoint_scalars)
                if prof:
                    prof.stop()
                exit(0)

        if prof:
            prof.stop()