diff --git a/train_sa.py b/train_sa.py index 14c097210a87c314438d0980ae42fd029139058f..5a3f9a76e7e72c763f6f4ad37d7bfa2cf7265350 100644 --- a/train_sa.py +++ b/train_sa.py @@ -97,7 +97,7 @@ def main(): shuffle=True, worker_init_fn=data.worker_init_fn) #### Create model - model = SlotAttentionAutoEncoder(resolution, num_slots, num_iterations).to(device) + model = SlotAttentionAutoEncoder(resolution, num_slots, num_iterations, cfg=cfg).to(device) num_params = sum(p.numel() for p in model.parameters()) print('Number of parameters:')