diff --git a/runs/clevr/slot_att/config.yaml b/runs/clevr/slot_att/config.yaml
index d53f13d186af8e2157e0ae1e0984ec58c8a9642b..5863fa6bb9a8ec23a604cf967ca45edf3eebfead 100644
--- a/runs/clevr/slot_att/config.yaml
+++ b/runs/clevr/slot_att/config.yaml
@@ -9,16 +9,14 @@ model:
   hidden_dim: 128
   iters: 3
 training:
-  num_workers: 2 
-  num_gpus: 1
-  batch_size: 8 
-  max_it: 333000000
+  num_workers: 48 
+  num_gpus: 8
+  batch_size: 256 
+  max_it: 1000000
   warmup_it: 10000
   lr_warmup: 5000
   decay_rate: 0.5
   decay_it: 100000
-  visualize_every: 5000
-  validate_every: 5000
   checkpoint_every: 1000
   backup_every: 25000
 
diff --git a/runs/clevr/slot_att/config_tsa.yaml b/runs/clevr/slot_att/config_tsa.yaml
index 5c9592739a292f0557e5386040dc180203095b78..1cb94fa1f35339df08fd634fc1db02da2c37f924 100644
--- a/runs/clevr/slot_att/config_tsa.yaml
+++ b/runs/clevr/slot_att/config_tsa.yaml
@@ -9,16 +9,14 @@ model:
   hidden_dim: 128
   iters: 3
 training:
-  num_workers: 2 
-  num_gpus: 1
-  batch_size: 8
-  max_it: 333000000
+  num_workers: 48
+  num_gpus: 8
+  batch_size: 256
+  max_it: 1000000
   warmup_it: 10000
   lr_warmup: 5000
   decay_rate: 0.5
   decay_it: 100000
-  visualize_every: 5000
-  validate_every: 5000
   checkpoint_every: 1000
   backup_every: 25000
 
diff --git a/runs/ycb/slot_att/config.yaml b/runs/ycb/slot_att/config.yaml
index ebc5132f030cfbcb7720e5226fdd97a272414880..ce3ef9944e70213cd512f56586464d735fceb69c 100644
--- a/runs/ycb/slot_att/config.yaml
+++ b/runs/ycb/slot_att/config.yaml
@@ -12,13 +12,11 @@ training:
   num_workers: 2 
   num_gpus: 1
   batch_size: 8 
-  max_it: 333000000
+  max_it: 1000000
   warmup_it: 10000
   lr_warmup: 5000
   decay_rate: 0.5
   decay_it: 100000
-  visualize_every: 5000
-  validate_every: 5000
   checkpoint_every: 1000
   backup_every: 25000
 
diff --git a/runs/ycb/slot_att/config_tsa.yaml b/runs/ycb/slot_att/config_tsa.yaml
index 0e2a05a7a6a25fc626b602c57bb6983049b7b7d8..1ec79a88799c95eb518d4dc695b3e3a3f672a8a3 100644
--- a/runs/ycb/slot_att/config_tsa.yaml
+++ b/runs/ycb/slot_att/config_tsa.yaml
@@ -12,13 +12,11 @@ training:
   num_workers: 2 
   num_gpus: 1
   batch_size: 8
-  max_it: 333000000
+  max_it: 1000000
   warmup_it: 10000
   lr_warmup: 5000
   decay_rate: 0.5
   decay_it: 100000
-  visualize_every: 5000
-  validate_every: 5000
   checkpoint_every: 1000
   backup_every: 25000
 
