Skip to content
Snippets Groups Projects
main.py 6.40 KiB
import os

import pandas as pd
import torch
import torch.optim as optim
import wandb as wdb

from data.dataset import load_data
from config import load_args
from model.loss import distance
from model.model import ModelTransformer


def train(model, data_train, epoch, optimizer, criterion_rt, metric_rt, wandb=None):
    losses_rt = 0.
    dist_rt_acc = 0.
    model.train()
    for param in model.parameters():
        param.requires_grad = True

    for seq, rt in data_train:
        rt = rt.float()
        if torch.cuda.is_available():
            seq, rt = seq.cuda(), rt.cuda()
        pred_rt = model.forward(seq)
        loss_rt = criterion_rt(rt, pred_rt)
        loss = loss_rt
        dist_rt = metric_rt(rt, pred_rt)
        dist_rt_acc += dist_rt.item()
        losses_rt += loss_rt.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    if wandb is not None:
        wdb.log({"train rt loss": losses_rt / len(data_train),
                 "train rt mean metric": dist_rt_acc / len(data_train),
                 'train epoch': epoch})

    print('epoch : ', epoch, 'train rt loss', losses_rt / len(data_train), "train rt mean metric : ",
          dist_rt_acc / len(data_train))



def eval(model, data_val, epoch, criterion_rt, metric_rt, wandb=None):
    model.eval()
    losses_rt = 0.
    dist_rt_acc = 0.
    for param in model.parameters():
        param.requires_grad = False

    for seq, rt in data_val:
        rt = rt.float()
        if torch.cuda.is_available():
            seq, rt = seq.cuda(), rt.cuda()
        pred_rt = model.forward(seq)
        loss_rt = criterion_rt(rt, pred_rt)
        losses_rt += loss_rt.item()
        dist_rt = metric_rt(rt, pred_rt)
        dist_rt_acc += dist_rt.item()

    if wandb is not None:
        wdb.log({"val rt loss": losses_rt / len(data_val),
                 "val rt mean metric": dist_rt_acc / len(data_val),
                 'val epoch': epoch})

    print('epoch : ', epoch, 'val rt loss', losses_rt / len(data_val),
          "val rt mean metric : ",
          dist_rt_acc / len(data_val))
    return losses_rt


def run(epochs, eval_inter, save_inter, model, data_train, data_val, data_test, optimizer, criterion_rt,
         metric_rt, wandb=None, output='output/out.csv'):
    mem = 1000000.
    for e in range(1, epochs + 1):
        train(model, data_train, e, optimizer, criterion_rt, metric_rt, wandb=wandb)
        if e % eval_inter == 0:
            losses_rt = eval(model, data_val, e, criterion_rt, metric_rt, wandb=wandb)
            if losses_rt < mem :
                mem = losses_rt
                torch.save(model.state_dict(), output.strip('.csv')+'.pt')
                print('model saved')
    model.load_state_dict(torch.load(output.strip('.csv')+'.pt', weights_only=True))
    save_pred(model, data_test, output, criterion_rt, metric_rt, wandb)


def main(args):
    if args.wandb is not None:
        os.environ["WANDB_API_KEY"] = 'b4a27ac6b6145e1a5d0ee7f9e2e8c20bd101dccd'
        os.environ["WANDB_MODE"] = "offline"
        os.environ["WANDB_DIR"] = os.path.abspath("./wandb_run")

        wdb.init(project="Common prediction", dir='./wandb_run', name=args.wandb)
    print(args)
    print('Cuda : ', torch.cuda.is_available())

    data_train = load_data(data_source=args.dataset_train, batch_size=args.batch_size, length=30, mode=args.split_train, seq_col=args.seq_train)
    data_test = load_data(data_source=args.dataset_test , batch_size=args.batch_size, length=30, mode=args.split_test, seq_col=args.seq_test)
    data_val = load_data(data_source=args.dataset_val, batch_size=args.batch_size, length=30, mode=args.split_val, seq_col=args.seq_val)
    print('\nData loaded')

    model = ModelTransformer(encoder_ff=args.encoder_ff, decoder_rt_ff=args.decoder_rt_ff,
                              n_head=args.n_head, encoder_num_layer=args.encoder_num_layer,
                              decoder_rt_num_layer=args.decoder_rt_num_layer, drop_rate=args.drop_rate,
                              embedding_dim=args.embedding_dim, acti=args.activation, norm=args.norm_first,
                             seq_length=30)

    if args.model_weigh is not None :
        model.load_state_dict(torch.load(args.model_weigh, weights_only=True))

    if torch.cuda.is_available():
        model = model.cuda()
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    print('\nModel initialised')

    run(epochs=args.epochs, eval_inter=args.eval_inter, save_inter=args.save_inter, model=model, data_train=data_train,
        data_val=data_val, data_test=data_test, optimizer=optimizer, criterion_rt=torch.nn.MSELoss(), metric_rt=distance,
        wandb=args.wandb, output=args.output)

    if args.wandb is not None:
        wdb.finish()


def save(model, checkpoint_name):
    print('\nModel Saving...')
    os.makedirs('checkpoints', exist_ok=True)
    torch.save(model, os.path.join('checkpoints', checkpoint_name))


def load(path):
    model = torch.load(os.path.join('checkpoints', path))
    return model


def get_n_params(model):
    pp = 0
    for n, p in list(model.named_parameters()):
        nn = 1

        for s in list(p.size()):
            nn = nn * s
        print(n, nn)
        pp += nn

    return pp

def save_pred(model, data_val, output_path,  criterion_rt, metric_rt, wandb=None):
    data_frame = pd.DataFrame()
    losses_rt = 0.
    dist_rt_acc = 0.
    model.eval()
    for param in model.parameters():
        param.requires_grad = False

    pred_rt, seqs, true_rt = [], [], []
    for seq, rt in data_val:
        rt = rt.float()
        if torch.cuda.is_available():
            seq, rt = seq.cuda(), rt.cuda()
        pr_rt = model.forward(seq)
        pred_rt.extend(pr_rt.data.cpu().tolist())
        seqs.extend(seq.data.cpu().tolist())
        true_rt.extend(rt.data.cpu().tolist())



        loss_rt = criterion_rt(rt, pr_rt)
        losses_rt += loss_rt.item()
        dist_rt = metric_rt(rt, pr_rt)
        dist_rt_acc += dist_rt.item()

    if wandb is not None:
        wdb.log({"test rt loss": losses_rt / len(data_val),
                 "test rt mean metric": dist_rt_acc / len(data_val)})

    print('test rt loss', losses_rt / len(data_val),
          "test rt mean metric : ",
          dist_rt_acc / len(data_val))

    data_frame['rt pred'] = pred_rt
    data_frame['seq'] = seqs
    data_frame['true rt'] = true_rt
    data_frame.to_csv(output_path)


if __name__ == "__main__":
    args = load_args()
    main(args)



#output/out_coli_augmented_04_coli_8.pt