diff --git a/main.py b/main.py index 2ad0475975b0c326c5254d6059b088dff097dc10..06be979f5043e6f67c09c351f15d1df539b4c89e 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,5 @@ import matplotlib.pyplot as plt +import numpy as np from config.config import load_args from dataset.dataset import load_data @@ -6,7 +7,9 @@ import torch import torch.nn as nn from models.model import Classification_model import torch.optim as optim - +from sklearn.metrics import confusion_matrix +import seaborn as sn +import pandas as pd @@ -31,7 +34,7 @@ def train(model, data_train, optimizer, loss_function, epoch): optimizer.step() losses = losses/len(data_train.dataset) acc = acc/len(data_train.dataset) - print('Train epoch ',epoch, 'loss : ',losses,' acc : ',acc) + print('Test epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch,losses,acc)) return losses, acc def test(model, data_test, loss_function, epoch): @@ -52,7 +55,7 @@ def test(model, data_test, loss_function, epoch): losses += loss.item() losses = losses/len(data_test.dataset) acc = acc/len(data_test.dataset) - print('Test epoch ',epoch,'loss : ',losses,' acc : ',acc) + print('Test epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch,losses,acc)) return losses,acc def run(args): @@ -85,8 +88,44 @@ def run(args): plt.plot(val_acc) plt.plot(train_acc) plt.plot(train_acc) + plt.show() + plt.savefig('output/training_plot.png') + + load_model(model, args.save_path) + make_prediction(model,data_test) + + +def make_prediction(model, data): + y_pred = [] + y_true = [] + + # iterate over test data + for im, label in data: + label = label.long() + if torch.cuda.is_available(): + im = im.cuda() + output = model(im) + + output = (torch.max(torch.exp(output), 1)[1]).data.cpu().numpy() + y_pred.extend(output) + + label = label.data.cpu().numpy() + y_true.extend(label) # Save Truth + # constant for classes + + classes = data.dataset.dataset.classes + + # Build confusion matrix + cf_matrix = confusion_matrix(y_true, y_pred) + df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None], index=[i for i in classes], + columns=[i for i in classes]) + plt.figure(figsize=(12, 7)) + sn.heatmap(df_cm, annot=True) + plt.savefig('output.png') + def save_model(model, path): + print('Model saved') torch.save(model.state_dict(), path) def load_model(model, path): diff --git a/requirements.txt b/requirements.txt index 6780dc37ff73f4e4abb0c84a415145c2a67d5ac0..340cebc299e1dfd9d72386f0ecf2bce1b8a5d4ab 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,5 @@ pyopenms~=3.3.0 openpyxl torch~=2.6.0 torchvision~=0.21.0 -pillow~=11.1.0 \ No newline at end of file +pillow~=11.1.0 +seaborn~=0.13.2 \ No newline at end of file