diff --git a/train_sa.py b/train_sa.py
index 14c097210a87c314438d0980ae42fd029139058f..5a3f9a76e7e72c763f6f4ad37d7bfa2cf7265350 100644
--- a/train_sa.py
+++ b/train_sa.py
@@ -97,7 +97,7 @@ def main():
         shuffle=True, worker_init_fn=data.worker_init_fn)
 
     #### Create model
-    model = SlotAttentionAutoEncoder(resolution, num_slots, num_iterations).to(device)
+    model = SlotAttentionAutoEncoder(resolution, num_slots, num_iterations, cfg=cfg).to(device)
     num_params = sum(p.numel() for p in model.parameters())
 
     print('Number of parameters:')