diff --git a/main.py b/main.py index bff9cd6e14c75e1d8f1ae76bed652686956a71f9..bf1509ecc22174e4a878da868fb8dc62014bc8e6 100644 --- a/main.py +++ b/main.py @@ -243,17 +243,17 @@ def run_int(epochs, eval_inter, save_inter, model, data_train, data_val, optimiz train_int(model, data_train, e, optimizer, criterion, metric, wandb=wandb) if e % eval_inter == 0: loss = eval_int(model, data_val, e, criterion, metric, wandb=wandb) - # if loss < best_loss: - # best_epoch = e - # if wandb is not None: - # save(model, optimizer, epochs, 'model_int' + wandb + '.pt') - # else: - # save(model, optimizer, epochs, 'model_int.pt') - # if wandb is not None: - # model_final = load('model_int' + wandb + '.pt') - # else: - # model_final = load('model_int.pt') - # print('Best epoch : ',e) + if loss < best_loss: + best_epoch = e + if wandb is not None: + save(model, optimizer, epochs, 'model_int' + wandb + '.pt') + else: + save(model, optimizer, epochs, 'model_int.pt') + if wandb is not None: + model_final = load('model_int' + wandb + '.pt') + else: + model_final = load('model_int.pt') + print('Best epoch : ',e) def main_rt(args):