Skip to content
Snippets Groups Projects
main_custom.py 16.9 KiB
Newer Older
Léo Schneider's avatar
Léo Schneider committed
import os
Léo Schneider's avatar
Léo Schneider committed
import pandas as pd
Léo Schneider's avatar
Léo Schneider committed
import torch
import torch.optim as optim
import wandb as wdb

import common_dataset
import dataloader
from config_common import load_args
from common_dataset import load_data
from dataloader import load_data
from loss import masked_cos_sim, distance, masked_spectral_angle
Schneider Leo's avatar
Schneider Leo committed
from model_custom import Model_Common_Transformer
Léo Schneider's avatar
Léo Schneider committed
from model import RT_pred_model_self_attention_multi


def train(model, data_train, epoch, optimizer, criterion_rt, criterion_intensity, metric_rt, metric_intensity, forward,
          wandb=None):
    losses_rt = 0.
    losses_int = 0.
    dist_rt_acc = 0.
    dist_int_acc = 0.
    model.train()
    for param in model.parameters():
        param.requires_grad = True
    if forward == 'both':
        for seq, charge, rt, intensity in data_train:
            rt, intensity = rt.float(), intensity.float()
            if torch.cuda.is_available():
                seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda()
            pred_rt, pred_int = model.forward(seq, charge)
            loss_rt = criterion_rt(rt, pred_rt)
            loss_int = criterion_intensity(intensity, pred_int)
            loss = loss_rt + loss_int
            dist_rt = metric_rt(rt, pred_rt)
            dist_int = metric_intensity(intensity, pred_int)
            dist_rt_acc += dist_rt.item()
            dist_int_acc += dist_int.item()
            losses_rt += loss_rt.item()
            losses_int += 5.*loss_int.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        if wandb is not None:
            wdb.log({"train rt loss": losses_rt / len(data_train), "train int loss": losses_int / len(data_train),
                     "train rt mean metric": dist_rt_acc / len(data_train),
                     "train int mean metric": dist_int_acc / len(data_train),
                     'train epoch': epoch})

        print('epoch : ', epoch, 'train rt loss', losses_rt / len(data_train), 'train int loss',
              losses_int / len(data_train), "train rt mean metric : ", dist_rt_acc / len(data_train),
              "train int mean metric",
              dist_int_acc / len(data_train))

    if forward == 'rt':
        for seq, rt in data_train:
            rt = rt.float()
            if torch.cuda.is_available():
                seq, rt = seq.cuda(), rt.cuda()
            pred_rt = model.forward_rt(seq)
            loss_rt = criterion_rt(rt, pred_rt)
            loss = loss_rt
            dist_rt = metric_rt(rt, pred_rt)
            dist_rt_acc += dist_rt.item()
            losses_rt += loss_rt.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        if wandb is not None:
            wdb.log({"train rt loss": losses_rt / len(data_train),
                     "train rt mean metric": dist_rt_acc / len(data_train),
                     'train epoch': epoch})

        print('epoch : ', epoch, 'train rt loss', losses_rt / len(data_train), "train rt mean metric : ",
              dist_rt_acc / len(data_train))

    if forward == 'int':
        for seq, charge, intensity in data_train:
            intensity = intensity.float()
            if torch.cuda.is_available():
                seq, charge, intensity = seq.cuda(), charge.cuda(), intensity.cuda()
            pred_int = model.forward_int(seq, charge)
            loss_int = criterion_intensity(intensity, pred_int)
            loss = loss_int
            dist_int = metric_intensity(intensity, pred_int)
            dist_int_acc += dist_int.item()
            losses_int += loss_int.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        if wandb is not None:
            wdb.log({"train int loss": losses_int / len(data_train),
                     "train int mean metric": dist_int_acc / len(data_train),
                     'train epoch': epoch})

        print('epoch : ', epoch, 'train int loss',
              losses_int / len(data_train),
              "train int mean metric",
              dist_int_acc / len(data_train))


