diff --git a/osrt/data/ycbv.py b/osrt/data/ycbv.py index 1c67c7b5309e2ff24e64f2e5bdb86dffb87e76d7..a815fd377b416f0d29820998df51243eebc3038c 100644 --- a/osrt/data/ycbv.py +++ b/osrt/data/ycbv.py @@ -298,7 +298,6 @@ class YCBVideo2D(Dataset): max_objects (int): Load only scenes with at most this many objects. """ self.path = path - print(f"Get path {path}") self.mode = mode self.max_objects = max_objects diff --git a/train_sa.py b/train_sa.py index 5d763a05fb3727e1af42ceaf5a891ea374cfca73..e06a949771811805ddfdf8a5edc804139ce4a6c6 100644 --- a/train_sa.py +++ b/train_sa.py @@ -72,12 +72,12 @@ def main(): checkpoint_callback = ModelCheckpoint( monitor="val_psnr", mode="max", - dirpath="./checkpoints", - filename="ckpt-" + str(cfg["data"]["dataset"])+ "-slots:"+ str(cfg["model"]["num_slots"]) + "-" + str(cfg["model"]["model_type"]) +"-{epoch:02d}-psnr{val_psnr:.2f}", + dirpath=f"./checkpoints_{cfg['data']['dataset']}_{cfg['model']['model_type']}", + filename="ckpt-" + str(cfg["data"]["dataset"])+ "-slots:"+ str(cfg["model"]["num_slots"]) + "-" + str(cfg["model"]["model_type"]) +"-{epoch:02d}-{val_psnr:.2f}", save_weights_only=True # don't save optimizer states nor lr-scheduler, ... ) - early_stopping = EarlyStopping(monitor="val_psnr", mode="max") + #early_stopping = EarlyStopping(monitor="val_psnr", mode="max") trainer = pl.Trainer(accelerator="gpu", devices=num_gpus, @@ -85,20 +85,21 @@ def main(): default_root_dir="./logs", logger=WandbLogger(project="slot-att", offline=True) if args.wandb else None, strategy="ddp_find_unused_parameters_true" if num_gpus > 1 else "auto", - callbacks=[checkpoint_callback, early_stopping], + callbacks=[checkpoint_callback],#, early_stopping], log_every_n_steps=100, max_steps=num_train_steps, enable_model_summary=True) trainer.fit(model, train_loader, val_loader) + #### Evaluate the model - print(f"Begin testing : ") + """print(f"Begin testing : ") test_dataset = data.get_dataset('test', cfg['data']) test_loader = DataLoader( test_dataset, batch_size=batch_size, num_workers=8 if num_workers > 8 else 0, shuffle=True, worker_init_fn=data.worker_init_fn, pin_memory=True) - trainer.test(ckpt_path="best", dataloaders=test_loader, verbose=True) + trainer.test(ckpt_path="best", dataloaders=test_loader, verbose=True)""" print(f"Begin visualization : ") #### Create datasets