diff --git a/runs/clevr/slot_att/config.yaml b/runs/clevr/slot_att/config.yaml index d53f13d186af8e2157e0ae1e0984ec58c8a9642b..5863fa6bb9a8ec23a604cf967ca45edf3eebfead 100644 --- a/runs/clevr/slot_att/config.yaml +++ b/runs/clevr/slot_att/config.yaml @@ -9,16 +9,14 @@ model: hidden_dim: 128 iters: 3 training: - num_workers: 2 - num_gpus: 1 - batch_size: 8 - max_it: 333000000 + num_workers: 48 + num_gpus: 8 + batch_size: 256 + max_it: 1000000 warmup_it: 10000 lr_warmup: 5000 decay_rate: 0.5 decay_it: 100000 - visualize_every: 5000 - validate_every: 5000 checkpoint_every: 1000 backup_every: 25000 diff --git a/runs/clevr/slot_att/config_tsa.yaml b/runs/clevr/slot_att/config_tsa.yaml index 5c9592739a292f0557e5386040dc180203095b78..1cb94fa1f35339df08fd634fc1db02da2c37f924 100644 --- a/runs/clevr/slot_att/config_tsa.yaml +++ b/runs/clevr/slot_att/config_tsa.yaml @@ -9,16 +9,14 @@ model: hidden_dim: 128 iters: 3 training: - num_workers: 2 - num_gpus: 1 - batch_size: 8 - max_it: 333000000 + num_workers: 48 + num_gpus: 8 + batch_size: 256 + max_it: 1000000 warmup_it: 10000 lr_warmup: 5000 decay_rate: 0.5 decay_it: 100000 - visualize_every: 5000 - validate_every: 5000 checkpoint_every: 1000 backup_every: 25000 diff --git a/runs/ycb/slot_att/config.yaml b/runs/ycb/slot_att/config.yaml index ebc5132f030cfbcb7720e5226fdd97a272414880..ce3ef9944e70213cd512f56586464d735fceb69c 100644 --- a/runs/ycb/slot_att/config.yaml +++ b/runs/ycb/slot_att/config.yaml @@ -12,13 +12,11 @@ training: num_workers: 2 num_gpus: 1 batch_size: 8 - max_it: 333000000 + max_it: 1000000 warmup_it: 10000 lr_warmup: 5000 decay_rate: 0.5 decay_it: 100000 - visualize_every: 5000 - validate_every: 5000 checkpoint_every: 1000 backup_every: 25000 diff --git a/runs/ycb/slot_att/config_tsa.yaml b/runs/ycb/slot_att/config_tsa.yaml index 0e2a05a7a6a25fc626b602c57bb6983049b7b7d8..1ec79a88799c95eb518d4dc695b3e3a3f672a8a3 100644 --- a/runs/ycb/slot_att/config_tsa.yaml +++ b/runs/ycb/slot_att/config_tsa.yaml @@ -12,13 +12,11 @@ training: num_workers: 2 num_gpus: 1 batch_size: 8 - max_it: 333000000 + max_it: 1000000 warmup_it: 10000 lr_warmup: 5000 decay_rate: 0.5 decay_it: 100000 - visualize_every: 5000 - validate_every: 5000 checkpoint_every: 1000 backup_every: 25000 diff --git a/render.slurm b/slurm/render.slurm similarity index 100% rename from render.slurm rename to slurm/render.slurm diff --git a/slurm/sa_train_clever.slurm b/slurm/sa_train_clever.slurm new file mode 100644 index 0000000000000000000000000000000000000000..eca34681e3197aeb47a61616671d443c25571056 --- /dev/null +++ b/slurm/sa_train_clever.slurm @@ -0,0 +1,24 @@ +#!/bin/bash + +#SBATCH --job-name=slot_att_clevr +#SBATCH --output=logs/job.%j.out +#SBATCH --error=logs/job.%j.err + +#SBATCH --account=uli@v100 +#SBATCH --partition=gpu_p2 +#SBATCH --gres=gpu:8 +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=8 #number of MPI tasks per node (=number of GPUs per node) +#SBATCH --exclusive +#SBATCH --hint=nomultithread +#SBATCH -t 20:00:00 +#SBATCH --mail-user=alexandre.chapin@ec-lyon.fr +#SBATCH --mail-typ=FAIL +#SBATCH --qos=qos_gpu-t3 + +module purge +echo ${SLURM_NODELIST} +#module load cudnn/8.5.0.96-11.7-cuda +module load pytorch-gpu/py3/2.0.0 + +srun python train_sa.py runs/clevr/slot_att/config.yaml --wandb diff --git a/slurm/sa_train_ycb.slurm b/slurm/sa_train_ycb.slurm new file mode 100644 index 0000000000000000000000000000000000000000..d44cbbc306eb38ee8b72919b77fb50e711c77514 --- /dev/null +++ b/slurm/sa_train_ycb.slurm @@ -0,0 +1,24 @@ +#!/bin/bash + +#SBATCH --job-name=slot_att_ycb +#SBATCH --output=logs/job.%j.out +#SBATCH --error=logs/job.%j.err + +#SBATCH --account=uli@v100 +#SBATCH --partition=gpu_p2 +#SBATCH --gres=gpu:8 +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=8 #number of MPI tasks per node (=number of GPUs per node) +#SBATCH --exclusive +#SBATCH --hint=nomultithread +#SBATCH -t 20:00:00 +#SBATCH --mail-user=alexandre.chapin@ec-lyon.fr +#SBATCH --mail-typ=FAIL +#SBATCH --qos=qos_gpu-t3 + +module purge +echo ${SLURM_NODELIST} +#module load cudnn/8.5.0.96-11.7-cuda +module load pytorch-gpu/py3/2.0.0 + +srun python train_sa.py runs/ycb/slot_att/config.yaml --wandb diff --git a/train_clevr.slurm b/slurm/train_clevr.slurm similarity index 100% rename from train_clevr.slurm rename to slurm/train_clevr.slurm diff --git a/train_msn.slurm b/slurm/train_msn.slurm similarity index 100% rename from train_msn.slurm rename to slurm/train_msn.slurm diff --git a/slurm/tsa_train_clevr.slurm b/slurm/tsa_train_clevr.slurm new file mode 100644 index 0000000000000000000000000000000000000000..8f1fcc7ad7516b927484226d3d8720765d4ce9e8 --- /dev/null +++ b/slurm/tsa_train_clevr.slurm @@ -0,0 +1,24 @@ +#!/bin/bash + +#SBATCH --job-name=trans_slot_att_ycb +#SBATCH --output=logs/job.%j.out +#SBATCH --error=logs/job.%j.err + +#SBATCH --account=uli@v100 +#SBATCH --partition=gpu_p2 +#SBATCH --gres=gpu:8 +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=8 #number of MPI tasks per node (=number of GPUs per node) +#SBATCH --exclusive +#SBATCH --hint=nomultithread +#SBATCH -t 20:00:00 +#SBATCH --mail-user=alexandre.chapin@ec-lyon.fr +#SBATCH --mail-typ=FAIL +#SBATCH --qos=qos_gpu-t3 + +module purge +echo ${SLURM_NODELIST} +#module load cudnn/8.5.0.96-11.7-cuda +module load pytorch-gpu/py3/2.0.0 + +srun python train_sa.py runs/clevr/slot_att/config_tsa.yaml --wandb diff --git a/slurm/tsa_train_ycb.slurm b/slurm/tsa_train_ycb.slurm new file mode 100644 index 0000000000000000000000000000000000000000..673c63a87dcf08b2fdcad7b5b885779bdbbc0e00 --- /dev/null +++ b/slurm/tsa_train_ycb.slurm @@ -0,0 +1,24 @@ +#!/bin/bash + +#SBATCH --job-name=trans_slot_att_ycb +#SBATCH --output=logs/job.%j.out +#SBATCH --error=logs/job.%j.err + +#SBATCH --account=uli@v100 +#SBATCH --partition=gpu_p2 +#SBATCH --gres=gpu:8 +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=8 #number of MPI tasks per node (=number of GPUs per node) +#SBATCH --exclusive +#SBATCH --hint=nomultithread +#SBATCH -t 20:00:00 +#SBATCH --mail-user=alexandre.chapin@ec-lyon.fr +#SBATCH --mail-typ=FAIL +#SBATCH --qos=qos_gpu-t3 + +module purge +echo ${SLURM_NODELIST} +#module load cudnn/8.5.0.96-11.7-cuda +module load pytorch-gpu/py3/2.0.0 + +srun python train_sa.py runs/ycb/slot_att/config_tsa.yaml --wandb diff --git a/train_sa.py b/train_sa.py index 0397a7cf90d4c171efff2af7af4d949b18871ae1..650d88eb6d504fbc870226683cfdb0be3203c275 100644 --- a/train_sa.py +++ b/train_sa.py @@ -32,6 +32,7 @@ def main(): parser.add_argument('--wandb', action='store_true', help='Log run to Weights and Biases.') parser.add_argument('--seed', type=int, default=0, help='Random seed.') parser.add_argument('--ckpt', type=str, default=None, help='Model checkpoint path') + parser.add_argument('--profiler', action='store_true', help='Activate checkpoiting') args = parser.parse_args() with open(args.config, 'r') as f: @@ -53,12 +54,12 @@ def main(): #### Create datasets train_dataset = data.get_dataset('train', cfg['data']) train_loader = DataLoader( - train_dataset, batch_size=batch_size, num_workers=num_workers-9 if num_workers > 9 else 0, + train_dataset, batch_size=batch_size, num_workers=num_workers-8 if num_workers > 8 else 0, shuffle=True, worker_init_fn=data.worker_init_fn, pin_memory=True) val_dataset = data.get_dataset('val', cfg['data']) val_loader = DataLoader( - val_dataset, batch_size=batch_size, num_workers=8 if num_workers > 9 else 0, + val_dataset, batch_size=batch_size, num_workers=8 if num_workers > 8 else 0, shuffle=True, worker_init_fn=data.worker_init_fn, pin_memory=True) #### Create model @@ -73,31 +74,37 @@ def main(): mode="max", dirpath="./checkpoints", filename="ckpt-" + str(cfg["data"]["dataset"])+ "-slots:"+ str(cfg["model"]["num_slots"]) + "-" + str(cfg["model"]["model_type"]) +"-{epoch:02d}-psnr{val_psnr:.2f}", - save_weights_only=True, # don't save optimizer states nor lr-scheduler, ... - every_n_train_steps=cfg["training"]["checkpoint_every"] + save_weights_only=True # don't save optimizer states nor lr-scheduler, ... ) early_stopping = EarlyStopping(monitor="val_psnr", mode="max") trainer = pl.Trainer(accelerator="gpu", devices=num_gpus, - profiler="simple", + profiler="simple" if args.proffiler else None, default_root_dir="./logs", logger=WandbLogger(project="slot-att", offline=True) if args.wandb else None, strategy="ddp_find_unused_parameters_true" if num_gpus > 1 else "auto", callbacks=[checkpoint_callback, early_stopping], log_every_n_steps=100, - val_check_interval=cfg["training"]["validate_every"], - check_val_every_n_epoch=None, max_steps=num_train_steps, enable_model_summary=True) trainer.fit(model, train_loader, val_loader) + #### Evaluate the model + print(f"Begin testing : ") + test_dataset = data.get_dataset('test', cfg['data']) + test_loader = DataLoader( + test_dataset, batch_size=batch_size, num_workers=8 if num_workers > 8 else 0, + shuffle=True, worker_init_fn=data.worker_init_fn, pin_memory=True) + + trainer.test(ckpt_path="best", dataloaders=test_loader, verbose=True) + print(f"Begin visualization : ") #### Create datasets vis_dataset = data.get_dataset('train', cfg['data']) vis_loader = DataLoader( - vis_dataset, batch_size=1, num_workers=1, + vis_dataset, batch_size=1, num_workers=0, shuffle=True, worker_init_fn=data.worker_init_fn) device = model.device