diff --git a/requirements.txt b/requirements.txt index 49db6f2a909ace3c07f8fe9dfae438dfe3fe7957..39e8a425c2c5c86a4533b52aef1ac5c676893e65 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ matplotlib tqdm opencv-python bitsandbytes +segment_anything diff --git a/train_sa.py b/train_sa.py index a35c7c3e22b705e395329b537be91c2c558f9143..8e93e2ce7972781f749da49dd741912c77ee84f4 100644 --- a/train_sa.py +++ b/train_sa.py @@ -69,8 +69,8 @@ def main(): ) trainer = pl.Trainer(accelerator="gpu", devices=num_gpus, profiler="simple", - default_root_dir="./logs", logger=WandbLogger(project="slot-attention") if args.wandb else None, - strategy="ddp" if num_gpus > 1 else "default", callbacks=[checkpoint_callback], deterministic=True, + default_root_dir="./logs", logger=WandbLogger(project="slot-att") if args.wandb else None, + strategy="ddp" if num_gpus > 1 else "default", callbacks=[checkpoint_callback], log_every_n_steps=100, max_steps=num_train_steps) trainer.fit(model, train_loader, val_loader)