diff --git a/train_sa.py b/train_sa.py index ad03192032f1e171721ee3f801dffa6d1400182d..1cf29f6363dc0648576e786cb665013f136da107 100644 --- a/train_sa.py +++ b/train_sa.py @@ -27,7 +27,7 @@ def main(): parser.add_argument('config', type=str, help="Where to save the checkpoints.") parser.add_argument('--wandb', action='store_true', help='Log run to Weights and Biases.') parser.add_argument('--seed', type=int, default=0, help='Random seed.') - parser.add_argument('--ckpt', type=str, default=".", help='Model checkpoint path') + parser.add_argument('--ckpt', type=str, default=None, help='Model checkpoint path') args = parser.parse_args() with open(args.config, 'r') as f: @@ -61,6 +61,11 @@ def main(): #### Create model model = LitSlotAttentionAutoEncoder(resolution, num_slots, num_iterations, cfg=cfg) + if args.ckpt: + checkpoint = torch.load(args.ckpt) + model.load_state_dict(checkpoint['state_dict']) + + checkpoint_callback = ModelCheckpoint( save_top_k=10, monitor="val_psnr",