def eval(model, data_val, epoch, criterion_rt, criterion_intensity, metric_rt, metric_intensity, forward, wandb=None):
    losses_rt = 0.
    losses_int = 0.
    dist_rt_acc = 0.
    dist_int_acc = 0.
    model.eval()
Léo Schneider's avatar
Léo Schneider committed
    for param in model.parameters():
        param.requires_grad = False
    if forward == 'both':
        for seq, charge, rt, intensity in data_val:
            rt, intensity = rt.float(), intensity.float()
            if torch.cuda.is_available():
                seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda()
            pred_rt, pred_int = model.forward(seq, charge)
Schneider Leo's avatar
Schneider Leo committed
            print(rt.shape,pred_rt.shape)
Léo Schneider's avatar
Léo Schneider committed
            loss_rt = criterion_rt(rt, pred_rt)
            loss_int = criterion_intensity(intensity, pred_int)
            losses_rt += loss_rt.item()
            losses_int += loss_int.item()
            dist_rt = metric_rt(rt, pred_rt)
            dist_int = metric_intensity(intensity, pred_int)
            dist_rt_acc += dist_rt.item()
            dist_int_acc += dist_int.item()

        if wandb is not None:
            wdb.log({"val rt loss": losses_rt / len(data_val), "val int loss": losses_int / len(data_val),
                     "val rt mean metric": dist_rt_acc / len(data_val),
                     "val int mean metric": dist_int_acc / len(data_val),
                     'val epoch': epoch})

        print('epoch : ', epoch, 'val rt loss', losses_rt / len(data_val), 'val int loss', losses_int / len(data_val),
              "val rt mean metric : ",
              dist_rt_acc / len(data_val), "val int mean metric", dist_int_acc / len(data_val))

    if forward == 'rt':  #adapted to prosit dataset format
        for seq, rt in data_val:
            rt = rt.float()
            if torch.cuda.is_available():
                seq, rt = seq.cuda(), rt.cuda()
            pred_rt = model.forward_rt(seq)
            loss_rt = criterion_rt(rt, pred_rt)
            losses_rt += loss_rt.item()
            dist_rt = metric_rt(rt, pred_rt)
            dist_rt_acc += dist_rt.item()

        if wandb is not None:
            wdb.log({"val rt loss": losses_rt / len(data_val),
                     "val rt mean metric": dist_rt_acc / len(data_val),
                     'val epoch': epoch})

        print('epoch : ', epoch, 'val rt loss', losses_rt / len(data_val),
              "val rt mean metric : ",
              dist_rt_acc / len(data_val))

    if forward == 'int':  #adapted to prosit dataset format
        for seq, charge, _, intensity in data_val:
            intensity = intensity.float()
            if torch.cuda.is_available():
                seq, charge, intensity = seq.cuda(), charge.cuda(), intensity.cuda()
            pred_int = model.forward_int(seq, charge)
            loss_int = criterion_intensity(intensity, pred_int)
            losses_int += loss_int.item()
            dist_int = metric_intensity(intensity, pred_int)
            dist_int_acc += dist_int.item()
        if wandb is not None:
            wdb.log({"val int loss": losses_int / len(data_val),
                     "val int mean metric": dist_int_acc / len(data_val),
                     'val epoch': epoch})
        print('epoch : ', epoch, 'val int loss', losses_int / len(data_val), "val int mean metric",
              dist_int_acc / len(data_val))


def run(epochs, eval_inter, save_inter, model, data_train, data_val, data_test, optimizer, criterion_rt,
Schneider Leo's avatar
Schneider Leo committed
        criterion_intensity, metric_rt, metric_intensity, forward, wandb=None, output='output/out.csv', file=False):
