import matplotlib.pyplot as plt import numpy as np from config.config import load_args from dataset.dataset import load_data import torch import torch.nn as nn from models.model import Classification_model import torch.optim as optim from sklearn.metrics import confusion_matrix import seaborn as sn import pandas as pd def train(model, data_train, optimizer, loss_function, epoch): model.train() losses = 0. acc = 0. for param in model.parameters(): param.requires_grad = True for im, label in data_train: label = label.long() if torch.cuda.is_available(): im, label = im.cuda(), label.cuda() pred_logits = model.forward(im) pred_class = torch.argmax(pred_logits,dim=1) acc += (pred_class==label).sum().item() loss = loss_function(pred_logits,label) losses += loss.item() optimizer.zero_grad() loss.backward() optimizer.step() losses = losses/len(data_train.dataset) acc = acc/len(data_train.dataset) print('Train epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch,losses,acc)) return losses, acc def test(model, data_test, loss_function, epoch): model.eval() losses = 0. acc = 0. for param in model.parameters(): param.requires_grad = False for im, label in data_test: label = label.long() if torch.cuda.is_available(): im, label = im.cuda(), label.cuda() pred_logits = model.forward(im) pred_class = torch.argmax(pred_logits,dim=1) acc += (pred_class==label).sum().item() loss = loss_function(pred_logits,label) losses += loss.item() losses = losses/len(data_test.dataset) acc = acc/len(data_test.dataset) print('Test epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch,losses,acc)) return losses,acc def run(args): model = Classification_model(n_class=9) if torch.cuda.is_available(): model = model.cuda() best_acc = 0 train_acc=[] train_loss=[] val_acc=[] val_loss=[] if args.pretrain_path is not None : load_model(model,args.pretrain_path) data_train, data_test = load_data(base_dir=args.dataset_dir, batch_size=args.batch_size) loss_function = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) # for e in range(args.epoches): # loss, acc = train(model,data_train,optimizer,loss_function,e) # train_loss.append(loss) # train_acc.append(acc) # if e%args.eval_inter==0 : # loss, acc = test(model,data_test,loss_function,e) # val_loss.append(loss) # val_acc.append(acc) # if acc > best_acc : # save_model(model,args.save_path) # best_acc = acc # plt.plot(train_acc) # plt.plot(val_acc) # plt.plot(train_acc) # plt.plot(train_acc) # plt.show() # plt.savefig('output/training_plot.png') load_model(model, args.save_path) make_prediction(model,data_test) def make_prediction(model, data): y_pred = [] y_true = [] # iterate over test data for im, label in data: label = label.long() if torch.cuda.is_available(): im = im.cuda() output = model(im) output = (torch.max(torch.exp(output), 1)[1]).data.cpu().numpy() y_pred.extend(output) label = label.data.cpu().numpy() y_true.extend(label) # Save Truth # constant for classes classes = data.dataset.dataset.classes # Build confusion matrix cf_matrix = confusion_matrix(y_true, y_pred) df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None], index=[i for i in classes], columns=[i for i in classes]) plt.figure(figsize=(12, 7)) sn.heatmap(df_cm, annot=True) plt.savefig('output.png') def save_model(model, path): print('Model saved') torch.save(model.state_dict(), path) def load_model(model, path): model.load_state_dict(torch.load(path, weights_only=True)) if __name__ == '__main__': args = load_args() run(args)