Skip to content
Snippets Groups Projects
main_custom.py 14.4 KiB
Newer Older
Léo Schneider's avatar
Léo Schneider committed
import os
Léo Schneider's avatar
Léo Schneider committed

import pandas as pd
Léo Schneider's avatar
Léo Schneider committed
import torch
import torch.optim as optim
import wandb as wdb

import common_dataset
import dataloader
from config_common import load_args
from common_dataset import load_data
from dataloader import load_data
from loss import masked_cos_sim, distance, masked_spectral_angle
Schneider Leo's avatar
Schneider Leo committed
from model_custom import Model_Common_Transformer
Léo Schneider's avatar
Léo Schneider committed
from model import RT_pred_model_self_attention_multi


def train(model, data_train, epoch, optimizer, criterion_rt, criterion_intensity, metric_rt, metric_intensity, forward,
          wandb=None):
    losses_rt = 0.
    losses_int = 0.
    dist_rt_acc = 0.
    dist_int_acc = 0.
    model.train()
    for param in model.parameters():
        param.requires_grad = True
    if forward == 'both':
        for seq, charge, rt, intensity in data_train:
            rt, intensity = rt.float(), intensity.float()
            if torch.cuda.is_available():
                seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda()
            pred_rt, pred_int = model.forward(seq, charge)
            loss_rt = criterion_rt(rt, pred_rt)
            loss_int = criterion_intensity(intensity, pred_int)
            loss = loss_rt + loss_int
            dist_rt = metric_rt(rt, pred_rt)
            dist_int = metric_intensity(intensity, pred_int)
            dist_rt_acc += dist_rt.item()
            dist_int_acc += dist_int.item()
            losses_rt += loss_rt.item()
            losses_int += 5.*loss_int.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        if wandb is not None:
            wdb.log({"train rt loss": losses_rt / len(data_train), "train int loss": losses_int / len(data_train),
                     "train rt mean metric": dist_rt_acc / len(data_train),
                     "train int mean metric": dist_int_acc / len(data_train),
                     'train epoch': epoch})

        print('epoch : ', epoch, 'train rt loss', losses_rt / len(data_train), 'train int loss',
              losses_int / len(data_train), "train rt mean metric : ", dist_rt_acc / len(data_train),
              "train int mean metric",
              dist_int_acc / len(data_train))

    if forward == 'rt':
        for seq, rt in data_train:
            rt = rt.float()
            if torch.cuda.is_available():
                seq, rt = seq.cuda(), rt.cuda()
            pred_rt = model.forward_rt(seq)
            loss_rt = criterion_rt(rt, pred_rt)
            loss = loss_rt
            dist_rt = metric_rt(rt, pred_rt)
            dist_rt_acc += dist_rt.item()
            losses_rt += loss_rt.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        if wandb is not None:
            wdb.log({"train rt loss": losses_rt / len(data_train),
                     "train rt mean metric": dist_rt_acc / len(data_train),
                     'train epoch': epoch})

        print('epoch : ', epoch, 'train rt loss', losses_rt / len(data_train), "train rt mean metric : ",
              dist_rt_acc / len(data_train))

    if forward == 'int':
        for seq, charge, intensity in data_train:
            intensity = intensity.float()
            if torch.cuda.is_available():
                seq, charge, intensity = seq.cuda(), charge.cuda(), intensity.cuda()
            pred_int = model.forward_int(seq, charge)
            loss_int = criterion_intensity(intensity, pred_int)
            loss = loss_int
            dist_int = metric_intensity(intensity, pred_int)
            dist_int_acc += dist_int.item()
            losses_int += loss_int.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        if wandb is not None:
            wdb.log({"train int loss": losses_int / len(data_train),
                     "train int mean metric": dist_int_acc / len(data_train),
                     'train epoch': epoch})

        print('epoch : ', epoch, 'train int loss',
              losses_int / len(data_train),
              "train int mean metric",
              dist_int_acc / len(data_train))


