diff --git a/train_sa.py b/train_sa.py index 8e93e2ce7972781f749da49dd741912c77ee84f4..25f5c124da5a746f89a7aaf736d3c2ce42825306 100644 --- a/train_sa.py +++ b/train_sa.py @@ -70,7 +70,7 @@ def main(): trainer = pl.Trainer(accelerator="gpu", devices=num_gpus, profiler="simple", default_root_dir="./logs", logger=WandbLogger(project="slot-att") if args.wandb else None, - strategy="ddp" if num_gpus > 1 else "default", callbacks=[checkpoint_callback], + strategy="ddp_find_unused_parameters_true" if num_gpus > 1 else "default", callbacks=[checkpoint_callback], log_every_n_steps=100, max_steps=num_train_steps) trainer.fit(model, train_loader, val_loader)