import matplotlib.pyplot as plt import numpy as np from config import load_args_contrastive from dataset_ref import load_data_duo import torch import torch.nn as nn from model import Classification_model_duo_contrastive import torch.optim as optim from sklearn.metrics import confusion_matrix import seaborn as sn import pandas as pd def train_duo(model, data_train, optimizer, loss_function, epoch): model.train() losses = 0. acc = 0. for param in model.parameters(): param.requires_grad = True for imaer,imana, img_ref, label in data_train: label = label.long() if torch.cuda.is_available(): imaer = imaer.cuda() imana = imana.cuda() img_ref = img_ref.cuda() label = label.cuda() pred_logits = model.forward(imaer,imana,img_ref) pred_class = torch.argmax(pred_logits,dim=1) acc += (pred_class==label).sum().item() loss = loss_function(pred_logits,label) losses += loss.item() optimizer.zero_grad() loss.backward() optimizer.step() losses = losses/len(data_train.dataset) acc = acc/len(data_train.dataset) print('Train epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch,losses,acc)) return losses, acc def test_duo(model, data_test, loss_function, epoch): model.eval() losses = 0. acc = 0. acc_contrastive = 0. for param in model.parameters(): param.requires_grad = False for imaer,imana, img_ref, label in data_test: imaer = imaer.transpose(0,1) imana = imana.transpose(0,1) img_ref = img_ref.transpose(0,1) label = label.transpose(0,1) label = label.squeeze() label = label.long() if torch.cuda.is_available(): imaer = imaer.cuda() imana = imana.cuda() img_ref = img_ref.cuda() label = label.cuda() label_class = torch.argmin(label).data.cpu().numpy() pred_logits = model.forward(imaer,imana,img_ref) pred_class = torch.argmax(pred_logits[:,0]).tolist() acc_contrastive += (torch.argmax(pred_logits,dim=1).data.cpu().numpy()==label.data.cpu().numpy()).sum().item() acc += (pred_class==label_class) loss = loss_function(pred_logits,label) losses += loss.item() losses = losses/(label.shape[0]*len(data_test.dataset)) acc = acc/(len(data_test.dataset)) acc_contrastive = acc_contrastive /(label.shape[0]*len(data_test.dataset)) print('Test epoch {}, loss : {:.3f} acc : {:.3f} acc contrastive : {:.3f}'.format(epoch,losses,acc,acc_contrastive)) return losses,acc,acc_contrastive def run_duo(args): #load data data_train, data_test_batch = load_data_duo(base_dir_train=args.dataset_train_dir, base_dir_test=args.dataset_val_dir, batch_size=args.batch_size, ref_dir=args.dataset_ref_dir, positive_prop=args.positive_prop) #load model model = Classification_model_duo_contrastive(model = args.model, n_class=2) model.double() #load weight if args.pretrain_path is not None : load_model(model,args.pretrain_path) #move parameters to GPU if torch.cuda.is_available(): model = model.cuda() #init accumulators best_loss = 100 train_acc=[] train_loss=[] val_acc=[] val_cont_acc=[] val_loss=[] #init training loss_function = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) #train model for e in range(args.epoches): loss, acc = train_duo(model,data_train,optimizer,loss_function,e) train_loss.append(loss) train_acc.append(acc) if e%args.eval_inter==0 : loss, acc, acc_contrastive = test_duo(model,data_test_batch,loss_function,e) val_loss.append(loss) val_acc.append(acc) val_cont_acc.append(acc_contrastive) if loss < best_loss : save_model(model,args.save_path) best_loss = loss # plot and save training figs plt.clf() plt.subplot(2, 1, 1) plt.plot(train_acc, label='train cont acc') plt.plot(val_cont_acc, label='val cont acc') plt.plot(val_acc, label='val classification acc') plt.title('Train and validation accuracy') plt.xlabel('epoch') plt.ylabel('accuracy') plt.legend(loc="upper left") plt.ylim(0, 1.05) plt.tight_layout() plt.subplot(2, 1, 2) plt.plot(train_loss, label='train') plt.plot(val_loss, label='val') plt.title('Train and validation loss') plt.xlabel('epoch') plt.ylabel('loss') plt.legend(loc="upper left") plt.tight_layout() plt.show() plt.savefig('output/training_plot_contrastive_noise_{}_lr_{}_model_{}.png'.format(args.noise_threshold,args.lr,args.model)) #load and evaluate best model load_model(model, args.save_path) make_prediction_duo(model,data_test_batch, 'output/confusion_matrix_contractive_noise_{}_lr_{}_model_{}.png'.format(args.noise_threshold,args.lr,args.model), 'output/confidence_matrix_contractive_noise_{}_lr_{}_model_{}.png'.format(args.noise_threshold,args.lr,args.model)) def make_prediction_duo(model, data, f_name, f_name2): for imaer, imana, img_ref, label in data: n_class = label.shape[1] break confidence_pred_list = [[] for i in range(n_class)] y_pred = [] y_true = [] soft_max = nn.Softmax(dim=1) # iterate over test data for imaer,imana,img_ref, label in data: imaer = imaer.transpose(0,1) imana = imana.transpose(0,1) img_ref = img_ref.transpose(0,1) label = label.transpose(0,1) label = label.squeeze() label = label.long() specie = torch.argmin(label) if torch.cuda.is_available(): imaer = imaer.cuda() imana = imana.cuda() img_ref = img_ref.cuda() label = label.cuda() output = model(imaer,imana,img_ref) confidence = soft_max(output) confidence_pred_list[specie].append(confidence[:,0].data.cpu().numpy()) #Mono class output (only most postive paire) output = torch.argmax(output[:,0]) label = torch.argmin(label) y_pred.append(output.tolist()) y_true.append(label.tolist()) # Save Truth # constant for classes # Build confusion matrix classes = data.dataset.classes cf_matrix = confusion_matrix(y_true, y_pred) confidence_matrix = np.zeros((n_class,n_class)) for i in range(n_class): confidence_matrix[i]=np.mean(confidence_pred_list[i],axis=0) df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None], index=[i for i in classes], columns=[i for i in classes]) print('Saving Confusion Matrix') plt.clf() plt.figure(figsize=(14, 9)) sn.heatmap(df_cm, annot=cf_matrix) plt.savefig(f_name) df_cm = pd.DataFrame(confidence_matrix, index=[i for i in classes], columns=[i for i in classes]) print('Saving Confidence Matrix') plt.clf() plt.figure(figsize=(14, 9)) sn.heatmap(df_cm, annot=confidence_matrix) plt.savefig(f_name2) def save_model(model, path): print('Model saved') torch.save(model.state_dict(), path) def load_model(model, path): model.load_state_dict(torch.load(path, weights_only=True)) if __name__ == '__main__': args = load_args_contrastive() run_duo(args)