import matplotlib.pyplot as plt import numpy as np from config.config import load_args from dataset.dataset import load_data import torch import torch.nn as nn from models.model import Classification_model import torch.optim as optim from sklearn.metrics import confusion_matrix import seaborn as sn import pandas as pd def train(model, data_train, optimizer, loss_function, epoch): model.train() losses = 0. acc = 0. for param in model.parameters(): param.requires_grad = True for im, label in data_train: label = label.long() if torch.cuda.is_available(): im, label = im.cuda(), label.cuda() pred_logits = model.forward(im) pred_class = torch.argmax(pred_logits,dim=1) acc += (pred_class==label).sum().item() loss = loss_function(pred_logits,label) losses += loss.item() optimizer.zero_grad() loss.backward() optimizer.step() losses = losses/len(data_train.dataset) acc = acc/len(data_train.dataset) print('Train epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch,losses,acc)) return losses, acc def test(model, data_test, loss_function, epoch): model.eval() losses = 0. acc = 0. for param in model.parameters(): param.requires_grad = False for im, label in data_test: label = label.long() if torch.cuda.is_available(): im, label = im.cuda(), label.cuda() pred_logits = model.forward(im) pred_class = torch.argmax(pred_logits,dim=1) acc += (pred_class==label).sum().item() loss = loss_function(pred_logits,label) losses += loss.item() losses = losses/len(data_test.dataset) acc = acc/len(data_test.dataset) print('Test epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch,losses,acc)) return losses,acc def run(args): data_train, data_test = load_data(base_dir=args.dataset_dir, batch_size=args.batch_size) model = Classification_model(model = args.model, n_class=len(data_train.dataset.dataset.classes)) if args.pretrain_path is not None : load_model(model,args.pretrain_path) if torch.cuda.is_available(): model = model.cuda() best_acc = 0 train_acc=[] train_loss=[] val_acc=[] val_loss=[] loss_function = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) for e in range(args.epoches): loss, acc = train(model,data_train,optimizer,loss_function,e) train_loss.append(loss) train_acc.append(acc) if e%args.eval_inter==0 : loss, acc = test(model,data_test,loss_function,e) val_loss.append(loss) val_acc.append(acc) if acc > best_acc : save_model(model,args.save_path) best_acc = acc plt.plot(train_acc) plt.plot(val_acc) plt.plot(train_acc) plt.plot(train_acc) plt.show() plt.savefig('output/training_plot_noise_{}_lr_{}_model_{}.png'.format(args.noise_thresold,args.lr,args.model)) load_model(model, args.save_path) make_prediction(model,data_test, 'output/confusion_matrix_noise_{}_lr_{}_model_{}.png'.format(args.noise_thresold,args.lr,args.model)) def make_prediction(model, data, f_name): y_pred = [] y_true = [] # iterate over test data for im, label in data: label = label.long() if torch.cuda.is_available(): im = im.cuda() output = model(im) output = (torch.max(torch.exp(output), 1)[1]).data.cpu().numpy() y_pred.extend(output) label = label.data.cpu().numpy() y_true.extend(label) # Save Truth # constant for classes classes = data.dataset.dataset.classes # Build confusion matrix cf_matrix = confusion_matrix(y_true, y_pred) df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None], index=[i for i in classes], columns=[i for i in classes]) plt.figure(figsize=(12, 7)) sn.heatmap(df_cm, annot=True) plt.savefig(f_name) def save_model(model, path): print('Model saved') torch.save(model.state_dict(), path) def load_model(model, path): model.load_state_dict(torch.load(path, weights_only=True)) if __name__ == '__main__': args = load_args() run(args)