Skip to content
Snippets Groups Projects
Commit 38ca8ef3 authored by Schneider Leo's avatar Schneider Leo
Browse files

model cuda loading

parent 4efef23b
No related branches found
No related tags found
No related merge requests found
import matplotlib.pyplot as plt
import numpy as np
from config.config import load_args
from dataset.dataset import load_data
......@@ -6,7 +7,9 @@ import torch
import torch.nn as nn
from models.model import Classification_model
import torch.optim as optim
from sklearn.metrics import confusion_matrix
import seaborn as sn
import pandas as pd
......@@ -31,7 +34,7 @@ def train(model, data_train, optimizer, loss_function, epoch):
optimizer.step()
losses = losses/len(data_train.dataset)
acc = acc/len(data_train.dataset)
print('Train epoch ',epoch, 'loss : ',losses,' acc : ',acc)
print('Test epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch,losses,acc))
return losses, acc
def test(model, data_test, loss_function, epoch):
......@@ -52,7 +55,7 @@ def test(model, data_test, loss_function, epoch):
losses += loss.item()
losses = losses/len(data_test.dataset)
acc = acc/len(data_test.dataset)
print('Test epoch ',epoch,'loss : ',losses,' acc : ',acc)
print('Test epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch,losses,acc))
return losses,acc
def run(args):
......@@ -85,8 +88,44 @@ def run(args):
plt.plot(val_acc)
plt.plot(train_acc)
plt.plot(train_acc)
plt.show()
plt.savefig('output/training_plot.png')
load_model(model, args.save_path)
make_prediction(model,data_test)
def make_prediction(model, data):
y_pred = []
y_true = []
# iterate over test data
for im, label in data:
label = label.long()
if torch.cuda.is_available():
im = im.cuda()
output = model(im)
output = (torch.max(torch.exp(output), 1)[1]).data.cpu().numpy()
y_pred.extend(output)
label = label.data.cpu().numpy()
y_true.extend(label) # Save Truth
# constant for classes
classes = data.dataset.dataset.classes
# Build confusion matrix
cf_matrix = confusion_matrix(y_true, y_pred)
df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None], index=[i for i in classes],
columns=[i for i in classes])
plt.figure(figsize=(12, 7))
sn.heatmap(df_cm, annot=True)
plt.savefig('output.png')
def save_model(model, path):
print('Model saved')
torch.save(model.state_dict(), path)
def load_model(model, path):
......
......@@ -5,4 +5,5 @@ pyopenms~=3.3.0
openpyxl
torch~=2.6.0
torchvision~=0.21.0
pillow~=11.1.0
\ No newline at end of file
pillow~=11.1.0
seaborn~=0.13.2
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment