import matplotlib.pyplot as plt import numpy as np from config import load_args_contrastive from dataset_ref import load_data_duo_batched, load_data_duo import torch import torch.nn as nn from model import Classification_model_contrastive, Classification_model_duo_contrastive import torch.optim as optim from sklearn.metrics import confusion_matrix import seaborn as sn import pandas as pd def train_duo(model, data_train, optimizer, loss_function, epoch): model.train() losses = 0. acc = 0. for param in model.parameters(): param.requires_grad = True for imaer,imana, img_ref, label in data_train: label = label.long() if torch.cuda.is_available(): imaer = imaer.cuda() imana = imana.cuda() img_ref = img_ref.cuda() label = label.cuda() pred_logits = model.forward(imaer,imana,img_ref) pred_class = torch.argmax(pred_logits,dim=1) acc += (pred_class==label).sum().item() loss = loss_function(pred_logits,label) losses += loss.item() optimizer.zero_grad() loss.backward() optimizer.step() losses = losses/len(data_train.dataset) acc = acc/len(data_train.dataset) print('Train epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch,losses,acc)) return losses, acc def test_duo(model, data_test, loss_function, epoch): model.eval() losses = 0. acc = 0. for param in model.parameters(): param.requires_grad = False for imaer,imana, img_ref, label in data_test: label = label.long() if torch.cuda.is_available(): imaer = imaer.cuda() imana = imana.cuda() img_ref = img_ref.cuda() label = label.cuda() pred_logits = model.forward(imaer,imana,img_ref) pred_class = torch.argmax(pred_logits,dim=1) acc += (pred_class==label).sum().item() loss = loss_function(pred_logits,label) losses += loss.item() losses = losses/len(data_test.dataset) acc = acc/len(data_test.dataset) print('Test epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch,losses,acc)) return losses,acc def run_duo(args): #load data data_train, data_test = load_data_duo(base_dir=args.dataset_dir, batch_size=args.batch_size, ref_dir=args.dataset_ref_dir, positive_prop=args.positive_prop) data_train_batch, data_test_batch = load_data_duo_batched(base_dir=args.dataset_dir, ref_dir=args.dataset_ref_dir) #load model model = Classification_model_duo_contrastive(model = args.model, n_class=2) model.double() #load weight if args.pretrain_path is not None : load_model(model,args.pretrain_path) #move parameters to GPU if torch.cuda.is_available(): model = model.cuda() #init accumulators best_acc = 0 train_acc=[] train_loss=[] val_acc=[] val_loss=[] #init training loss_function = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) #train model for e in range(args.epoches): loss, acc = train_duo(model,data_train,optimizer,loss_function,e) train_loss.append(loss) train_acc.append(acc) if e%args.eval_inter==0 : loss, acc = test_duo(model,data_test_batch,loss_function,e) val_loss.append(loss) val_acc.append(acc) if acc > best_acc : save_model(model,args.save_path) best_acc = acc # plot and save training figs plt.plot(train_acc) plt.plot(val_acc) plt.plot(train_acc) plt.plot(train_acc) plt.ylim(0, 1.05) plt.show() plt.savefig('output/training_plot_contrastive_noise_{}_lr_{}_model_{}.png'.format(args.noise_threshold,args.lr,args.model)) #load and evaluate best model load_model(model, args.save_path) make_prediction_duo(model,data_test, 'output/confusion_matrix_contractive_noise_{}_lr_{}_model_{}_{}.png'.format(args.noise_threshold,args.lr,args.model)) def make_prediction_duo(model, data, f_name): y_pred = [] y_true = [] # iterate over test data for imaer,imana,img_ref, label in data: label = label.long() if torch.cuda.is_available(): imaer = imaer.cuda() imana = imana.cuda() img_ref = img_ref.cuda() label = label.cuda() output = model(imaer,imana,img_ref) output = (torch.max(torch.exp(output), 1)[1]).data.cpu().numpy() y_pred.extend(output) label = label.data.cpu().numpy() y_true.extend(label) # Save Truth # constant for classes # Build confusion matrix cf_matrix = confusion_matrix(y_true, y_pred) df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None], index=[i for i in range(2)], columns=['True','False']) print('Saving Confusion Matrix') plt.figure(figsize=(14, 9)) sn.heatmap(df_cm, annot=cf_matrix) plt.savefig(f_name) def save_model(model, path): print('Model saved') torch.save(model.state_dict(), path) def load_model(model, path): model.load_state_dict(torch.load(path, weights_only=True)) if __name__ == '__main__': args = load_args_contrastive() run_duo(args)