import argparse import yaml from osrt.model import LitSlotAttentionAutoEncoder from osrt import data from torch.utils.data import DataLoader import lightning as pl from lightning.pytorch.loggers.wandb import WandbLogger from lightning.pytorch.callbacks import ModelCheckpoint import torch def main(): # Arguments parser = argparse.ArgumentParser( description='Train a 3D scene representation model.' ) parser.add_argument('config', type=str, help="Where to save the checkpoints.") parser.add_argument('--wandb', action='store_true', help='Log run to Weights and Biases.') parser.add_argument('--seed', type=int, default=0, help='Random seed.') parser.add_argument('--ckpt', type=str, default=".", help='Model checkpoint path') args = parser.parse_args() with open(args.config, 'r') as f: cfg = yaml.load(f, Loader=yaml.CLoader) ### Set random seed. pl.seed_everything(42, workers=True) ### Hyperparameters of the model. batch_size = cfg["training"]["batch_size"] num_gpus = cfg["training"]["num_gpus"] num_slots = cfg["model"]["num_slots"] num_iterations = cfg["model"]["iters"] resolution = (128, 128) #### Create datasets test_dataset = data.get_dataset('val', cfg['data']) test_dataloader = DataLoader( test_dataset, batch_size=batch_size, num_workers=cfg["training"]["num_workers"], shuffle=True, worker_init_fn=data.worker_init_fn) #### Create model model = LitSlotAttentionAutoEncoder(resolution, num_slots, num_iterations, cfg=cfg) checkpoint = torch.load(args.ckpt) model.load_state_dict(checkpoint['state_dict']) model.eval() trainer = pl.Trainer(accelerator="gpu", devices=num_gpus, strategy="auto") trainer.validate(model, dataloaders=test_dataloader) if __name__ == "__main__": main()