From 42faaee0d52f13543e0d11aec8b3d1adc5129372 Mon Sep 17 00:00:00 2001 From: alexcbb <alexchapin@hotmail.fr> Date: Wed, 26 Jul 2023 14:24:49 +0200 Subject: [PATCH] Handle checkpoint --- train_sa.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/train_sa.py b/train_sa.py index ad03192..1cf29f6 100644 --- a/train_sa.py +++ b/train_sa.py @@ -27,7 +27,7 @@ def main(): parser.add_argument('config', type=str, help="Where to save the checkpoints.") parser.add_argument('--wandb', action='store_true', help='Log run to Weights and Biases.') parser.add_argument('--seed', type=int, default=0, help='Random seed.') - parser.add_argument('--ckpt', type=str, default=".", help='Model checkpoint path') + parser.add_argument('--ckpt', type=str, default=None, help='Model checkpoint path') args = parser.parse_args() with open(args.config, 'r') as f: @@ -61,6 +61,11 @@ def main(): #### Create model model = LitSlotAttentionAutoEncoder(resolution, num_slots, num_iterations, cfg=cfg) + if args.ckpt: + checkpoint = torch.load(args.ckpt) + model.load_state_dict(checkpoint['state_dict']) + + checkpoint_callback = ModelCheckpoint( save_top_k=10, monitor="val_psnr", -- GitLab