Schneider Leo's avatar
Schneider Leo committed
    if forward =='transfer' :
        for e in range(1, epochs + 1):
            train(model, data_train, e, optimizer, criterion_rt, criterion_intensity, metric_rt, metric_intensity, 'rt',
                  wandb=wandb)
            if e % eval_inter == 0:
                eval(model, data_val, e, criterion_rt, criterion_intensity, metric_rt, metric_intensity, 'both',
                     wandb=wandb)
            if e % save_inter == 0:
                save(model, 'model_common_' + str(e) + '.pt')
Schneider Leo's avatar
Schneider Leo committed
        save_pred(model, data_val, 'both', output)
Schneider Leo's avatar
Schneider Leo committed
    elif forward=='reverse':
        for e in range(1, epochs + 1):
            train(model, data_train, e, optimizer, criterion_rt, criterion_intensity, metric_rt, metric_intensity, 'both',
                  wandb=wandb)
            if e % eval_inter == 0:
                eval(model, data_val, e, criterion_rt, criterion_intensity, metric_rt, metric_intensity, 'rt',
                     wandb=wandb)
            if e % save_inter == 0:
                save(model, 'model_common_' + str(e) + '.pt')
        save_pred(model, data_val, 'rt', output)
Schneider Leo's avatar
Schneider Leo committed

    else :
        for e in range(1, epochs + 1):
            train(model, data_train, e, optimizer, criterion_rt, criterion_intensity, metric_rt, metric_intensity, forward,
                  wandb=wandb)
            if e % eval_inter == 0:
                eval(model, data_val, e, criterion_rt, criterion_intensity, metric_rt, metric_intensity, forward,
                     wandb=wandb)
Schneider Leo's avatar
Schneider Leo committed
            # if e % save_inter == 0:
            #     save(model, 'model_common_' + str(e) + '.pt')
Schneider Leo's avatar
Schneider Leo committed
        save_pred(model, data_val, forward, output, file=file)
Léo Schneider's avatar
Léo Schneider committed


def main(args):
    if args.wandb is not None:
        os.environ["WANDB_API_KEY"] = 'b4a27ac6b6145e1a5d0ee7f9e2e8c20bd101dccd'
        os.environ["WANDB_MODE"] = "offline"
        os.environ["WANDB_DIR"] = os.path.abspath("./wandb_run")

        wdb.init(project="Common prediction", dir='./wandb_run', name=args.wandb)
    print(args)
    print('Cuda : ', torch.cuda.is_available())

    if args.forward == 'both':
Schneider Leo's avatar
Schneider Leo committed
        data_train, data_val = common_dataset.load_data(path_train=args.dataset_train,
Léo Schneider's avatar
Léo Schneider committed
                                                                   path_val=args.dataset_val,
                                                                   path_test=args.dataset_test,
Schneider Leo's avatar
Schneider Leo committed
                                                                   batch_size=args.batch_size, length=args.seq_length, pad = False, convert=False, vocab='unmod')
Léo Schneider's avatar
Léo Schneider committed
    elif args.forward == 'rt':
Schneider Leo's avatar
Schneider Leo committed
        data_train, data_val = dataloader.load_data(data_sources=[args.dataset_train,args.dataset_val,args.dataset_test],
Schneider Leo's avatar
Schneider Leo committed
                                                               batch_size=args.batch_size, length=args.seq_length)
Schneider Leo's avatar
Schneider Leo committed

    elif args.forward == 'transfer':
Schneider Leo's avatar
Schneider Leo committed
        data_train, _ = dataloader.load_data(data_sources=[args.dataset_train,'database/data_holdout.csv','database/data_holdout.csv'],
Schneider Leo's avatar
Schneider Leo committed
                                                               batch_size=args.batch_size, length=args.seq_length)
Léo Schneider's avatar
Léo Schneider committed

Schneider Leo's avatar
Schneider Leo committed
        _, data_val = common_dataset.load_data(path_train=args.dataset_val,
Schneider Leo's avatar
Schneider Leo committed
                                                                   path_val=args.dataset_val,
                                                                   path_test=args.dataset_test,
Schneider Leo's avatar
Schneider Leo committed
                                                                   batch_size=args.batch_size, length=args.seq_length, pad = True, convert=True, vocab='unmod')
