Skip to content
Snippets Groups Projects
Commit 38ca8ef3 authored by Schneider Leo's avatar Schneider Leo
Browse files

model cuda loading

parent 4efef23b
No related branches found
No related tags found
No related merge requests found
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np
from config.config import load_args from config.config import load_args
from dataset.dataset import load_data from dataset.dataset import load_data
...@@ -6,7 +7,9 @@ import torch ...@@ -6,7 +7,9 @@ import torch
import torch.nn as nn import torch.nn as nn
from models.model import Classification_model from models.model import Classification_model
import torch.optim as optim import torch.optim as optim
from sklearn.metrics import confusion_matrix
import seaborn as sn
import pandas as pd
...@@ -31,7 +34,7 @@ def train(model, data_train, optimizer, loss_function, epoch): ...@@ -31,7 +34,7 @@ def train(model, data_train, optimizer, loss_function, epoch):
optimizer.step() optimizer.step()
losses = losses/len(data_train.dataset) losses = losses/len(data_train.dataset)
acc = acc/len(data_train.dataset) acc = acc/len(data_train.dataset)
print('Train epoch ',epoch, 'loss : ',losses,' acc : ',acc) print('Test epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch,losses,acc))
return losses, acc return losses, acc
def test(model, data_test, loss_function, epoch): def test(model, data_test, loss_function, epoch):
...@@ -52,7 +55,7 @@ def test(model, data_test, loss_function, epoch): ...@@ -52,7 +55,7 @@ def test(model, data_test, loss_function, epoch):
losses += loss.item() losses += loss.item()
losses = losses/len(data_test.dataset) losses = losses/len(data_test.dataset)
acc = acc/len(data_test.dataset) acc = acc/len(data_test.dataset)
print('Test epoch ',epoch,'loss : ',losses,' acc : ',acc) print('Test epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch,losses,acc))
return losses,acc return losses,acc
def run(args): def run(args):
...@@ -85,8 +88,44 @@ def run(args): ...@@ -85,8 +88,44 @@ def run(args):
plt.plot(val_acc) plt.plot(val_acc)
plt.plot(train_acc) plt.plot(train_acc)
plt.plot(train_acc) plt.plot(train_acc)
plt.show()
plt.savefig('output/training_plot.png')
load_model(model, args.save_path)
make_prediction(model,data_test)
def make_prediction(model, data):
y_pred = []
y_true = []
# iterate over test data
for im, label in data:
label = label.long()
if torch.cuda.is_available():
im = im.cuda()
output = model(im)
output = (torch.max(torch.exp(output), 1)[1]).data.cpu().numpy()
y_pred.extend(output)
label = label.data.cpu().numpy()
y_true.extend(label) # Save Truth
# constant for classes
classes = data.dataset.dataset.classes
# Build confusion matrix
cf_matrix = confusion_matrix(y_true, y_pred)
df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None], index=[i for i in classes],
columns=[i for i in classes])
plt.figure(figsize=(12, 7))
sn.heatmap(df_cm, annot=True)
plt.savefig('output.png')
def save_model(model, path): def save_model(model, path):
print('Model saved')
torch.save(model.state_dict(), path) torch.save(model.state_dict(), path)
def load_model(model, path): def load_model(model, path):
......
...@@ -5,4 +5,5 @@ pyopenms~=3.3.0 ...@@ -5,4 +5,5 @@ pyopenms~=3.3.0
openpyxl openpyxl
torch~=2.6.0 torch~=2.6.0
torchvision~=0.21.0 torchvision~=0.21.0
pillow~=11.1.0 pillow~=11.1.0
\ No newline at end of file seaborn~=0.13.2
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment