From a49eb721349572eee734dd8590cc3a5f0f572104 Mon Sep 17 00:00:00 2001 From: Schneider Leo <leo.schneider@etu.ec-lyon.fr> Date: Tue, 24 Sep 2024 13:37:33 +0200 Subject: [PATCH] dataset rain ISA --- main_custom.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/main_custom.py b/main_custom.py index fb99a74..06b4fbd 100644 --- a/main_custom.py +++ b/main_custom.py @@ -11,7 +11,7 @@ from config_common import load_args from common_dataset import load_data from dataloader import load_data from loss import masked_cos_sim, distance, masked_spectral_angle -from model_custom import Model_Common_Transformer, Model_Common_Transformer_TAPE +from model_custom import Model_Common_Transformer from model import RT_pred_model_self_attention_multi @@ -25,7 +25,6 @@ def train(model, data_train, epoch, optimizer, criterion_rt, criterion_intensity for param in model.parameters(): param.requires_grad = True if forward == 'both': - print(data_train.dataset.data['Sequence']) for seq, charge, rt, intensity in data_train: rt, intensity = rt.float(), intensity.float() if torch.cuda.is_available(): @@ -201,7 +200,7 @@ def main(args): data_train, data_val, data_test = common_dataset.load_data(path_train=args.dataset_train, path_val=args.dataset_val, path_test=args.dataset_test, - batch_size=args.batch_size, length=25, pad = False, convert=True, vocab='iapuc') + batch_size=args.batch_size, length=25, pad = True, convert=True, vocab='iapuc') elif args.forward == 'rt': data_train, data_val, data_test = dataloader.load_data(data_sources=[args.dataset_train,args.dataset_val,args.dataset_test], batch_size=args.batch_size, length=25) -- GitLab