From b3a3c45ae3ad6824ff7112fa2b93a1ab8dadb601 Mon Sep 17 00:00:00 2001
From: Schneider Leo <leo.schneider@etu.ec-lyon.fr>
Date: Wed, 12 Mar 2025 11:56:23 +0100
Subject: [PATCH] model cuda loading

---
 main.py | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/main.py b/main.py
index ba90ffa..7246be3 100644
--- a/main.py
+++ b/main.py
@@ -89,13 +89,13 @@ def run(args):
     plt.plot(train_acc)
     plt.plot(train_acc)
     plt.show()
-    plt.savefig('output/training_plot.png')
+    plt.savefig('output/training_plot_{}_.png'.format(args.output))
 
     load_model(model, args.save_path)
-    make_prediction(model,data_test)
+    make_prediction(model,data_test, 'output/confusion_matrix_{}_.png'.format(args.output))
 
 
-def make_prediction(model, data):
+def make_prediction(model, data, f_name):
     y_pred = []
     y_true = []
 
@@ -121,7 +121,7 @@ def make_prediction(model, data):
                          columns=[i for i in classes])
     plt.figure(figsize=(12, 7))
     sn.heatmap(df_cm, annot=True)
-    plt.savefig('confusion_matrix.png')
+    plt.savefig(f_name)
 
 
 def save_model(model, path):
-- 
GitLab