From 42faaee0d52f13543e0d11aec8b3d1adc5129372 Mon Sep 17 00:00:00 2001
From: alexcbb <alexchapin@hotmail.fr>
Date: Wed, 26 Jul 2023 14:24:49 +0200
Subject: [PATCH] Handle checkpoint

---
 train_sa.py | 7 ++++++-
 1 file changed, 6 insertions(+), 1 deletion(-)

diff --git a/train_sa.py b/train_sa.py
index ad03192..1cf29f6 100644
--- a/train_sa.py
+++ b/train_sa.py
@@ -27,7 +27,7 @@ def main():
     parser.add_argument('config', type=str, help="Where to save the checkpoints.")
     parser.add_argument('--wandb', action='store_true', help='Log run to Weights and Biases.')
     parser.add_argument('--seed', type=int, default=0, help='Random seed.')
-    parser.add_argument('--ckpt', type=str, default=".", help='Model checkpoint path')
+    parser.add_argument('--ckpt', type=str, default=None, help='Model checkpoint path')
 
     args = parser.parse_args()
     with open(args.config, 'r') as f:
@@ -61,6 +61,11 @@ def main():
     #### Create model
     model = LitSlotAttentionAutoEncoder(resolution, num_slots, num_iterations, cfg=cfg)
 
+    if args.ckpt:
+        checkpoint = torch.load(args.ckpt)
+        model.load_state_dict(checkpoint['state_dict'])
+
+
     checkpoint_callback = ModelCheckpoint(
         save_top_k=10,
         monitor="val_psnr",
-- 
GitLab