Skip to content
Snippets Groups Projects
Commit 4b89db33 authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Deactivate early stop, delete evaluate, change dir checkpoint'

parent 0c468c18
No related branches found
No related tags found
No related merge requests found
......@@ -298,7 +298,6 @@ class YCBVideo2D(Dataset):
max_objects (int): Load only scenes with at most this many objects.
"""
self.path = path
print(f"Get path {path}")
self.mode = mode
self.max_objects = max_objects
......
......@@ -72,12 +72,12 @@ def main():
checkpoint_callback = ModelCheckpoint(
monitor="val_psnr",
mode="max",
dirpath="./checkpoints",
filename="ckpt-" + str(cfg["data"]["dataset"])+ "-slots:"+ str(cfg["model"]["num_slots"]) + "-" + str(cfg["model"]["model_type"]) +"-{epoch:02d}-psnr{val_psnr:.2f}",
dirpath=f"./checkpoints_{cfg['data']['dataset']}_{cfg['model']['model_type']}",
filename="ckpt-" + str(cfg["data"]["dataset"])+ "-slots:"+ str(cfg["model"]["num_slots"]) + "-" + str(cfg["model"]["model_type"]) +"-{epoch:02d}-{val_psnr:.2f}",
save_weights_only=True # don't save optimizer states nor lr-scheduler, ...
)
early_stopping = EarlyStopping(monitor="val_psnr", mode="max")
#early_stopping = EarlyStopping(monitor="val_psnr", mode="max")
trainer = pl.Trainer(accelerator="gpu",
devices=num_gpus,
......@@ -85,20 +85,21 @@ def main():
default_root_dir="./logs",
logger=WandbLogger(project="slot-att", offline=True) if args.wandb else None,
strategy="ddp_find_unused_parameters_true" if num_gpus > 1 else "auto",
callbacks=[checkpoint_callback, early_stopping],
callbacks=[checkpoint_callback],#, early_stopping],
log_every_n_steps=100,
max_steps=num_train_steps,
enable_model_summary=True)
trainer.fit(model, train_loader, val_loader)
#### Evaluate the model
print(f"Begin testing : ")
"""print(f"Begin testing : ")
test_dataset = data.get_dataset('test', cfg['data'])
test_loader = DataLoader(
test_dataset, batch_size=batch_size, num_workers=8 if num_workers > 8 else 0,
shuffle=True, worker_init_fn=data.worker_init_fn, pin_memory=True)
trainer.test(ckpt_path="best", dataloaders=test_loader, verbose=True)
trainer.test(ckpt_path="best", dataloaders=test_loader, verbose=True)"""
print(f"Begin visualization : ")
#### Create datasets
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment