diff --git a/train_sa.py b/train_sa.py
index ad03192032f1e171721ee3f801dffa6d1400182d..1cf29f6363dc0648576e786cb665013f136da107 100644
--- a/train_sa.py
+++ b/train_sa.py
@@ -27,7 +27,7 @@ def main():
     parser.add_argument('config', type=str, help="Where to save the checkpoints.")
     parser.add_argument('--wandb', action='store_true', help='Log run to Weights and Biases.')
     parser.add_argument('--seed', type=int, default=0, help='Random seed.')
-    parser.add_argument('--ckpt', type=str, default=".", help='Model checkpoint path')
+    parser.add_argument('--ckpt', type=str, default=None, help='Model checkpoint path')
 
     args = parser.parse_args()
     with open(args.config, 'r') as f:
@@ -61,6 +61,11 @@ def main():
     #### Create model
     model = LitSlotAttentionAutoEncoder(resolution, num_slots, num_iterations, cfg=cfg)
 
+    if args.ckpt:
+        checkpoint = torch.load(args.ckpt)
+        model.load_state_dict(checkpoint['state_dict'])
+
+
     checkpoint_callback = ModelCheckpoint(
         save_top_k=10,
         monitor="val_psnr",