diff --git a/render.slurm b/slurm/render.slurm
similarity index 100%
rename from render.slurm
rename to slurm/render.slurm
diff --git a/slurm/sa_train_clever.slurm b/slurm/sa_train_clever.slurm
new file mode 100644
index 0000000000000000000000000000000000000000..eca34681e3197aeb47a61616671d443c25571056
--- /dev/null
+++ b/slurm/sa_train_clever.slurm
@@ -0,0 +1,24 @@
+#!/bin/bash
+
+#SBATCH --job-name=slot_att_clevr
+#SBATCH --output=logs/job.%j.out
+#SBATCH --error=logs/job.%j.err 
+
+#SBATCH --account=uli@v100
+#SBATCH --partition=gpu_p2
+#SBATCH --gres=gpu:8
+#SBATCH --nodes=1
+#SBATCH --ntasks-per-node=8 #number of MPI tasks per node (=number of GPUs per node)
+#SBATCH --exclusive
+#SBATCH --hint=nomultithread
+#SBATCH -t 20:00:00
+#SBATCH --mail-user=alexandre.chapin@ec-lyon.fr
+#SBATCH --mail-typ=FAIL
+#SBATCH --qos=qos_gpu-t3
+
+module purge
+echo ${SLURM_NODELIST}
+#module load cudnn/8.5.0.96-11.7-cuda
+module load pytorch-gpu/py3/2.0.0
+
+srun python train_sa.py runs/clevr/slot_att/config.yaml --wandb
diff --git a/slurm/sa_train_ycb.slurm b/slurm/sa_train_ycb.slurm
new file mode 100644
index 0000000000000000000000000000000000000000..d44cbbc306eb38ee8b72919b77fb50e711c77514
--- /dev/null
+++ b/slurm/sa_train_ycb.slurm
@@ -0,0 +1,24 @@
+#!/bin/bash
+
+#SBATCH --job-name=slot_att_ycb
+#SBATCH --output=logs/job.%j.out
+#SBATCH --error=logs/job.%j.err 
+
+#SBATCH --account=uli@v100
+#SBATCH --partition=gpu_p2
+#SBATCH --gres=gpu:8
+#SBATCH --nodes=1
+#SBATCH --ntasks-per-node=8 #number of MPI tasks per node (=number of GPUs per node)
+#SBATCH --exclusive
+#SBATCH --hint=nomultithread
+#SBATCH -t 20:00:00
+#SBATCH --mail-user=alexandre.chapin@ec-lyon.fr
+#SBATCH --mail-typ=FAIL
+#SBATCH --qos=qos_gpu-t3
+
+module purge
+echo ${SLURM_NODELIST}
+#module load cudnn/8.5.0.96-11.7-cuda
+module load pytorch-gpu/py3/2.0.0
+
+srun python train_sa.py runs/ycb/slot_att/config.yaml --wandb
diff --git a/train_clevr.slurm b/slurm/train_clevr.slurm
similarity index 100%
rename from train_clevr.slurm
rename to slurm/train_clevr.slurm
diff --git a/train_msn.slurm b/slurm/train_msn.slurm
similarity index 100%
rename from train_msn.slurm
rename to slurm/train_msn.slurm
diff --git a/slurm/tsa_train_clevr.slurm b/slurm/tsa_train_clevr.slurm
new file mode 100644
index 0000000000000000000000000000000000000000..8f1fcc7ad7516b927484226d3d8720765d4ce9e8
--- /dev/null
+++ b/slurm/tsa_train_clevr.slurm
@@ -0,0 +1,24 @@
+#!/bin/bash
+
+#SBATCH --job-name=trans_slot_att_ycb
+#SBATCH --output=logs/job.%j.out
+#SBATCH --error=logs/job.%j.err 
+
+#SBATCH --account=uli@v100
+#SBATCH --partition=gpu_p2
+#SBATCH --gres=gpu:8
+#SBATCH --nodes=1
+#SBATCH --ntasks-per-node=8 #number of MPI tasks per node (=number of GPUs per node)
+#SBATCH --exclusive
+#SBATCH --hint=nomultithread
+#SBATCH -t 20:00:00
+#SBATCH --mail-user=alexandre.chapin@ec-lyon.fr
+#SBATCH --mail-typ=FAIL
+#SBATCH --qos=qos_gpu-t3
+
+module purge
+echo ${SLURM_NODELIST}
+#module load cudnn/8.5.0.96-11.7-cuda
+module load pytorch-gpu/py3/2.0.0
+
+srun python train_sa.py runs/clevr/slot_att/config_tsa.yaml --wandb
diff --git a/slurm/tsa_train_ycb.slurm b/slurm/tsa_train_ycb.slurm
new file mode 100644
index 0000000000000000000000000000000000000000..673c63a87dcf08b2fdcad7b5b885779bdbbc0e00
--- /dev/null
+++ b/slurm/tsa_train_ycb.slurm
@@ -0,0 +1,24 @@
+#!/bin/bash
+
+#SBATCH --job-name=trans_slot_att_ycb
+#SBATCH --output=logs/job.%j.out
+#SBATCH --error=logs/job.%j.err 
+
+#SBATCH --account=uli@v100
+#SBATCH --partition=gpu_p2
+#SBATCH --gres=gpu:8
+#SBATCH --nodes=1
+#SBATCH --ntasks-per-node=8 #number of MPI tasks per node (=number of GPUs per node)
+#SBATCH --exclusive
+#SBATCH --hint=nomultithread
+#SBATCH -t 20:00:00
+#SBATCH --mail-user=alexandre.chapin@ec-lyon.fr
+#SBATCH --mail-typ=FAIL
+#SBATCH --qos=qos_gpu-t3
+
+module purge
+echo ${SLURM_NODELIST}
+#module load cudnn/8.5.0.96-11.7-cuda
+module load pytorch-gpu/py3/2.0.0
+
+srun python train_sa.py runs/ycb/slot_att/config_tsa.yaml --wandb
diff --git a/train_sa.py b/train_sa.py
index 0397a7cf90d4c171efff2af7af4d949b18871ae1..650d88eb6d504fbc870226683cfdb0be3203c275 100644
--- a/train_sa.py
+++ b/train_sa.py
@@ -32,6 +32,7 @@ def main():
     parser.add_argument('--wandb', action='store_true', help='Log run to Weights and Biases.')
     parser.add_argument('--seed', type=int, default=0, help='Random seed.')
     parser.add_argument('--ckpt', type=str, default=None, help='Model checkpoint path')
