diff --git a/train_sa.py b/train_sa.py index 5963cf580f6422d0809ccbe2f465bc14e7be7161..0397a7cf90d4c171efff2af7af4d949b18871ae1 100644 --- a/train_sa.py +++ b/train_sa.py @@ -84,7 +84,7 @@ def main(): profiler="simple", default_root_dir="./logs", logger=WandbLogger(project="slot-att", offline=True) if args.wandb else None, - strategy="ddp" if num_gpus > 1 else "auto", + strategy="ddp_find_unused_parameters_true" if num_gpus > 1 else "auto", callbacks=[checkpoint_callback, early_stopping], log_every_n_steps=100, val_check_interval=cfg["training"]["validate_every"],