From b3a3c45ae3ad6824ff7112fa2b93a1ab8dadb601 Mon Sep 17 00:00:00 2001 From: Schneider Leo <leo.schneider@etu.ec-lyon.fr> Date: Wed, 12 Mar 2025 11:56:23 +0100 Subject: [PATCH] model cuda loading --- main.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/main.py b/main.py index ba90ffa..7246be3 100644 --- a/main.py +++ b/main.py @@ -89,13 +89,13 @@ def run(args): plt.plot(train_acc) plt.plot(train_acc) plt.show() - plt.savefig('output/training_plot.png') + plt.savefig('output/training_plot_{}_.png'.format(args.output)) load_model(model, args.save_path) - make_prediction(model,data_test) + make_prediction(model,data_test, 'output/confusion_matrix_{}_.png'.format(args.output)) -def make_prediction(model, data): +def make_prediction(model, data, f_name): y_pred = [] y_true = [] @@ -121,7 +121,7 @@ def make_prediction(model, data): columns=[i for i in classes]) plt.figure(figsize=(12, 7)) sn.heatmap(df_cm, annot=True) - plt.savefig('confusion_matrix.png') + plt.savefig(f_name) def save_model(model, path): -- GitLab