From b286dd0d9766fb9ad3d8c7959982c0687d6bf70a Mon Sep 17 00:00:00 2001
From: alexcbb <alexchapin@hotmail.fr>
Date: Fri, 28 Jul 2023 15:13:53 +0200
Subject: [PATCH] Remove useless params and clean some code

---
 runs/clevr/slot_att/config.yaml              | 10 ++++----
 runs/clevr/slot_att/config_tsa.yaml          | 10 ++++----
 runs/ycb/slot_att/config.yaml                |  4 +---
 runs/ycb/slot_att/config_tsa.yaml            |  4 +---
 render.slurm => slurm/render.slurm           |  0
 slurm/sa_train_clever.slurm                  | 24 ++++++++++++++++++++
 slurm/sa_train_ycb.slurm                     | 24 ++++++++++++++++++++
 train_clevr.slurm => slurm/train_clevr.slurm |  0
 train_msn.slurm => slurm/train_msn.slurm     |  0
 slurm/tsa_train_clevr.slurm                  | 24 ++++++++++++++++++++
 slurm/tsa_train_ycb.slurm                    | 24 ++++++++++++++++++++
 train_sa.py                                  | 23 ++++++++++++-------
 12 files changed, 121 insertions(+), 26 deletions(-)
 rename render.slurm => slurm/render.slurm (100%)
 create mode 100644 slurm/sa_train_clever.slurm
 create mode 100644 slurm/sa_train_ycb.slurm
 rename train_clevr.slurm => slurm/train_clevr.slurm (100%)
 rename train_msn.slurm => slurm/train_msn.slurm (100%)
 create mode 100644 slurm/tsa_train_clevr.slurm
 create mode 100644 slurm/tsa_train_ycb.slurm

diff --git a/runs/clevr/slot_att/config.yaml b/runs/clevr/slot_att/config.yaml
index d53f13d..5863fa6 100644
--- a/runs/clevr/slot_att/config.yaml
+++ b/runs/clevr/slot_att/config.yaml
@@ -9,16 +9,14 @@ model:
   hidden_dim: 128
   iters: 3
 training:
-  num_workers: 2 
-  num_gpus: 1
-  batch_size: 8 
-  max_it: 333000000
+  num_workers: 48 
+  num_gpus: 8
+  batch_size: 256 
+  max_it: 1000000
   warmup_it: 10000
   lr_warmup: 5000
   decay_rate: 0.5
   decay_it: 100000
-  visualize_every: 5000
-  validate_every: 5000
   checkpoint_every: 1000
   backup_every: 25000
 
diff --git a/runs/clevr/slot_att/config_tsa.yaml b/runs/clevr/slot_att/config_tsa.yaml
index 5c95927..1cb94fa 100644
--- a/runs/clevr/slot_att/config_tsa.yaml
+++ b/runs/clevr/slot_att/config_tsa.yaml
@@ -9,16 +9,14 @@ model:
   hidden_dim: 128
   iters: 3
 training:
-  num_workers: 2 
-  num_gpus: 1
-  batch_size: 8
-  max_it: 333000000
+  num_workers: 48
+  num_gpus: 8
+  batch_size: 256
+  max_it: 1000000
   warmup_it: 10000
   lr_warmup: 5000
   decay_rate: 0.5
   decay_it: 100000
-  visualize_every: 5000
-  validate_every: 5000
   checkpoint_every: 1000
   backup_every: 25000
 
diff --git a/runs/ycb/slot_att/config.yaml b/runs/ycb/slot_att/config.yaml
index ebc5132..ce3ef99 100644
--- a/runs/ycb/slot_att/config.yaml
+++ b/runs/ycb/slot_att/config.yaml
@@ -12,13 +12,11 @@ training:
   num_workers: 2 
   num_gpus: 1
   batch_size: 8 
-  max_it: 333000000
+  max_it: 1000000
   warmup_it: 10000
   lr_warmup: 5000
   decay_rate: 0.5
   decay_it: 100000
-  visualize_every: 5000
-  validate_every: 5000
   checkpoint_every: 1000
   backup_every: 25000
 
diff --git a/runs/ycb/slot_att/config_tsa.yaml b/runs/ycb/slot_att/config_tsa.yaml
index 0e2a05a..1ec79a8 100644
--- a/runs/ycb/slot_att/config_tsa.yaml
+++ b/runs/ycb/slot_att/config_tsa.yaml
@@ -12,13 +12,11 @@ training:
   num_workers: 2 
   num_gpus: 1
   batch_size: 8
-  max_it: 333000000
+  max_it: 1000000
   warmup_it: 10000
   lr_warmup: 5000
   decay_rate: 0.5
   decay_it: 100000
-  visualize_every: 5000
-  validate_every: 5000
   checkpoint_every: 1000
   backup_every: 25000
 
