diff --git a/main.py b/main.py index 14e0159a8e38f410906a198602332c1a5268304e..83e77e4f47ca14827453c43f3473db0fc9185fbc 100644 --- a/main.py +++ b/main.py @@ -81,8 +81,6 @@ def run(epochs, eval_inter, save_inter, model, data_train, data_val, data_test, mem = losses_rt torch.save(model.state_dict(), output.strip('.csv')+'pt') print('model saved') - if e % save_inter == 0: - save(model, 'model_common_' + str(e) + '.pt') model.load_state_dict(torch.load(output.strip('.csv')+'pt', weights_only=True)) save_pred(model, data_test, output, criterion_rt, metric_rt, wandb) @@ -162,7 +160,7 @@ def save_pred(model, data_val, output_path, criterion_rt, metric_rt, wandb=None - loss_rt = criterion_rt(rt, pred_rt) + loss_rt = criterion_rt(rt, pr_rt) losses_rt += loss_rt.item() dist_rt = metric_rt(rt, pred_rt) dist_rt_acc += dist_rt.item()