def eval(model, data_val, epoch, criterion_rt, criterion_intensity, metric_rt, metric_intensity, forward, wandb=None):
    losses_rt = 0.
    losses_int = 0.
    dist_rt_acc = 0.
    dist_int_acc = 0.
    for param in model.parameters():
        param.requires_grad = False
    if forward == 'both':
        for seq, charge, rt, intensity in data_val:
            rt, intensity = rt.float(), intensity.float()
            if torch.cuda.is_available():
                seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda()
            pred_rt, pred_int = model.forward(seq, charge)
            loss_rt = criterion_rt(rt, pred_rt)
            loss_int = criterion_intensity(intensity, pred_int)
            losses_rt += loss_rt.item()
            losses_int += loss_int.item()
            dist_rt = metric_rt(rt, pred_rt)
            dist_int = metric_intensity(intensity, pred_int)
            dist_rt_acc += dist_rt.item()
            dist_int_acc += dist_int.item()

        if wandb is not None:
            wdb.log({"val rt loss": losses_rt / len(data_val), "val int loss": losses_int / len(data_val),
                     "val rt mean metric": dist_rt_acc / len(data_val),
                     "val int mean metric": dist_int_acc / len(data_val),
                     'val epoch': epoch})

        print('epoch : ', epoch, 'val rt loss', losses_rt / len(data_val), 'val int loss', losses_int / len(data_val),
              "val rt mean metric : ",
              dist_rt_acc / len(data_val), "val int mean metric", dist_int_acc / len(data_val))

    if forward == 'rt':  #adapted to prosit dataset format
        for seq, rt in data_val:
            rt = rt.float()
            if torch.cuda.is_available():
                seq, rt = seq.cuda(), rt.cuda()
            pred_rt = model.forward_rt(seq)
            loss_rt = criterion_rt(rt, pred_rt)
            losses_rt += loss_rt.item()
            dist_rt = metric_rt(rt, pred_rt)
            dist_rt_acc += dist_rt.item()

        if wandb is not None:
            wdb.log({"val rt loss": losses_rt / len(data_val),
                     "val rt mean metric": dist_rt_acc / len(data_val),
                     'val epoch': epoch})

        print('epoch : ', epoch, 'val rt loss', losses_rt / len(data_val),
              "val rt mean metric : ",
              dist_rt_acc / len(data_val))

    if forward == 'int':  #adapted to prosit dataset format
        for seq, charge, _, intensity in data_val:
            intensity = intensity.float()
            if torch.cuda.is_available():
                seq, charge, intensity = seq.cuda(), charge.cuda(), intensity.cuda()
            pred_int = model.forward_int(seq, charge)
            loss_int = criterion_intensity(intensity, pred_int)
            losses_int += loss_int.item()
            dist_int = metric_intensity(intensity, pred_int)
            dist_int_acc += dist_int.item()
        if wandb is not None:
            wdb.log({"val int loss": losses_int / len(data_val),
                     "val int mean metric": dist_int_acc / len(data_val),
                     'val epoch': epoch})
        print('epoch : ', epoch, 'val int loss', losses_int / len(data_val), "val int mean metric",
              dist_int_acc / len(data_val))


def run(epochs, eval_inter, save_inter, model, data_train, data_val, data_test, optimizer, criterion_rt,
Léo Schneider's avatar
Léo Schneider committed
        criterion_intensity, metric_rt, metric_intensity, forward, wandb=None, output='output/out.csv'):
Schneider Leo's avatar
Schneider Leo committed

    if forward =='transfer' :
        for e in range(1, epochs + 1):
            train(model, data_train, e, optimizer, criterion_rt, criterion_intensity, metric_rt, metric_intensity, 'rt',
                  wandb=wandb)
            if e % eval_inter == 0:
                eval(model, data_val, e, criterion_rt, criterion_intensity, metric_rt, metric_intensity, 'both',
                     wandb=wandb)
            if e % save_inter == 0:
                save(model, 'model_common_' + str(e) + '.pt')
        save_pred(model, data_val, 'rt', output)

    else :
        for e in range(1, epochs + 1):
            train(model, data_train, e, optimizer, criterion_rt, criterion_intensity, metric_rt, metric_intensity, forward,
                  wandb=wandb)
            if e % eval_inter == 0:
                eval(model, data_val, e, criterion_rt, criterion_intensity, metric_rt, metric_intensity, forward,
                     wandb=wandb)
            if e % save_inter == 0:
                save(model, 'model_common_' + str(e) + '.pt')
        save_pred(model, data_val, forward, output)
Léo Schneider's avatar
Léo Schneider committed


def main(args):
    if args.wandb is not None:
        os.environ["WANDB_API_KEY"] = 'b4a27ac6b6145e1a5d0ee7f9e2e8c20bd101dccd'
        os.environ["WANDB_MODE"] = "offline"
        os.environ["WANDB_DIR"] = os.path.abspath("./wandb_run")

        wdb.init(project="Common prediction", dir='./wandb_run', name=args.wandb)
    print(args)
    print('Cuda : ', torch.cuda.is_available())

    if args.forward == 'both':
        data_train, data_val, data_test = common_dataset.load_data(path_train=args.dataset_train,
                                                                   path_val=args.dataset_val,
                                                                   path_test=args.dataset_test,
Schneider Leo's avatar
Schneider Leo committed
                                                                   batch_size=args.batch_size, length=25, pad = True, convert=True, vocab='iapuc')
Léo Schneider's avatar
Léo Schneider committed
    elif args.forward == 'rt':
Schneider Leo's avatar
Schneider Leo committed
        data_train, _, _ = dataloader.load_data(data_sources=[args.dataset_train,args.dataset_val,args.dataset_test],
                                                               batch_size=args.batch_size, length=25)

    elif args.forward == 'transfer':
Schneider Leo's avatar
Schneider Leo committed
        data_train, _, _ = dataloader.load_data(data_sources=[args.dataset_train,'database/data_holdout.csv','database/data_holdout.csv'],
Léo Schneider's avatar
Léo Schneider committed
                                                               batch_size=args.batch_size, length=25)

