-
Schneider Leo authored6c0c79fb
main.py 7.37 KiB
import matplotlib.pyplot as plt
import numpy as np
from config import load_args_contrastive
from dataset_ref import load_data_duo
import torch
import torch.nn as nn
from model import Classification_model_duo_contrastive
import torch.optim as optim
from sklearn.metrics import confusion_matrix
import seaborn as sn
import pandas as pd
def train_duo(model, data_train, optimizer, loss_function, epoch):
model.train()
losses = 0.
acc = 0.
for param in model.parameters():
param.requires_grad = True
for imaer,imana, img_ref, label in data_train:
label = label.long()
if torch.cuda.is_available():
imaer = imaer.cuda()
imana = imana.cuda()
img_ref = img_ref.cuda()
label = label.cuda()
pred_logits = model.forward(imaer,imana,img_ref)
pred_class = torch.argmax(pred_logits,dim=1)
acc += (pred_class==label).sum().item()
loss = loss_function(pred_logits,label)
losses += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses = losses/len(data_train.dataset)
acc = acc/len(data_train.dataset)
print('Train epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch,losses,acc))
return losses, acc
def test_duo(model, data_test, loss_function, epoch):
model.eval()
losses = 0.
acc = 0.
acc_contrastive = 0.
for param in model.parameters():
param.requires_grad = False
for imaer,imana, img_ref, label in data_test:
imaer = imaer.transpose(0,1)
imana = imana.transpose(0,1)
img_ref = img_ref.transpose(0,1)
label = label.transpose(0,1)
label = label.squeeze()
label = label.long()
if torch.cuda.is_available():
imaer = imaer.cuda()
imana = imana.cuda()
img_ref = img_ref.cuda()
label = label.cuda()
label_class = torch.argmin(label).data.cpu().numpy()
pred_logits = model.forward(imaer,imana,img_ref)
pred_class = torch.argmax(pred_logits[:,0]).tolist()
acc_contrastive += (torch.argmax(pred_logits,dim=1).data.cpu().numpy()==label.data.cpu().numpy()).sum().item()
acc += (pred_class==label_class)
loss = loss_function(pred_logits,label)
losses += loss.item()
losses = losses/(label.shape[0]*len(data_test.dataset))
acc = acc/(len(data_test.dataset))
acc_contrastive = acc_contrastive /(label.shape[0]*len(data_test.dataset))
print('Test epoch {}, loss : {:.3f} acc : {:.3f} acc contrastive : {:.3f}'.format(epoch,losses,acc,acc_contrastive))
return losses,acc,acc_contrastive
def run_duo(args):
#load data
data_train, data_test_batch = load_data_duo(base_dir_train=args.dataset_train_dir, base_dir_test=args.dataset_val_dir, batch_size=args.batch_size,
ref_dir=args.dataset_ref_dir, positive_prop=args.positive_prop)
#load model
model = Classification_model_duo_contrastive(model = args.model, n_class=2)
model.double()
#load weight
if args.pretrain_path is not None :
load_model(model,args.pretrain_path)
#move parameters to GPU
if torch.cuda.is_available():
model = model.cuda()
#init accumulators
best_loss = 100
train_acc=[]
train_loss=[]
val_acc=[]
val_cont_acc=[]
val_loss=[]
#init training
loss_function = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
#train model
for e in range(args.epoches):
loss, acc = train_duo(model,data_train,optimizer,loss_function,e)
train_loss.append(loss)
train_acc.append(acc)
if e%args.eval_inter==0 :
loss, acc, acc_contrastive = test_duo(model,data_test_batch,loss_function,e)
val_loss.append(loss)
val_acc.append(acc)
val_cont_acc.append(acc_contrastive)
if loss < best_loss :
save_model(model,args.save_path)
best_loss = loss
# plot and save training figs
plt.clf()
plt.subplot(2, 1, 1)
plt.plot(train_acc, label='train cont acc')
plt.plot(val_cont_acc, label='val cont acc')
plt.plot(val_acc, label='val classification acc')
plt.title('Train and validation accuracy')
plt.xlabel('epoch')
plt.ylabel('accuracy')
plt.legend(loc="upper left")
plt.ylim(0, 1.05)
plt.tight_layout()
plt.subplot(2, 1, 2)
plt.plot(train_loss, label='train')
plt.plot(val_loss, label='val')
plt.title('Train and validation loss')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend(loc="upper left")
plt.tight_layout()
plt.show()
plt.savefig('output/training_plot_contrastive_noise_{}_lr_{}_model_{}.png'.format(args.noise_threshold,args.lr,args.model))
#load and evaluate best model
load_model(model, args.save_path)
make_prediction_duo(model,data_test_batch, 'output/confusion_matrix_contractive_noise_{}_lr_{}_model_{}.png'.format(args.noise_threshold,args.lr,args.model),
'output/confidence_matrix_contractive_noise_{}_lr_{}_model_{}.png'.format(args.noise_threshold,args.lr,args.model))
def make_prediction_duo(model, data, f_name, f_name2):
for imaer, imana, img_ref, label in data:
n_class = label.shape[1]
break
confidence_pred_list = [[] for i in range(n_class)]
y_pred = []
y_true = []
soft_max = nn.Softmax(dim=1)
# iterate over test data
for imaer,imana,img_ref, label in data:
imaer = imaer.transpose(0,1)
imana = imana.transpose(0,1)
img_ref = img_ref.transpose(0,1)
label = label.transpose(0,1)
label = label.squeeze()
label = label.long()
specie = torch.argmin(label)
if torch.cuda.is_available():
imaer = imaer.cuda()
imana = imana.cuda()
img_ref = img_ref.cuda()
label = label.cuda()
output = model(imaer,imana,img_ref)
confidence = soft_max(output)
confidence_pred_list[specie].append(confidence[:,0].data.cpu().numpy())
#Mono class output (only most postive paire)
output = torch.argmax(output[:,0])
label = torch.argmin(label)
y_pred.append(output.tolist())
y_true.append(label.tolist()) # Save Truth
# constant for classes
# Build confusion matrix
classes = data.dataset.classes
cf_matrix = confusion_matrix(y_true, y_pred)
confidence_matrix = np.zeros((n_class,n_class))
for i in range(n_class):
confidence_matrix[i]=np.mean(confidence_pred_list[i],axis=0)
df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None], index=[i for i in classes],
columns=[i for i in classes])
print('Saving Confusion Matrix')
plt.clf()
plt.figure(figsize=(14, 9))
sn.heatmap(df_cm, annot=cf_matrix)
plt.savefig(f_name)
df_cm = pd.DataFrame(confidence_matrix, index=[i for i in classes],
columns=[i for i in classes])
print('Saving Confidence Matrix')
plt.clf()
plt.figure(figsize=(14, 9))
sn.heatmap(df_cm, annot=confidence_matrix)
plt.savefig(f_name2)
def save_model(model, path):
print('Model saved')
torch.save(model.state_dict(), path)
def load_model(model, path):
model.load_state_dict(torch.load(path, weights_only=True))
if __name__ == '__main__':
args = load_args_contrastive()
run_duo(args)