Schneider Leo's avatar
Schneider Leo committed

    elif args.forward == 'reverse':
Schneider Leo's avatar
Schneider Leo committed
        _, data_val = dataloader.load_data(data_sources=['database/data_train.csv',args.dataset_val,args.dataset_test],
Schneider Leo's avatar
Schneider Leo committed
                                                               batch_size=args.batch_size, length=args.seq_length)
Schneider Leo's avatar
Schneider Leo committed
        data_train, _ = common_dataset.load_data(path_train=args.dataset_train,
Schneider Leo's avatar
Schneider Leo committed
                                                                   path_val=args.dataset_train,
                                                                   path_test=args.dataset_train,
Schneider Leo's avatar
Schneider Leo committed
                                                                   batch_size=args.batch_size, length=args.seq_length, pad = True, convert=True, vocab='unmod')
Schneider Leo's avatar
Schneider Leo committed

Léo Schneider's avatar
Léo Schneider committed
    print('\nData loaded')

Schneider Leo's avatar
Schneider Leo committed
    model = Model_Common_Transformer(encoder_ff=args.encoder_ff, decoder_rt_ff=args.decoder_rt_ff,
Léo Schneider's avatar
Léo Schneider committed
                                     decoder_int_ff=args.decoder_int_ff
                                     , n_head=args.n_head, encoder_num_layer=args.encoder_num_layer,
                                     decoder_int_num_layer=args.decoder_int_num_layer,
                                     decoder_rt_num_layer=args.decoder_rt_num_layer, drop_rate=args.drop_rate,
Schneider Leo's avatar
Schneider Leo committed
                                     embedding_dim=args.embedding_dim, acti=args.activation, norm=args.norm_first,
                                     seq_length=args.seq_length)
Léo Schneider's avatar
Léo Schneider committed
    if torch.cuda.is_available():
        model = model.cuda()
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    print('\nModel initialised')

    run(epochs=args.epochs, eval_inter=args.eval_inter, save_inter=args.save_inter, model=model, data_train=data_train,
Schneider Leo's avatar
Schneider Leo committed
        data_val=data_val, data_test=data_val, optimizer=optimizer, criterion_rt=torch.nn.MSELoss(),
Léo Schneider's avatar
Léo Schneider committed
        criterion_intensity=masked_cos_sim, metric_rt=distance, metric_intensity=masked_spectral_angle,
Schneider Leo's avatar
Schneider Leo committed
        wandb=args.wandb, forward=args.forward, output=args.output, file=args.file)
Léo Schneider's avatar
Léo Schneider committed

    if args.wandb is not None:
        wdb.finish()


def save(model, checkpoint_name):
    print('\nModel Saving...')
    os.makedirs('checkpoints', exist_ok=True)
    torch.save(model, os.path.join('checkpoints', checkpoint_name))


def load(path):
    model = torch.load(os.path.join('checkpoints', path))
    return model


def get_n_params(model):
    pp = 0
    for n, p in list(model.named_parameters()):
        nn = 1

        for s in list(p.size()):
            nn = nn * s
        pp += nn
    return pp

Schneider Leo's avatar
Schneider Leo committed
def save_pred(model, data_val, forward, output_path, file = False):
Léo Schneider's avatar
Léo Schneider committed
    data_frame = pd.DataFrame()
    model.eval()
Léo Schneider's avatar
Léo Schneider committed
    for param in model.parameters():
        param.requires_grad = False
    if forward == 'both':
Schneider Leo's avatar
Schneider Leo committed
        pred_rt, pred_int, seqs, charges, true_rt, true_int, file_list = [], [], [], [], [], [], []
Schneider Leo's avatar
Schneider Leo committed
        if file:
            data_val.dataset.set_file_mode(True)
Schneider Leo's avatar
Schneider Leo committed
            for seq, charge, rt, intensity, files in data_val:
Schneider Leo's avatar
Schneider Leo committed
                rt, intensity = rt.float(), intensity.float()