diff --git a/render.slurm b/slurm/render.slurm
similarity index 100%
rename from render.slurm
rename to slurm/render.slurm
diff --git a/slurm/sa_train_clever.slurm b/slurm/sa_train_clever.slurm
new file mode 100644
index 0000000..eca3468
--- /dev/null
+++ b/slurm/sa_train_clever.slurm
@@ -0,0 +1,24 @@
+#!/bin/bash
+
+#SBATCH --job-name=slot_att_clevr
+#SBATCH --output=logs/job.%j.out
+#SBATCH --error=logs/job.%j.err 
+
+#SBATCH --account=uli@v100
+#SBATCH --partition=gpu_p2
+#SBATCH --gres=gpu:8
+#SBATCH --nodes=1
+#SBATCH --ntasks-per-node=8 #number of MPI tasks per node (=number of GPUs per node)
+#SBATCH --exclusive
+#SBATCH --hint=nomultithread
+#SBATCH -t 20:00:00
+#SBATCH --mail-user=alexandre.chapin@ec-lyon.fr
+#SBATCH --mail-typ=FAIL
+#SBATCH --qos=qos_gpu-t3
+
+module purge
+echo ${SLURM_NODELIST}
+#module load cudnn/8.5.0.96-11.7-cuda
+module load pytorch-gpu/py3/2.0.0
+
+srun python train_sa.py runs/clevr/slot_att/config.yaml --wandb
diff --git a/slurm/sa_train_ycb.slurm b/slurm/sa_train_ycb.slurm
new file mode 100644
index 0000000..d44cbbc
--- /dev/null
+++ b/slurm/sa_train_ycb.slurm
@@ -0,0 +1,24 @@
+#!/bin/bash
+
+#SBATCH --job-name=slot_att_ycb
+#SBATCH --output=logs/job.%j.out
+#SBATCH --error=logs/job.%j.err 
+
+#SBATCH --account=uli@v100
+#SBATCH --partition=gpu_p2
+#SBATCH --gres=gpu:8
+#SBATCH --nodes=1
+#SBATCH --ntasks-per-node=8 #number of MPI tasks per node (=number of GPUs per node)
+#SBATCH --exclusive
+#SBATCH --hint=nomultithread
+#SBATCH -t 20:00:00
+#SBATCH --mail-user=alexandre.chapin@ec-lyon.fr
+#SBATCH --mail-typ=FAIL
+#SBATCH --qos=qos_gpu-t3
+
+module purge
+echo ${SLURM_NODELIST}
+#module load cudnn/8.5.0.96-11.7-cuda
+module load pytorch-gpu/py3/2.0.0
+
+srun python train_sa.py runs/ycb/slot_att/config.yaml --wandb
diff --git a/train_clevr.slurm b/slurm/train_clevr.slurm
similarity index 100%
rename from train_clevr.slurm
rename to slurm/train_clevr.slurm
diff --git a/train_msn.slurm b/slurm/train_msn.slurm
similarity index 100%
rename from train_msn.slurm
rename to slurm/train_msn.slurm
diff --git a/slurm/tsa_train_clevr.slurm b/slurm/tsa_train_clevr.slurm
new file mode 100644
index 0000000..8f1fcc7
--- /dev/null
+++ b/slurm/tsa_train_clevr.slurm
@@ -0,0 +1,24 @@
+#!/bin/bash
+
+#SBATCH --job-name=trans_slot_att_ycb
+#SBATCH --output=logs/job.%j.out
+#SBATCH --error=logs/job.%j.err 
+
+#SBATCH --account=uli@v100
+#SBATCH --partition=gpu_p2
+#SBATCH --gres=gpu:8
+#SBATCH --nodes=1
+#SBATCH --ntasks-per-node=8 #number of MPI tasks per node (=number of GPUs per node)
+#SBATCH --exclusive
+#SBATCH --hint=nomultithread
+#SBATCH -t 20:00:00
+#SBATCH --mail-user=alexandre.chapin@ec-lyon.fr
+#SBATCH --mail-typ=FAIL
+#SBATCH --qos=qos_gpu-t3
+
+module purge
+echo ${SLURM_NODELIST}
+#module load cudnn/8.5.0.96-11.7-cuda
+module load pytorch-gpu/py3/2.0.0
+
+srun python train_sa.py runs/clevr/slot_att/config_tsa.yaml --wandb
diff --git a/slurm/tsa_train_ycb.slurm b/slurm/tsa_train_ycb.slurm
new file mode 100644
index 0000000..673c63a
--- /dev/null
+++ b/slurm/tsa_train_ycb.slurm
@@ -0,0 +1,24 @@
+#!/bin/bash
+
+#SBATCH --job-name=trans_slot_att_ycb
+#SBATCH --output=logs/job.%j.out
+#SBATCH --error=logs/job.%j.err 
+
+#SBATCH --account=uli@v100
+#SBATCH --partition=gpu_p2
+#SBATCH --gres=gpu:8
+#SBATCH --nodes=1
+#SBATCH --ntasks-per-node=8 #number of MPI tasks per node (=number of GPUs per node)
+#SBATCH --exclusive
+#SBATCH --hint=nomultithread
+#SBATCH -t 20:00:00
+#SBATCH --mail-user=alexandre.chapin@ec-lyon.fr
+#SBATCH --mail-typ=FAIL
+#SBATCH --qos=qos_gpu-t3
+
+module purge
+echo ${SLURM_NODELIST}
+#module load cudnn/8.5.0.96-11.7-cuda
+module load pytorch-gpu/py3/2.0.0
+
+srun python train_sa.py runs/ycb/slot_att/config_tsa.yaml --wandb
diff --git a/train_sa.py b/train_sa.py
index 0397a7c..650d88e 100644
--- a/train_sa.py
+++ b/train_sa.py
@@ -32,6 +32,7 @@ def main():
     parser.add_argument('--wandb', action='store_true', help='Log run to Weights and Biases.')
     parser.add_argument('--seed', type=int, default=0, help='Random seed.')
     parser.add_argument('--ckpt', type=str, default=None, help='Model checkpoint path')
