From 1cb6c3ab1b91cb381476ca1dd8691bee6fe31053 Mon Sep 17 00:00:00 2001
From: Schneider Leo <leo.schneider@etu.ec-lyon.fr>
Date: Fri, 13 Dec 2024 14:02:30 +0100
Subject: [PATCH] data augmented

---
 main.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/main.py b/main.py
index 92b6b22..14e0159 100644
--- a/main.py
+++ b/main.py
@@ -84,7 +84,7 @@ def run(epochs, eval_inter, save_inter, model, data_train, data_val, data_test,
         if e % save_inter == 0:
             save(model, 'model_common_' + str(e) + '.pt')
     model.load_state_dict(torch.load(output.strip('.csv')+'pt', weights_only=True))
-    save_pred(model, data_test, output)
+    save_pred(model, data_test, output, criterion_rt, metric_rt, wandb)
 
 
 def main(args):
-- 
GitLab