import torch import torch.nn as nn import argparse import yaml from osrt.model import LitSlotAttentionAutoEncoder from osrt import data from osrt.utils.visualization_utils import visualize_slot_attention from osrt.utils.common import mse2psnr from torch.utils.data import DataLoader import torch.nn.functional as F import os import lightning as pl from lightning.pytorch.loggers.wandb import WandbLogger from lightning.pytorch.callbacks import ModelCheckpoint import warnings from lightning.pytorch.utilities.warnings import PossibleUserWarning from lightning.pytorch.callbacks.early_stopping import EarlyStopping # Ignore all warnings that could be false positives : cf https://lightning.ai/docs/pytorch/stable/advanced/speed.html warnings.filterwarnings("ignore", category=PossibleUserWarning) def main(): # Arguments parser = argparse.ArgumentParser( description='Train a 3D scene representation model.' ) parser.add_argument('config', type=str, help="Where to save the checkpoints.") parser.add_argument('--wandb', action='store_true', help='Log run to Weights and Biases.') parser.add_argument('--seed', type=int, default=0, help='Random seed.') parser.add_argument('--ckpt', type=str, default=None, help='Model checkpoint path') args = parser.parse_args() with open(args.config, 'r') as f: cfg = yaml.load(f, Loader=yaml.CLoader) ### Set random seed. pl.seed_everything(42, workers=True) ### Hyperparameters of the model. batch_size = cfg["training"]["batch_size"] num_gpus = cfg["training"]["num_gpus"] num_slots = cfg["model"]["num_slots"] num_iterations = cfg["model"]["iters"] num_train_steps = cfg["training"]["max_it"] resolution = (128, 128) print(f"Number of CPU Cores : {os.cpu_count()}") #### Create datasets train_dataset = data.get_dataset('train', cfg['data']) val_every = val_every // len(train_dataset) train_loader = DataLoader( train_dataset, batch_size=batch_size, num_workers=cfg["training"]["num_workers"]-9, shuffle=True, worker_init_fn=data.worker_init_fn, pin_memory=True) val_dataset = data.get_dataset('val', cfg['data']) val_loader = DataLoader( val_dataset, batch_size=batch_size, num_workers=8, shuffle=True, worker_init_fn=data.worker_init_fn, pin_memory=True) #### Create model model = LitSlotAttentionAutoEncoder(resolution, num_slots, num_iterations, cfg=cfg) if args.ckpt: checkpoint = torch.load(args.ckpt) model.load_state_dict(checkpoint['state_dict']) checkpoint_callback = ModelCheckpoint( monitor="val_psnr", mode="max", dirpath="./checkpoints", filename="ckpt-" + str(cfg["data"]["dataset"]) + "-" + str(cfg["model"]["model_type"]) +"-{epoch:02d}-psnr{val_psnr:.2f}", save_weights_only=True, # don't save optimizer states nor lr-scheduler, ... every_n_train_steps=cfg["training"]["checkpoint_every"] ) early_stopping = EarlyStopping(monitor="val_psnr", mode="max") trainer = pl.Trainer(accelerator="gpu", devices=num_gpus, profiler="simple", default_root_dir="./logs", logger=WandbLogger(project="slot-att") if args.wandb else None, strategy="ddp" if num_gpus > 1 else "auto", callbacks=[checkpoint_callback, early_stopping], log_every_n_steps=100, val_check_interval=cfg["training"]["validate_every"], max_steps=num_train_steps, enable_model_summary=True) trainer.fit(model, train_loader, val_loader) #### Create datasets vis_dataset = data.get_dataset('train', cfg['data']) vis_loader = DataLoader( vis_dataset, batch_size=1, num_workers=1, shuffle=True, worker_init_fn=data.worker_init_fn) device = model.device #### Visualize image image = torch.squeeze(next(iter(vis_loader)).get('input_images').to(device), dim=1) image = F.interpolate(image, size=128) image = image.to(device) recon_combined, recons, masks, slots, _ = model(image) loss = nn.MSELoss() loss_value = loss(recon_combined, image) psnr = mse2psnr(loss_value) visualize_slot_attention(num_slots, image, recon_combined, recons, masks, folder_save=args.ckpt, step=0, save_file=True) if __name__ == "__main__": main()