Skip to content
Snippets Groups Projects
main.py 8.47 KiB
import matplotlib.pyplot as plt
import numpy as np

from config.config import load_args
from dataset.dataset import load_data, load_data_duo
import torch
import torch.nn as nn
from models.model import Classification_model, Classification_model_duo
import torch.optim as optim
from sklearn.metrics import confusion_matrix
import seaborn as sn
import pandas as pd



def train(model, data_train, optimizer, loss_function, epoch):
    model.train()
    losses = 0.
    acc = 0.
    for param in model.parameters():
        param.requires_grad = True

    for im, label in data_train:
        label = label.long()
        if torch.cuda.is_available():
            im, label = im.cuda(), label.cuda()
        pred_logits = model.forward(im)
        pred_class = torch.argmax(pred_logits,dim=1)
        acc += (pred_class==label).sum().item()
        loss = loss_function(pred_logits,label)
        losses += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    losses = losses/len(data_train.dataset)
    acc = acc/len(data_train.dataset)
    print('Train epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch,losses,acc))
    return losses, acc

def test(model, data_test, loss_function, epoch):
    model.eval()
    losses = 0.
    acc = 0.
    for param in model.parameters():
        param.requires_grad = False

    for im, label in data_test:
        label = label.long()
        if torch.cuda.is_available():
            im, label = im.cuda(), label.cuda()
        pred_logits = model.forward(im)
        pred_class = torch.argmax(pred_logits,dim=1)
        acc += (pred_class==label).sum().item()
        loss = loss_function(pred_logits,label)
        losses += loss.item()
    losses = losses/len(data_test.dataset)
    acc = acc/len(data_test.dataset)
    print('Test epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch,losses,acc))
    return losses,acc

def run(args):
    data_train, data_test = load_data(base_dir=args.dataset_dir, batch_size=args.batch_size)
    model = Classification_model(model = args.model, n_class=len(data_train.dataset.dataset.classes))
    if args.pretrain_path is not None :
        load_model(model,args.pretrain_path)
    if torch.cuda.is_available():
        model = model.cuda()
    best_acc = 0
    train_acc=[]
    train_loss=[]
    val_acc=[]
    val_loss=[]
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    for e in range(args.epoches):
        loss, acc = train(model,data_train,optimizer,loss_function,e)
        train_loss.append(loss)
        train_acc.append(acc)
        if e%args.eval_inter==0 :
            loss, acc = test(model,data_test,loss_function,e)
            val_loss.append(loss)
            val_acc.append(acc)
            if acc > best_acc :
                save_model(model,args.save_path)
                best_acc = acc
    plt.plot(train_acc)
    plt.plot(val_acc)
    plt.plot(train_acc)
    plt.plot(train_acc)
    plt.ylim(0, 1.05)
    plt.show()
    plt.savefig('output/training_plot_noise_{}_lr_{}_model_{}.png'.format(args.noise_threshold,args.lr,args.model))

    load_model(model, args.save_path)
    make_prediction(model,data_test, 'output/confusion_matrix_noise_{}_lr_{}_model_{}.png'.format(args.noise_threshold,args.lr,args.model))


def make_prediction(model, data, f_name):
    y_pred = []
    y_true = []

    # iterate over test data
    for im, label in data:
        label = label.long()
        if torch.cuda.is_available():
            im = im.cuda()
        output = model(im)

        output = (torch.max(torch.exp(output), 1)[1]).data.cpu().numpy()
        y_pred.extend(output)

        label = label.data.cpu().numpy()
        y_true.extend(label)  # Save Truth
    # constant for classes

    classes = data.dataset.dataset.classes

    # Build confusion matrix
    cf_matrix = confusion_matrix(y_true, y_pred)
    df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None], index=[i for i in classes],
                         columns=[i for i in classes])
    plt.figure(figsize=(12, 7))
    sn.heatmap(df_cm, annot=True)
    plt.savefig(f_name)