+    parser.add_argument('--profiler', action='store_true', help='Activate checkpoiting')
 
     args = parser.parse_args()
     with open(args.config, 'r') as f:
@@ -53,12 +54,12 @@ def main():
     #### Create datasets
     train_dataset = data.get_dataset('train', cfg['data'])
     train_loader = DataLoader(
-        train_dataset, batch_size=batch_size, num_workers=num_workers-9 if num_workers > 9 else 0,
+        train_dataset, batch_size=batch_size, num_workers=num_workers-8 if num_workers > 8 else 0,
         shuffle=True, worker_init_fn=data.worker_init_fn, pin_memory=True)
     
     val_dataset = data.get_dataset('val', cfg['data'])
     val_loader = DataLoader(
-        val_dataset, batch_size=batch_size, num_workers=8 if num_workers > 9 else 0,
+        val_dataset, batch_size=batch_size, num_workers=8 if num_workers > 8 else 0,
         shuffle=True, worker_init_fn=data.worker_init_fn, pin_memory=True)
 
     #### Create model
@@ -73,31 +74,37 @@ def main():
         mode="max",
         dirpath="./checkpoints",
         filename="ckpt-" +  str(cfg["data"]["dataset"])+ "-slots:"+ str(cfg["model"]["num_slots"]) + "-" + str(cfg["model"]["model_type"]) +"-{epoch:02d}-psnr{val_psnr:.2f}",
-        save_weights_only=True, # don't save optimizer states nor lr-scheduler, ...
-        every_n_train_steps=cfg["training"]["checkpoint_every"]
+        save_weights_only=True # don't save optimizer states nor lr-scheduler, ...
     )
 
     early_stopping = EarlyStopping(monitor="val_psnr", mode="max")
 
     trainer = pl.Trainer(accelerator="gpu", 
                          devices=num_gpus, 
-                         profiler="simple", 
+                         profiler="simple" if args.proffiler else None, 
                          default_root_dir="./logs", 
                          logger=WandbLogger(project="slot-att", offline=True) if args.wandb else None,
                          strategy="ddp_find_unused_parameters_true" if num_gpus > 1 else "auto", 
                          callbacks=[checkpoint_callback, early_stopping],
                          log_every_n_steps=100, 
-                         val_check_interval=cfg["training"]["validate_every"],
-                         check_val_every_n_epoch=None,
                          max_steps=num_train_steps, 
                          enable_model_summary=True)
 
     trainer.fit(model, train_loader, val_loader)
+    #### Evaluate the model
+    print(f"Begin testing : ")
+    test_dataset = data.get_dataset('test', cfg['data'])
+    test_loader = DataLoader(
+        test_dataset, batch_size=batch_size, num_workers=8 if num_workers > 8 else 0,
+        shuffle=True, worker_init_fn=data.worker_init_fn, pin_memory=True)
+    
+    trainer.test(ckpt_path="best", dataloaders=test_loader, verbose=True)
 
+    print(f"Begin visualization : ")
     #### Create datasets
     vis_dataset = data.get_dataset('train', cfg['data'])
     vis_loader = DataLoader(
-        vis_dataset, batch_size=1, num_workers=1,
+        vis_dataset, batch_size=1, num_workers=0,
         shuffle=True, worker_init_fn=data.worker_init_fn)
     
     device = model.device
-- 
GitLab