diff --git a/main.py b/main.py
index ba90ffaf5e1bf17edab389546101ad77cc31b7ad..7246be316d9cdecd8e030ea3a537680a21c0b8b7 100644
--- a/main.py
+++ b/main.py
@@ -89,13 +89,13 @@ def run(args):
     plt.plot(train_acc)
     plt.plot(train_acc)
     plt.show()
-    plt.savefig('output/training_plot.png')
+    plt.savefig('output/training_plot_{}_.png'.format(args.output))
 
     load_model(model, args.save_path)
-    make_prediction(model,data_test)
+    make_prediction(model,data_test, 'output/confusion_matrix_{}_.png'.format(args.output))
 
 
-def make_prediction(model, data):
+def make_prediction(model, data, f_name):
     y_pred = []
     y_true = []
 
@@ -121,7 +121,7 @@ def make_prediction(model, data):
                          columns=[i for i in classes])
     plt.figure(figsize=(12, 7))
     sn.heatmap(df_cm, annot=True)
-    plt.savefig('confusion_matrix.png')
+    plt.savefig(f_name)
 
 
 def save_model(model, path):