Skip to content
Snippets Groups Projects
evaluate_sa.py 1.86 KiB

import argparse
import yaml

from osrt.model import LitSlotAttentionAutoEncoder
from osrt import data

from torch.utils.data import DataLoader

import lightning as pl
from lightning.pytorch.loggers.wandb import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint

import torch

def main():
    # Arguments
    parser = argparse.ArgumentParser(
        description='Train a 3D scene representation model.'
    )
    parser.add_argument('config', type=str, help="Where to save the checkpoints.")
    parser.add_argument('--wandb', action='store_true', help='Log run to Weights and Biases.')
    parser.add_argument('--seed', type=int, default=0, help='Random seed.')
    parser.add_argument('--ckpt', type=str, default=".", help='Model checkpoint path')

    args = parser.parse_args()
    with open(args.config, 'r') as f:
        cfg = yaml.load(f, Loader=yaml.CLoader)

    ### Set random seed.
    pl.seed_everything(42, workers=True)

    ### Hyperparameters of the model.
    batch_size = cfg["training"]["batch_size"]
    num_gpus = cfg["training"]["num_gpus"]
    num_slots = cfg["model"]["num_slots"]
    num_iterations = cfg["model"]["iters"]
    resolution = (128, 128)
    
    #### Create datasets
    test_dataset = data.get_dataset('val', cfg['data'])
    test_dataloader = DataLoader(
        test_dataset, batch_size=batch_size, num_workers=cfg["training"]["num_workers"],
        shuffle=True, worker_init_fn=data.worker_init_fn)

    #### Create model
    model = LitSlotAttentionAutoEncoder(resolution, num_slots, num_iterations, cfg=cfg)
    checkpoint = torch.load(args.ckpt)

    model.load_state_dict(checkpoint['state_dict'])
    model.eval()

    trainer = pl.Trainer(accelerator="gpu", devices=num_gpus,
                         strategy="auto")
    
    trainer.validate(model, dataloaders=test_dataloader)

                
if __name__ == "__main__":
    main()