import os import pandas as pd import torch import torch.optim as optim import wandb as wdb import common_dataset import dataloader from config_common import load_args from common_dataset import load_data from dataloader import load_data from loss import masked_cos_sim, distance, masked_spectral_angle from model_custom import Model_Common_Transformer from model import RT_pred_model_self_attention_multi def train(model, data_train, epoch, optimizer, criterion_rt, criterion_intensity, metric_rt, metric_intensity, forward, wandb=None): losses_rt = 0. losses_int = 0. dist_rt_acc = 0. dist_int_acc = 0. model.train() for param in model.parameters(): param.requires_grad = True if forward == 'both': for seq, charge, rt, intensity in data_train: rt, intensity = rt.float(), intensity.float() if torch.cuda.is_available(): seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda() pred_rt, pred_int = model.forward(seq, charge) loss_rt = criterion_rt(rt, pred_rt) loss_int = criterion_intensity(intensity, pred_int) loss = loss_rt + loss_int dist_rt = metric_rt(rt, pred_rt) dist_int = metric_intensity(intensity, pred_int) dist_rt_acc += dist_rt.item() dist_int_acc += dist_int.item() losses_rt += loss_rt.item() losses_int += 5.*loss_int.item() optimizer.zero_grad() loss.backward() optimizer.step() if wandb is not None: wdb.log({"train rt loss": losses_rt / len(data_train), "train int loss": losses_int / len(data_train), "train rt mean metric": dist_rt_acc / len(data_train), "train int mean metric": dist_int_acc / len(data_train), 'train epoch': epoch}) print('epoch : ', epoch, 'train rt loss', losses_rt / len(data_train), 'train int loss', losses_int / len(data_train), "train rt mean metric : ", dist_rt_acc / len(data_train), "train int mean metric", dist_int_acc / len(data_train)) if forward == 'rt': for seq, rt in data_train: rt = rt.float() if torch.cuda.is_available(): seq, rt = seq.cuda(), rt.cuda() pred_rt = model.forward_rt(seq) loss_rt = criterion_rt(rt, pred_rt) loss = loss_rt dist_rt = metric_rt(rt, pred_rt) dist_rt_acc += dist_rt.item() losses_rt += loss_rt.item() optimizer.zero_grad() loss.backward() optimizer.step() if wandb is not None: wdb.log({"train rt loss": losses_rt / len(data_train), "train rt mean metric": dist_rt_acc / len(data_train), 'train epoch': epoch}) print('epoch : ', epoch, 'train rt loss', losses_rt / len(data_train), "train rt mean metric : ", dist_rt_acc / len(data_train)) if forward == 'int': for seq, charge, intensity in data_train: intensity = intensity.float() if torch.cuda.is_available(): seq, charge, intensity = seq.cuda(), charge.cuda(), intensity.cuda() pred_int = model.forward_int(seq, charge) loss_int = criterion_intensity(intensity, pred_int) loss = loss_int dist_int = metric_intensity(intensity, pred_int) dist_int_acc += dist_int.item() losses_int += loss_int.item() optimizer.zero_grad() loss.backward() optimizer.step() if wandb is not None: wdb.log({"train int loss": losses_int / len(data_train), "train int mean metric": dist_int_acc / len(data_train), 'train epoch': epoch}) print('epoch : ', epoch, 'train int loss', losses_int / len(data_train), "train int mean metric", dist_int_acc / len(data_train)) def eval(model, data_val, epoch, criterion_rt, criterion_intensity, metric_rt, metric_intensity, forward, wandb=None): losses_rt = 0. losses_int = 0. dist_rt_acc = 0. dist_int_acc = 0. model.eval() for param in model.parameters(): param.requires_grad = False if forward == 'both': for seq, charge, rt, intensity in data_val: rt, intensity = rt.float(), intensity.float() if torch.cuda.is_available(): seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda() pred_rt, pred_int = model.forward(seq, charge) loss_rt = criterion_rt(rt, pred_rt) loss_int = criterion_intensity(intensity, pred_int) losses_rt += loss_rt.item() losses_int += loss_int.item() dist_rt = metric_rt(rt, pred_rt) dist_int = metric_intensity(intensity, pred_int) dist_rt_acc += dist_rt.item() dist_int_acc += dist_int.item() if wandb is not None: wdb.log({"val rt loss": losses_rt / len(data_val), "val int loss": losses_int / len(data_val), "val rt mean metric": dist_rt_acc / len(data_val), "val int mean metric": dist_int_acc / len(data_val), 'val epoch': epoch}) print('epoch : ', epoch, 'val rt loss', losses_rt / len(data_val), 'val int loss', losses_int / len(data_val), "val rt mean metric : ", dist_rt_acc / len(data_val), "val int mean metric", dist_int_acc / len(data_val)) if forward == 'rt': #adapted to prosit dataset format for seq, rt in data_val: rt = rt.float() if torch.cuda.is_available(): seq, rt = seq.cuda(), rt.cuda() pred_rt = model.forward_rt(seq) loss_rt = criterion_rt(rt, pred_rt) losses_rt += loss_rt.item() dist_rt = metric_rt(rt, pred_rt) dist_rt_acc += dist_rt.item() if wandb is not None: wdb.log({"val rt loss": losses_rt / len(data_val), "val rt mean metric": dist_rt_acc / len(data_val), 'val epoch': epoch}) print('epoch : ', epoch, 'val rt loss', losses_rt / len(data_val), "val rt mean metric : ", dist_rt_acc / len(data_val)) if forward == 'int': #adapted to prosit dataset format for seq, charge, _, intensity in data_val: intensity = intensity.float() if torch.cuda.is_available(): seq, charge, intensity = seq.cuda(), charge.cuda(), intensity.cuda() pred_int = model.forward_int(seq, charge) loss_int = criterion_intensity(intensity, pred_int) losses_int += loss_int.item() dist_int = metric_intensity(intensity, pred_int) dist_int_acc += dist_int.item() if wandb is not None: wdb.log({"val int loss": losses_int / len(data_val), "val int mean metric": dist_int_acc / len(data_val), 'val epoch': epoch}) print('epoch : ', epoch, 'val int loss', losses_int / len(data_val), "val int mean metric", dist_int_acc / len(data_val)) def run(epochs, eval_inter, save_inter, model, data_train, data_val, data_test, optimizer, criterion_rt, criterion_intensity, metric_rt, metric_intensity, forward, wandb=None, output='output/out.csv', file=False): if forward =='transfer' : for e in range(1, epochs + 1): train(model, data_train, e, optimizer, criterion_rt, criterion_intensity, metric_rt, metric_intensity, 'rt', wandb=wandb) if e % eval_inter == 0: eval(model, data_val, e, criterion_rt, criterion_intensity, metric_rt, metric_intensity, 'both', wandb=wandb) if e % save_inter == 0: save(model, 'model_common_' + str(e) + '.pt') save_pred(model, data_val, 'both', output) elif forward=='reverse': for e in range(1, epochs + 1): train(model, data_train, e, optimizer, criterion_rt, criterion_intensity, metric_rt, metric_intensity, 'both', wandb=wandb) if e % eval_inter == 0: eval(model, data_val, e, criterion_rt, criterion_intensity, metric_rt, metric_intensity, 'rt', wandb=wandb) if e % save_inter == 0: save(model, 'model_common_' + str(e) + '.pt') save_pred(model, data_val, 'rt', output) else : for e in range(1, epochs + 1): train(model, data_train, e, optimizer, criterion_rt, criterion_intensity, metric_rt, metric_intensity, forward, wandb=wandb) if e % eval_inter == 0: eval(model, data_val, e, criterion_rt, criterion_intensity, metric_rt, metric_intensity, forward, wandb=wandb) # if e % save_inter == 0: # save(model, 'model_common_' + str(e) + '.pt') save_pred(model, data_val, forward, output, file=file) def main(args): if args.wandb is not None: os.environ["WANDB_API_KEY"] = 'b4a27ac6b6145e1a5d0ee7f9e2e8c20bd101dccd' os.environ["WANDB_MODE"] = "offline" os.environ["WANDB_DIR"] = os.path.abspath("./wandb_run") wdb.init(project="Common prediction", dir='./wandb_run', name=args.wandb) print(args) print('Cuda : ', torch.cuda.is_available()) if args.forward == 'both': data_train, data_val = common_dataset.load_data(path_train=args.dataset_train, path_val=args.dataset_val, path_test=args.dataset_test, batch_size=args.batch_size, length=args.seq_length, pad = False, convert=False, vocab='unmod') elif args.forward == 'rt': data_train, data_val = dataloader.load_data(data_sources=[args.dataset_train,args.dataset_val,args.dataset_test], batch_size=args.batch_size, length=args.seq_length) elif args.forward == 'transfer': data_train, _ = dataloader.load_data(data_sources=[args.dataset_train,'database/data_holdout.csv','database/data_holdout.csv'], batch_size=args.batch_size, length=args.seq_length) _, data_val = common_dataset.load_data(path_train=args.dataset_val, path_val=args.dataset_val, path_test=args.dataset_test, batch_size=args.batch_size, length=args.seq_length, pad = True, convert=True, vocab='unmod') elif args.forward == 'reverse': _, data_val = dataloader.load_data(data_sources=['database/data_train.csv',args.dataset_val,args.dataset_test], batch_size=args.batch_size, length=args.seq_length) data_train, _ = common_dataset.load_data(path_train=args.dataset_train, path_val=args.dataset_train, path_test=args.dataset_train, batch_size=args.batch_size, length=args.seq_length, pad = True, convert=True, vocab='unmod') print('\nData loaded') model = Model_Common_Transformer(encoder_ff=args.encoder_ff, decoder_rt_ff=args.decoder_rt_ff, decoder_int_ff=args.decoder_int_ff , n_head=args.n_head, encoder_num_layer=args.encoder_num_layer, decoder_int_num_layer=args.decoder_int_num_layer, decoder_rt_num_layer=args.decoder_rt_num_layer, drop_rate=args.drop_rate, embedding_dim=args.embedding_dim, acti=args.activation, norm=args.norm_first, seq_length=args.seq_length) if torch.cuda.is_available(): model = model.cuda() optimizer = optim.Adam(model.parameters(), lr=args.lr) print('\nModel initialised') run(epochs=args.epochs, eval_inter=args.eval_inter, save_inter=args.save_inter, model=model, data_train=data_train, data_val=data_val, data_test=data_val, optimizer=optimizer, criterion_rt=torch.nn.MSELoss(), criterion_intensity=masked_cos_sim, metric_rt=distance, metric_intensity=masked_spectral_angle, wandb=args.wandb, forward=args.forward, output=args.output, file=args.file) if args.wandb is not None: wdb.finish() def save(model, checkpoint_name): print('\nModel Saving...') os.makedirs('checkpoints', exist_ok=True) torch.save(model, os.path.join('checkpoints', checkpoint_name)) def load(path): model = torch.load(os.path.join('checkpoints', path)) return model def get_n_params(model): pp = 0 for n, p in list(model.named_parameters()): nn = 1 for s in list(p.size()): nn = nn * s pp += nn return pp def save_pred(model, data_val, forward, output_path, file = False): data_frame = pd.DataFrame() model.eval() for param in model.parameters(): param.requires_grad = False if forward == 'both': pred_rt, pred_int, seqs, charges, true_rt, true_int, file_list = [], [], [], [], [], [], [] if file: data_val.dataset.set_file_mode(True) for seq, charge, rt, intensity, files in data_val: rt, intensity = rt.float(), intensity.float() if torch.cuda.is_available(): seq, charge, rt, intensity, files = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda(), files.cuda() pr_rt, pr_intensity = model.forward(seq, charge) pred_rt.extend(pr_rt.data.cpu().tolist()) pred_int.extend(pr_intensity.data.cpu().tolist()) seqs.extend(seq.data.cpu().tolist()) charges.extend(charge.data.cpu().tolist()) true_rt.extend(rt.data.cpu().tolist()) true_int.extend(intensity.data.cpu().tolist()) file_list.extend(files.data.cpu().tolist()) else : for seq, charge, rt, intensity in data_val: rt, intensity = rt.float(), intensity.float() if torch.cuda.is_available(): seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda() pr_rt, pr_intensity = model.forward(seq, charge) pred_rt.extend(pr_rt.data.cpu().tolist()) pred_int.extend(pr_intensity.data.cpu().tolist()) seqs.extend(seq.data.cpu().tolist()) charges.extend(charge.data.cpu().tolist()) true_rt.extend(rt.data.cpu().tolist()) true_int.extend(intensity.data.cpu().tolist()) data_frame['rt pred'] = pred_rt data_frame['seq'] = seqs data_frame['pred int'] = pred_int data_frame['true rt'] = true_rt data_frame['true int'] = true_int data_frame['charge'] = charges if file : data_frame['file'] = file_list if forward == 'rt': #adapted to prosit dataset format pred_rt, seqs, true_rt = [], [], [] for seq, rt in data_val: rt = rt.float() if torch.cuda.is_available(): seq, rt = seq.cuda(), rt.cuda() pr_rt = model.forward_rt(seq) pred_rt.extend(pr_rt.data.cpu().tolist()) seqs.extend(seq.data.cpu().tolist()) true_rt.extend(rt.data.cpu().tolist()) data_frame['rt pred'] = pred_rt data_frame['seq'] = seqs data_frame['true rt'] = true_rt if forward == 'int': #adapted to prosit dataset format pred_int, seqs, charges, true_int = [], [], [], [] for seq, charge, _, intensity in data_val: intensity = intensity.float() if torch.cuda.is_available(): seq, charge, intensity = seq.cuda(), charge.cuda(), intensity.cuda() pred_int = model.forward_int(seq, charge) seqs.extend(seq.data.cpu().tolist()) charges.extend(charge.data.cpu().tolist()) true_int.extend(intensity.data.cpu().tolist()) data_frame['seq'] = seqs data_frame['pred int'] = pred_int data_frame['true int'] = true_int data_frame['charge'] = charges if file : data_val.dataset.set_file_mode(False) data_frame.to_csv(output_path) if __name__ == "__main__": args = load_args() main(args)