diff --git a/train_sa.py b/train_sa.py index 17886ac7323842004073005b6906e148f752ad99..269319caec62329c5ca1c60aada061cc9a9b358e 100644 --- a/train_sa.py +++ b/train_sa.py @@ -60,7 +60,6 @@ def main(): #### Create model model = LitSlotAttentionAutoEncoder(resolution, num_slots, num_iterations, cfg=cfg) - wandb_logger = WandbLogger(project="slot-attention") checkpoint_callback = ModelCheckpoint( save_top_k=10, @@ -71,7 +70,7 @@ def main(): ) trainer = pl.Trainer(accelerator="gpu", devices=num_gpus, profiler="simple", - default_root_dir="./logs", logger=wandb_logger, + default_root_dir="./logs", logger=WandbLogger(project="slot-attention") if args.wandb else None, strategy="ddp" if num_gpus > 1 else "default", callbacks=[checkpoint_callback], deterministic=True, log_every_n_steps=100, max_steps=num_train_steps)