def train_duo(model, data_train, optimizer, loss_function, epoch):
    model.train()
    losses = 0.
    acc = 0.
    for param in model.parameters():
        param.requires_grad = True

    for imaer,imana, label in data_train:
        label = label.long()
        if torch.cuda.is_available():
            imaer = imaer.cuda()
            imana = imana.cuda()
            label = label.cuda()
        pred_logits = model.forward(imaer,imana)
        pred_class = torch.argmax(pred_logits,dim=1)
        acc += (pred_class==label).sum().item()
        loss = loss_function(pred_logits,label)
        losses += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    losses = losses/len(data_train.dataset)
    acc = acc/len(data_train.dataset)
    print('Train epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch,losses,acc))
    return losses, acc

def test_duo(model, data_test, loss_function, epoch):
    model.eval()
    losses = 0.
    acc = 0.
    for param in model.parameters():
        param.requires_grad = False

    for imaer,imana, label in data_test:
        label = label.long()
        if torch.cuda.is_available():
            imaer = imaer.cuda()
            imana = imana.cuda()
            label = label.cuda()
        pred_logits = model.forward(imaer,imana)
        pred_class = torch.argmax(pred_logits,dim=1)
        acc += (pred_class==label).sum().item()
        loss = loss_function(pred_logits,label)
        losses += loss.item()
    losses = losses/len(data_test.dataset)
    acc = acc/len(data_test.dataset)
    print('Test epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch,losses,acc))
    return losses,acc

def run_duo(args):
    data_train, data_test = load_data_duo(base_dir=args.dataset_dir, batch_size=args.batch_size)
    model = Classification_model_duo(model = args.model, n_class=len(data_train.dataset.dataset.classes))
    if args.pretrain_path is not None :
        load_model(model,args.pretrain_path)
    if torch.cuda.is_available():
        model = model.cuda()
    best_acc = 0
    train_acc=[]
    train_loss=[]
    val_acc=[]
    val_loss=[]
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    for e in range(args.epoches):
        loss, acc = train_duo(model,data_train,optimizer,loss_function,e)
        train_loss.append(loss)
        train_acc.append(acc)
        if e%args.eval_inter==0 :
            loss, acc = test_duo(model,data_test,loss_function,e)
            val_loss.append(loss)
            val_acc.append(acc)
            if acc > best_acc :
                save_model(model,args.save_path)
                best_acc = acc
    plt.plot(train_acc)
    plt.plot(val_acc)
    plt.plot(train_acc)
    plt.plot(train_acc)
    plt.ylim(0, 1.05)
    plt.show()
    plt.savefig('output/training_plot_noise_{}_lr_{}_model_{}_{}.png'.format(args.noise_threshold,args.lr,args.model,args.model_type))

    load_model(model, args.save_path)
    make_prediction_duo(model,data_test, 'output/confusion_matrix_noise_{}_lr_{}_model_{}_{}.png'.format(args.noise_threshold,args.lr,args.model,args.model_type))


def make_prediction_duo(model, data, f_name):
    y_pred = []
    y_true = []
    print('Building confusion matrix')
    # iterate over test data
    for imaer,imana, label in data:
        label = label.long()
        if torch.cuda.is_available():
            imaer = imaer.cuda()
            imana = imana.cuda()
            label = label.cuda()
        output = model(imaer,imana)

        output = (torch.max(torch.exp(output), 1)[1]).data.cpu().numpy()
        y_pred.extend(output)

        label = label.data.cpu().numpy()
        y_true.extend(label)  # Save Truth
    # constant for classes

    classes = data.dataset.dataset.classes
    print('Prediction made')
    # Build confusion matrix
    cf_matrix = confusion_matrix(y_true, y_pred)
    print('CM made')
    df_cm = pd.DataFrame(cf_matrix[:, None], index=[i for i in classes],
                         columns=[i for i in classes])

    print('Saving Confusion Matrix')
    plt.figure(figsize=(14, 9))
    sn.heatmap(df_cm, annot=True)
    plt.savefig(f_name)


def save_model(model, path):
    print('Model saved')
    torch.save(model.state_dict(), path)

def load_model(model, path):
    model.load_state_dict(torch.load(path, weights_only=True))



if __name__ == '__main__':
    args = load_args()
    if args.model_type=='duo':
        run_duo(args)
    else :
        run(args)