diff --git a/main.py b/main.py
index 2ad0475975b0c326c5254d6059b088dff097dc10..06be979f5043e6f67c09c351f15d1df539b4c89e 100644
--- a/main.py
+++ b/main.py
@@ -1,4 +1,5 @@
 import matplotlib.pyplot as plt
+import numpy as np
 
 from config.config import load_args
 from dataset.dataset import load_data
@@ -6,7 +7,9 @@ import torch
 import torch.nn as nn
 from models.model import Classification_model
 import torch.optim as optim
-
+from sklearn.metrics import confusion_matrix
+import seaborn as sn
+import pandas as pd
 
 
 
@@ -31,7 +34,7 @@ def train(model, data_train, optimizer, loss_function, epoch):
         optimizer.step()
     losses = losses/len(data_train.dataset)
     acc = acc/len(data_train.dataset)
-    print('Train epoch ',epoch, 'loss : ',losses,' acc : ',acc)
+    print('Test epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch,losses,acc))
     return losses, acc
 
 def test(model, data_test, loss_function, epoch):
@@ -52,7 +55,7 @@ def test(model, data_test, loss_function, epoch):
         losses += loss.item()
     losses = losses/len(data_test.dataset)
     acc = acc/len(data_test.dataset)
-    print('Test epoch ',epoch,'loss : ',losses,' acc : ',acc)
+    print('Test epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch,losses,acc))
     return losses,acc
 
 def run(args):
@@ -85,8 +88,44 @@ def run(args):
     plt.plot(val_acc)
     plt.plot(train_acc)
     plt.plot(train_acc)
+    plt.show()
+    plt.savefig('output/training_plot.png')
+
+    load_model(model, args.save_path)
+    make_prediction(model,data_test)
+
+
+def make_prediction(model, data):
+    y_pred = []
+    y_true = []
+
+    # iterate over test data
+    for im, label in data:
+        label = label.long()
+        if torch.cuda.is_available():
+            im = im.cuda()
+        output = model(im)
+
+        output = (torch.max(torch.exp(output), 1)[1]).data.cpu().numpy()
+        y_pred.extend(output)
+
+        label = label.data.cpu().numpy()
+        y_true.extend(label)  # Save Truth
+    # constant for classes
+
+    classes = data.dataset.dataset.classes
+
+    # Build confusion matrix
+    cf_matrix = confusion_matrix(y_true, y_pred)
+    df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None], index=[i for i in classes],
+                         columns=[i for i in classes])
+    plt.figure(figsize=(12, 7))
+    sn.heatmap(df_cm, annot=True)
+    plt.savefig('output.png')
+
 
 def save_model(model, path):
+    print('Model saved')
     torch.save(model.state_dict(), path)
 
 def load_model(model, path):
diff --git a/requirements.txt b/requirements.txt
index 6780dc37ff73f4e4abb0c84a415145c2a67d5ac0..340cebc299e1dfd9d72386f0ecf2bce1b8a5d4ab 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -5,4 +5,5 @@ pyopenms~=3.3.0
 openpyxl
 torch~=2.6.0
 torchvision~=0.21.0
-pillow~=11.1.0
\ No newline at end of file
+pillow~=11.1.0
+seaborn~=0.13.2
\ No newline at end of file