From 1cb6c3ab1b91cb381476ca1dd8691bee6fe31053 Mon Sep 17 00:00:00 2001 From: Schneider Leo <leo.schneider@etu.ec-lyon.fr> Date: Fri, 13 Dec 2024 14:02:30 +0100 Subject: [PATCH] data augmented --- main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main.py b/main.py index 92b6b22..14e0159 100644 --- a/main.py +++ b/main.py @@ -84,7 +84,7 @@ def run(epochs, eval_inter, save_inter, model, data_train, data_val, data_test, if e % save_inter == 0: save(model, 'model_common_' + str(e) + '.pt') model.load_state_dict(torch.load(output.strip('.csv')+'pt', weights_only=True)) - save_pred(model, data_test, output) + save_pred(model, data_test, output, criterion_rt, metric_rt, wandb) def main(args): -- GitLab