+    parser.add_argument('--profiler', action='store_true', help='Activate checkpoiting')
 
     args = parser.parse_args()
     with open(args.config, 'r') as f:
@@ -53,12 +54,12 @@ def main():
     #### Create datasets
     train_dataset = data.get_dataset('train', cfg['data'])
     train_loader = DataLoader(
-        train_dataset, batch_size=batch_size, num_workers=num_workers-9 if num_workers > 9 else 0,
+        train_dataset, batch_size=batch_size, num_workers=num_workers-8 if num_workers > 8 else 0,
         shuffle=True, worker_init_fn=data.worker_init_fn, pin_memory=True)
     
     val_dataset = data.get_dataset('val', cfg['data'])
     val_loader = DataLoader(
-        val_dataset, batch_size=batch_size, num_workers=8 if num_workers > 9 else 0,
+        val_dataset, batch_size=batch_size, num_workers=8 if num_workers > 8 else 0,
         shuffle=True, worker_init_fn=data.worker_init_fn, pin_memory=True)
 
     #### Create model
@@ -73,31 +74,37 @@ def main():
         mode="max",
         dirpath="./checkpoints",
         filename="ckpt-" +  str(cfg["data"]["dataset"])+ "-slots:"+ str(cfg["model"]["num_slots"]) + "-" + str(cfg["model"]["model_type"]) +"-{epoch:02d}-psnr{val_psnr:.2f}",
-        save_weights_only=True, # don't save optimizer states nor lr-scheduler, ...
-        every_n_train_steps=cfg["training"]["checkpoint_every"]
+        save_weights_only=True # don't save optimizer states nor lr-scheduler, ...
     )
 
     early_stopping = EarlyStopping(monitor="val_psnr", mode="max")
 
     trainer = pl.Trainer(accelerator="gpu", 
                          devices=num_gpus, 
-                         profiler="simple", 
+                         profiler="simple" if args.proffiler else None, 
                          default_root_dir="./logs", 
                          logger=WandbLogger(project="slot-att", offline=True) if args.wandb else None,
                          strategy="ddp_find_unused_parameters_true" if num_gpus > 1 else "auto", 
                          callbacks=[checkpoint_callback, early_stopping],
                          log_every_n_steps=100, 
-                         val_check_interval=cfg["training"]["validate_every"],
-                         check_val_every_n_epoch=None,
                          max_steps=num_train_steps, 
                          enable_model_summary=True)
 
     trainer.fit(model, train_loader, val_loader)
+    #### Evaluate the model
+    print(f"Begin testing : ")
+    test_dataset = data.get_dataset('test', cfg['data'])
+    test_loader = DataLoader(
+        test_dataset, batch_size=batch_size, num_workers=8 if num_workers > 8 else 0,
+        shuffle=True, worker_init_fn=data.worker_init_fn, pin_memory=True)
+    
+    trainer.test(ckpt_path="best", dataloaders=test_loader, verbose=True)
 
+    print(f"Begin visualization : ")
     #### Create datasets
     vis_dataset = data.get_dataset('train', cfg['data'])
     vis_loader = DataLoader(
-        vis_dataset, batch_size=1, num_workers=1,
+        vis_dataset, batch_size=1, num_workers=0,
         shuffle=True, worker_init_fn=data.worker_init_fn)
     
     device = model.device