From 4b89db3377fd0bdac730bfb96a10d74c31131600 Mon Sep 17 00:00:00 2001 From: alexcbb <alexchapin@hotmail.fr> Date: Fri, 28 Jul 2023 17:07:02 +0200 Subject: [PATCH] Deactivate early stop, delete evaluate, change dir checkpoint' --- osrt/data/ycbv.py | 1 - train_sa.py | 13 +++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/osrt/data/ycbv.py b/osrt/data/ycbv.py index 1c67c7b..a815fd3 100644 --- a/osrt/data/ycbv.py +++ b/osrt/data/ycbv.py @@ -298,7 +298,6 @@ class YCBVideo2D(Dataset): max_objects (int): Load only scenes with at most this many objects. """ self.path = path - print(f"Get path {path}") self.mode = mode self.max_objects = max_objects diff --git a/train_sa.py b/train_sa.py index 5d763a0..e06a949 100644 --- a/train_sa.py +++ b/train_sa.py @@ -72,12 +72,12 @@ def main(): checkpoint_callback = ModelCheckpoint( monitor="val_psnr", mode="max", - dirpath="./checkpoints", - filename="ckpt-" + str(cfg["data"]["dataset"])+ "-slots:"+ str(cfg["model"]["num_slots"]) + "-" + str(cfg["model"]["model_type"]) +"-{epoch:02d}-psnr{val_psnr:.2f}", + dirpath=f"./checkpoints_{cfg['data']['dataset']}_{cfg['model']['model_type']}", + filename="ckpt-" + str(cfg["data"]["dataset"])+ "-slots:"+ str(cfg["model"]["num_slots"]) + "-" + str(cfg["model"]["model_type"]) +"-{epoch:02d}-{val_psnr:.2f}", save_weights_only=True # don't save optimizer states nor lr-scheduler, ... ) - early_stopping = EarlyStopping(monitor="val_psnr", mode="max") + #early_stopping = EarlyStopping(monitor="val_psnr", mode="max") trainer = pl.Trainer(accelerator="gpu", devices=num_gpus, @@ -85,20 +85,21 @@ def main(): default_root_dir="./logs", logger=WandbLogger(project="slot-att", offline=True) if args.wandb else None, strategy="ddp_find_unused_parameters_true" if num_gpus > 1 else "auto", - callbacks=[checkpoint_callback, early_stopping], + callbacks=[checkpoint_callback],#, early_stopping], log_every_n_steps=100, max_steps=num_train_steps, enable_model_summary=True) trainer.fit(model, train_loader, val_loader) + #### Evaluate the model - print(f"Begin testing : ") + """print(f"Begin testing : ") test_dataset = data.get_dataset('test', cfg['data']) test_loader = DataLoader( test_dataset, batch_size=batch_size, num_workers=8 if num_workers > 8 else 0, shuffle=True, worker_init_fn=data.worker_init_fn, pin_memory=True) - trainer.test(ckpt_path="best", dataloaders=test_loader, verbose=True) + trainer.test(ckpt_path="best", dataloaders=test_loader, verbose=True)""" print(f"Begin visualization : ") #### Create datasets -- GitLab