import datetime import time import torch import torch.nn as nn import argparse import yaml from osrt.model import LitSlotAttentionAutoEncoder from osrt import data from osrt.utils.visualize import visualize_slot_attention from torch.utils.data import DataLoader import torch.nn.functional as F from tqdm import tqdm import lightning as pl from lightning.pytorch.loggers.wandb import WandbLogger from lightning.pytorch.callbacks import ModelCheckpoint def main(): # Arguments parser = argparse.ArgumentParser( description='Train a 3D scene representation model.' ) parser.add_argument('config', type=str, help="Where to save the checkpoints.") parser.add_argument('--wandb', action='store_true', help='Log run to Weights and Biases.') parser.add_argument('--seed', type=int, default=0, help='Random seed.') parser.add_argument('--ckpt', type=str, default=".", help='Model checkpoint path') args = parser.parse_args() with open(args.config, 'r') as f: cfg = yaml.load(f, Loader=yaml.CLoader) ### Set random seed. pl.seed_everything(42, workers=True) ### Hyperparameters of the model. batch_size = cfg["training"]["batch_size"] num_gpus = cfg["training"]["num_gpus"] num_slots = cfg["model"]["num_slots"] num_iterations = cfg["model"]["iters"] num_train_steps = cfg["training"]["max_it"] warmup_steps = cfg["training"]["warmup_it"] decay_rate = cfg["training"]["decay_rate"] decay_steps = cfg["training"]["decay_it"] resolution = (128, 128) #### Create datasets train_dataset = data.get_dataset('train', cfg['data']) train_loader = DataLoader( train_dataset, batch_size=batch_size, num_workers=cfg["training"]["num_workers"], shuffle=True, worker_init_fn=data.worker_init_fn) val_dataset = data.get_dataset('val', cfg['data']) val_loader = DataLoader( val_dataset, batch_size=batch_size, num_workers=1, shuffle=True, worker_init_fn=data.worker_init_fn) #### Create model model = LitSlotAttentionAutoEncoder(resolution, num_slots, num_iterations, cfg=cfg) checkpoint_callback = ModelCheckpoint( save_top_k=10, monitor="val_psnr", mode="max", dirpath="./checkpoints" if cfg["model"]["model_type"] == "sa" else "./checkpoints_tsa", filename="slot_att-clevr3d-{epoch:02d}-psnr{val_psnr:.2f}.pth", ) trainer = pl.Trainer(accelerator="gpu", devices=num_gpus, profiler="simple", default_root_dir="./logs", logger=WandbLogger(project="slot-attention") if args.wandb else None, strategy="ddp" if num_gpus > 1 else "default", callbacks=[checkpoint_callback], deterministic=True, log_every_n_steps=100, max_steps=num_train_steps) trainer.fit(model, train_loader, val_loader) if __name__ == "__main__": main() #print(f"[TRAIN] Epoch : {epoch} || Step: {global_step}, Loss: {total_loss}, Time: {datetime.timedelta(seconds=time.time() - start)}") """ if not epoch % cfg["training"]["checkpoint_every"]: # Save the checkpoint of the model. ckpt['global_step'] = global_step ckpt['model_state_dict'] = model.state_dict() torch.save(ckpt, args.ckpt + '/ckpt_' + str(global_step) + '.pth') print(f"Saved checkpoint: {args.ckpt + '/ckpt_' + str(global_step) + '.pth'}") # We visualize some test data if not epoch % cfg["training"]["visualize_every"]: """