-
Alexandre Chapin authored7a1aa37a
evaluate_sa.py 1.86 KiB
import argparse
import yaml
from osrt.model import LitSlotAttentionAutoEncoder
from osrt import data
from torch.utils.data import DataLoader
import lightning as pl
from lightning.pytorch.loggers.wandb import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint
import torch
def main():
# Arguments
parser = argparse.ArgumentParser(
description='Train a 3D scene representation model.'
)
parser.add_argument('config', type=str, help="Where to save the checkpoints.")
parser.add_argument('--wandb', action='store_true', help='Log run to Weights and Biases.')
parser.add_argument('--seed', type=int, default=0, help='Random seed.')
parser.add_argument('--ckpt', type=str, default=".", help='Model checkpoint path')
args = parser.parse_args()
with open(args.config, 'r') as f:
cfg = yaml.load(f, Loader=yaml.CLoader)
### Set random seed.
pl.seed_everything(42, workers=True)
### Hyperparameters of the model.
batch_size = cfg["training"]["batch_size"]
num_gpus = cfg["training"]["num_gpus"]
num_slots = cfg["model"]["num_slots"]
num_iterations = cfg["model"]["iters"]
resolution = (128, 128)
#### Create datasets
test_dataset = data.get_dataset('val', cfg['data'])
test_dataloader = DataLoader(
test_dataset, batch_size=batch_size, num_workers=cfg["training"]["num_workers"],
shuffle=True, worker_init_fn=data.worker_init_fn)
#### Create model
model = LitSlotAttentionAutoEncoder(resolution, num_slots, num_iterations, cfg=cfg)
checkpoint = torch.load(args.ckpt)
model.load_state_dict(checkpoint['state_dict'])
model.eval()
trainer = pl.Trainer(accelerator="gpu", devices=num_gpus,
strategy="auto")
trainer.validate(model, dataloaders=test_dataloader)
if __name__ == "__main__":
main()