From 4b89db3377fd0bdac730bfb96a10d74c31131600 Mon Sep 17 00:00:00 2001
From: alexcbb <alexchapin@hotmail.fr>
Date: Fri, 28 Jul 2023 17:07:02 +0200
Subject: [PATCH] Deactivate early stop, delete evaluate, change dir
 checkpoint'

---
 osrt/data/ycbv.py |  1 -
 train_sa.py       | 13 +++++++------
 2 files changed, 7 insertions(+), 7 deletions(-)

diff --git a/osrt/data/ycbv.py b/osrt/data/ycbv.py
index 1c67c7b..a815fd3 100644
--- a/osrt/data/ycbv.py
+++ b/osrt/data/ycbv.py
@@ -298,7 +298,6 @@ class YCBVideo2D(Dataset):
             max_objects (int): Load only scenes with at most this many objects.
         """
         self.path = path
-        print(f"Get path {path}")
         self.mode = mode
         self.max_objects = max_objects
 
diff --git a/train_sa.py b/train_sa.py
index 5d763a0..e06a949 100644
--- a/train_sa.py
+++ b/train_sa.py
@@ -72,12 +72,12 @@ def main():
     checkpoint_callback = ModelCheckpoint(
         monitor="val_psnr",
         mode="max",
-        dirpath="./checkpoints",
-        filename="ckpt-" +  str(cfg["data"]["dataset"])+ "-slots:"+ str(cfg["model"]["num_slots"]) + "-" + str(cfg["model"]["model_type"]) +"-{epoch:02d}-psnr{val_psnr:.2f}",
+        dirpath=f"./checkpoints_{cfg['data']['dataset']}_{cfg['model']['model_type']}",
+        filename="ckpt-" +  str(cfg["data"]["dataset"])+ "-slots:"+ str(cfg["model"]["num_slots"]) + "-" + str(cfg["model"]["model_type"]) +"-{epoch:02d}-{val_psnr:.2f}",
         save_weights_only=True # don't save optimizer states nor lr-scheduler, ...
     )
 
-    early_stopping = EarlyStopping(monitor="val_psnr", mode="max")
+    #early_stopping = EarlyStopping(monitor="val_psnr", mode="max")
 
     trainer = pl.Trainer(accelerator="gpu", 
                          devices=num_gpus, 
@@ -85,20 +85,21 @@ def main():
                          default_root_dir="./logs", 
                          logger=WandbLogger(project="slot-att", offline=True) if args.wandb else None,
                          strategy="ddp_find_unused_parameters_true" if num_gpus > 1 else "auto", 
-                         callbacks=[checkpoint_callback, early_stopping],
+                         callbacks=[checkpoint_callback],#, early_stopping],
                          log_every_n_steps=100, 
                          max_steps=num_train_steps, 
                          enable_model_summary=True)
 
     trainer.fit(model, train_loader, val_loader)
+
     #### Evaluate the model
-    print(f"Begin testing : ")
+    """print(f"Begin testing : ")
     test_dataset = data.get_dataset('test', cfg['data'])
     test_loader = DataLoader(
         test_dataset, batch_size=batch_size, num_workers=8 if num_workers > 8 else 0,
         shuffle=True, worker_init_fn=data.worker_init_fn, pin_memory=True)
     
-    trainer.test(ckpt_path="best", dataloaders=test_loader, verbose=True)
+    trainer.test(ckpt_path="best", dataloaders=test_loader, verbose=True)"""
 
     print(f"Begin visualization : ")
     #### Create datasets
-- 
GitLab