Skip to content
Snippets Groups Projects
Commit 42faaee0 authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Handle checkpoint

parent f54bc1b2
No related branches found
No related tags found
No related merge requests found
...@@ -27,7 +27,7 @@ def main(): ...@@ -27,7 +27,7 @@ def main():
parser.add_argument('config', type=str, help="Where to save the checkpoints.") parser.add_argument('config', type=str, help="Where to save the checkpoints.")
parser.add_argument('--wandb', action='store_true', help='Log run to Weights and Biases.') parser.add_argument('--wandb', action='store_true', help='Log run to Weights and Biases.')
parser.add_argument('--seed', type=int, default=0, help='Random seed.') parser.add_argument('--seed', type=int, default=0, help='Random seed.')
parser.add_argument('--ckpt', type=str, default=".", help='Model checkpoint path') parser.add_argument('--ckpt', type=str, default=None, help='Model checkpoint path')
args = parser.parse_args() args = parser.parse_args()
with open(args.config, 'r') as f: with open(args.config, 'r') as f:
...@@ -61,6 +61,11 @@ def main(): ...@@ -61,6 +61,11 @@ def main():
#### Create model #### Create model
model = LitSlotAttentionAutoEncoder(resolution, num_slots, num_iterations, cfg=cfg) model = LitSlotAttentionAutoEncoder(resolution, num_slots, num_iterations, cfg=cfg)
if args.ckpt:
checkpoint = torch.load(args.ckpt)
model.load_state_dict(checkpoint['state_dict'])
checkpoint_callback = ModelCheckpoint( checkpoint_callback = ModelCheckpoint(
save_top_k=10, save_top_k=10,
monitor="val_psnr", monitor="val_psnr",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment