Skip to content
Snippets Groups Projects
train_sa.py 7.17 KiB
import datetime
import time
import torch
import torch.nn as nn
import torch.optim as optim
import argparse
import yaml

from osrt.model import SlotAttentionAutoEncoder
from osrt import data
from osrt.utils.visualize import visualize_slot_attention
from osrt.utils.common import mse2psnr

from torch.utils.data import DataLoader
import torch.nn.functional as F
from tqdm import tqdm

def train_step(batch, model, optimizer, device):
    """Perform a single training step."""
    input_image = torch.squeeze(batch.get('input_images').to(device), dim=1)
    input_image = F.interpolate(input_image, size=128)

    # Get the prediction of the model and compute the loss.
    preds = model(input_image)
    recon_combined, recons, masks, slots = preds
    input_image = input_image.permute(0, 2, 3, 1)
    loss_value = nn.MSELoss(recon_combined, input_image)
    del recons, masks, slots  # Unused.

    # Get and apply gradients.
    optimizer.zero_grad()
    loss_value.backward()
    optimizer.step()

    return loss_value.item()

def eval_step(batch, model, device):
    """Perform a single eval step."""
    input_image = torch.squeeze(batch.get('input_images').to(device), dim=1)
    input_image = F.interpolate(input_image, size=128)

    # Get the prediction of the model and compute the loss.
    preds = model(input_image)
    recon_combined, recons, masks, slots = preds
    input_image = input_image.permute(0, 2, 3, 1)
    loss_value = nn.MSELoss(recon_combined, input_image)
    del recons, masks, slots  # Unused.
    psnr = mse2psnr(loss_value)

    return loss_value.item(), psnr.item()

def main():
    # Arguments
    parser = argparse.ArgumentParser(
        description='Train a 3D scene representation model.'
    )
    parser.add_argument('config', type=str, help="Where to save the checkpoints.")
    parser.add_argument('--wandb', action='store_true', help='Log run to Weights and Biases.')
    parser.add_argument('--seed', type=int, default=0, help='Random seed.')
    parser.add_argument('--ckpt', type=str, default=".", help='Model checkpoint path')

    args = parser.parse_args()
    with open(args.config, 'r') as f:
        cfg = yaml.load(f, Loader=yaml.CLoader)

    ### Set random seed.
    torch.manual_seed(args.seed)

    ### Hyperparameters of the model.
    batch_size = cfg["training"]["batch_size"]
    num_slots = cfg["model"]["num_slots"]
    num_iterations = cfg["model"]["iters"]
    base_learning_rate = 0.0004
    num_train_steps = cfg["training"]["max_it"]
    warmup_steps = cfg["training"]["warmup_it"]
    decay_rate = cfg["training"]["decay_rate"]
    decay_steps = cfg["training"]["decay_it"]
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    resolution = (128, 128)
    
    #### Create datasets
    train_dataset = data.get_dataset('train', cfg['data'])
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, num_workers=cfg["training"]["num_workers"],
        shuffle=True, worker_init_fn=data.worker_init_fn)
    
    val_dataset = data.get_dataset('val', cfg['data'])
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, num_workers=1,
        shuffle=True, worker_init_fn=data.worker_init_fn)
    
    vis_dataset = data.get_dataset('test', cfg['data'])
    vis_loader = DataLoader(
        vis_dataset, batch_size=batch_size, num_workers=cfg["training"]["num_workers"],
        shuffle=True, worker_init_fn=data.worker_init_fn)

    #### Create model
    model = SlotAttentionAutoEncoder(resolution, num_slots, num_iterations).to(device)
    num_params = sum(p.numel() for p in model.parameters())

    print('Number of parameters:')
    print(f'Model slot attention: {num_params}')

    optimizer = optim.Adam(model.parameters(), lr=base_learning_rate, eps=1e-08)

    #### Prepare checkpoint manager.
    global_step = 0
    ckpt = {
        'network': model,
        'optimizer': optimizer,
        'global_step': global_step
    }
    ckpt_manager = torch.save(ckpt, args.ckpt + '/ckpt.pth')
    # ckpt = torch.load(args.ckpt + '/ckpt.pth')
    model = ckpt['network']
    optimizer = ckpt['optimizer']
    global_step = ckpt['global_step']

    """ TODO : setup wandb
    if args.wandb:
        if run_id is None:
            run_id =  wandb.util.generate_id()
            print(f'Sampled new wandb run_id {run_id}.')
        else:
            print(f'Resuming wandb with existing run_id {run_id}.')
        # Tell in which mode to launch the logging in W&B (for offline cluster)
        if args.offline_log:
            mode = "offline"
        else:
            mode = "online"
        wandb.init(project='osrt', name=os.path.dirname(args.config),
                   id=run_id, resume=True, mode=mode, sync_tensorboard=True) 
        wandb.config = cfg"""

    start = time.time()
    epochs = num_train_steps // len(train_loader)
    for epoch in range(epochs):
        total_loss = 0
        model.train()
        for batch in tqdm(train_loader):
            # Learning rate warm-up.
            if global_step < warmup_steps:
                learning_rate = base_learning_rate * global_step / warmup_steps
            else:
                learning_rate = base_learning_rate
            learning_rate = learning_rate * (decay_rate ** (global_step / decay_steps))
            for param_group in optimizer.param_groups:
                param_group['lr'] = learning_rate

            total_loss += train_step(batch, model, optimizer, device)
            global_step += 1

        total_loss /= len(train_loader)
        # We save the checkpoints
        if not epoch % cfg["training"]["checkpoint_every"]:
            # Save the checkpoint of the model.
            ckpt['global_step'] = global_step
            ckpt['model_state_dict'] = model.state_dict()
            torch.save(ckpt, args.ckpt + '/ckpt_' + str(global_step) + '.pth')
            print(f"Saved checkpoint: {args.ckpt + '/ckpt_' + str(global_step) + '.pth'}")

        # We visualize some test data
        if not epoch % cfg["training"]["visualize_every"]:
            image = torch.squeeze(next(iter(vis_loader)).get('input_images').to(device), dim=1)
            image = F.interpolate(image, size=128)
            image = image.to(device)
            recon_combined, recons, masks, slots = model(image)
            visualize_slot_attention(num_slots, image, recon_combined, recons, masks, folder_save=args.ckpt, step=global_step, save_file=True)
        # Log the training loss.
        if not epoch % cfg["training"]["print_every"]:
            print(f"[TRAIN] Epoch : {epoch} || Step: {global_step}, Loss: {total_loss}, Time: {datetime.timedelta(seconds=time.time() - start)}")
        # We visualize some test data
        if not epoch % cfg["training"]["validate_every"]:
            val_loss = 0
            val_psnr = 0
            model.eval()
            for batch in tqdm(val_loader):
                mse, psnr = eval_step(batch, model, device)
                val_loss += mse
                val_psnr += psnr
            val_loss /= len(val_loader)
            val_psnr /= len(val_loader)
            print(f"[EVAL] Epoch : {epoch} || Loss (MSE): {val_loss}; PSNR: {val_psnr}, Time: {datetime.timedelta(seconds=time.time() - start)}")
            model.train()
                        
if __name__ == "__main__":
    main()