Schneider Leo's avatar
Schneider Leo committed
                if torch.cuda.is_available():
Schneider Leo's avatar
Schneider Leo committed
                    seq, charge, rt, intensity, files = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda(), files.cuda()
Schneider Leo's avatar
Schneider Leo committed
                pr_rt, pr_intensity = model.forward(seq, charge)
                pred_rt.extend(pr_rt.data.cpu().tolist())
                pred_int.extend(pr_intensity.data.cpu().tolist())
                seqs.extend(seq.data.cpu().tolist())
                charges.extend(charge.data.cpu().tolist())
                true_rt.extend(rt.data.cpu().tolist())
                true_int.extend(intensity.data.cpu().tolist())
Schneider Leo's avatar
Schneider Leo committed
                file_list.extend(files.data.cpu().tolist())
Schneider Leo's avatar
Schneider Leo committed
        else :
            for seq, charge, rt, intensity in data_val:
                rt, intensity = rt.float(), intensity.float()
Schneider Leo's avatar
Schneider Leo committed
                if torch.cuda.is_available():
                    seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda()
Schneider Leo's avatar
Schneider Leo committed
                pr_rt, pr_intensity = model.forward(seq, charge)
                pred_rt.extend(pr_rt.data.cpu().tolist())
                pred_int.extend(pr_intensity.data.cpu().tolist())
                seqs.extend(seq.data.cpu().tolist())
                charges.extend(charge.data.cpu().tolist())
                true_rt.extend(rt.data.cpu().tolist())
                true_int.extend(intensity.data.cpu().tolist())
Léo Schneider's avatar
Léo Schneider committed
        data_frame['rt pred'] = pred_rt
        data_frame['seq'] = seqs
        data_frame['pred int'] = pred_int
        data_frame['true rt'] = true_rt
        data_frame['true int'] = true_int
        data_frame['charge'] = charges
Schneider Leo's avatar
Schneider Leo committed
        if file :
            data_frame['file'] = file_list
Léo Schneider's avatar
Léo Schneider committed



    if forward == 'rt':  #adapted to prosit dataset format
        pred_rt, seqs, true_rt = [], [], []
        for seq, rt in data_val:
            rt = rt.float()
            if torch.cuda.is_available():
                seq, rt = seq.cuda(), rt.cuda()
            pr_rt = model.forward_rt(seq)
            pred_rt.extend(pr_rt.data.cpu().tolist())
            seqs.extend(seq.data.cpu().tolist())
            true_rt.extend(rt.data.cpu().tolist())
Léo Schneider's avatar
Léo Schneider committed
        data_frame['rt pred'] = pred_rt
        data_frame['seq'] = seqs
        data_frame['true rt'] = true_rt
Léo Schneider's avatar
Léo Schneider committed


    if forward == 'int':  #adapted to prosit dataset format
        pred_int, seqs, charges, true_int = [], [], [], []
        for seq, charge, _, intensity in data_val:
            intensity = intensity.float()
            if torch.cuda.is_available():
                seq, charge, intensity = seq.cuda(), charge.cuda(), intensity.cuda()
            pred_int = model.forward_int(seq, charge)
            seqs.extend(seq.data.cpu().tolist())
            charges.extend(charge.data.cpu().tolist())
            true_int.extend(intensity.data.cpu().tolist())
Léo Schneider's avatar
Léo Schneider committed
        data_frame['seq'] = seqs
        data_frame['pred int'] = pred_int
        data_frame['true int'] = true_int
        data_frame['charge'] = charges
Schneider Leo's avatar
Schneider Leo committed
    if file :
        data_val.dataset.set_file_mode(False)
Léo Schneider's avatar
Léo Schneider committed
    data_frame.to_csv(output_path)

Léo Schneider's avatar
Léo Schneider committed

if __name__ == "__main__":
    args = load_args()
    main(args)

Léo Schneider's avatar
Léo Schneider committed