Schneider Leo's avatar
Schneider Leo committed
        _, data_val, data_test = common_dataset.load_data(path_train=args.dataset_val,
                                                                   path_val=args.dataset_val,
                                                                   path_test=args.dataset_test,
                                                                   batch_size=args.batch_size, length=25, pad = True, convert=True, vocab='iapuc')

Léo Schneider's avatar
Léo Schneider committed
    print('\nData loaded')

Schneider Leo's avatar
Schneider Leo committed
    model = Model_Common_Transformer(encoder_ff=args.encoder_ff, decoder_rt_ff=args.decoder_rt_ff,
Léo Schneider's avatar
Léo Schneider committed
                                     decoder_int_ff=args.decoder_int_ff
                                     , n_head=args.n_head, encoder_num_layer=args.encoder_num_layer,
                                     decoder_int_num_layer=args.decoder_int_num_layer,
                                     decoder_rt_num_layer=args.decoder_rt_num_layer, drop_rate=args.drop_rate,
                                     embedding_dim=args.embedding_dim, acti=args.activation, norm=args.norm_first)
    if torch.cuda.is_available():
        model = model.cuda()
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    print('\nModel initialised')

    run(epochs=args.epochs, eval_inter=args.eval_inter, save_inter=args.save_inter, model=model, data_train=data_train,
        data_val=data_val, data_test=data_test, optimizer=optimizer, criterion_rt=torch.nn.MSELoss(),
        criterion_intensity=masked_cos_sim, metric_rt=distance, metric_intensity=masked_spectral_angle,
Léo Schneider's avatar
Léo Schneider committed
        wandb=args.wandb, forward=args.forward, output=args.output)
Léo Schneider's avatar
Léo Schneider committed

    if args.wandb is not None:
        wdb.finish()


def save(model, checkpoint_name):
    print('\nModel Saving...')
    os.makedirs('checkpoints', exist_ok=True)
    torch.save(model, os.path.join('checkpoints', checkpoint_name))


def load(path):
    model = torch.load(os.path.join('checkpoints', path))
    return model


def get_n_params(model):
    pp = 0
    for n, p in list(model.named_parameters()):
        nn = 1

        for s in list(p.size()):
            nn = nn * s
        print(n, nn)
        pp += nn
    return pp

Léo Schneider's avatar
Léo Schneider committed
def save_pred(model, data_val, forward, output_path):
    data_frame = pd.DataFrame()
    for param in model.parameters():
        param.requires_grad = False
    if forward == 'both':
        pred_rt, pred_int, seqs, charges, true_rt, true_int = [], [], [], [], [], []
        for seq, charge, rt, intensity in data_val:
            rt, intensity = rt.float(), intensity.float()
            if torch.cuda.is_available():
                seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda()
            pr_rt, pr_intensity = model.forward(seq, charge)
            pred_rt.extend(pr_rt.data.cpu().tolist())
            pred_int.extend(pr_intensity.data.cpu().tolist())
            seqs.extend(seq.data.cpu().tolist())
            charges.extend(charge.data.cpu().tolist())
            true_rt.extend(rt.data.cpu().tolist())
            true_int.extend(intensity.data.cpu().tolist())
Léo Schneider's avatar
Léo Schneider committed
        data_frame['rt pred'] = pred_rt
        data_frame['seq'] = seqs
        data_frame['pred int'] = pred_int
        data_frame['true rt'] = true_rt
        data_frame['true int'] = true_int
        data_frame['charge'] = charges
Léo Schneider's avatar
Léo Schneider committed



    if forward == 'rt':  #adapted to prosit dataset format
        pred_rt, seqs, true_rt = [], [], []
        for seq, rt in data_val:
            rt = rt.float()
            if torch.cuda.is_available():
                seq, rt = seq.cuda(), rt.cuda()
            pr_rt = model.forward_rt(seq)
            pred_rt.extend(pr_rt.data.cpu().tolist())
            seqs.extend(seq.data.cpu().tolist())
            true_rt.extend(rt.data.cpu().tolist())
Léo Schneider's avatar
Léo Schneider committed
        data_frame['rt pred'] = pred_rt
        data_frame['seq'] = seqs
        data_frame['true rt'] = true_rt
Léo Schneider's avatar
Léo Schneider committed


    if forward == 'int':  #adapted to prosit dataset format
        pred_int, seqs, charges, true_int = [], [], [], []
        for seq, charge, _, intensity in data_val:
            intensity = intensity.float()
            if torch.cuda.is_available():
                seq, charge, intensity = seq.cuda(), charge.cuda(), intensity.cuda()
            pred_int = model.forward_int(seq, charge)
            seqs.extend(seq.data.cpu().tolist())
            charges.extend(charge.data.cpu().tolist())
            true_int.extend(intensity.data.cpu().tolist())
Léo Schneider's avatar
Léo Schneider committed
        data_frame['seq'] = seqs
        data_frame['pred int'] = pred_int
        data_frame['true int'] = true_int
        data_frame['charge'] = charges
Léo Schneider's avatar
Léo Schneider committed

    data_frame.to_csv(output_path)

Léo Schneider's avatar
Léo Schneider committed

if __name__ == "__main__":
    args = load_args()
    main(args)

Léo Schneider's avatar
Léo Schneider committed