From d607b30088ef31e3f4b80f42fc3702ef0d55c042 Mon Sep 17 00:00:00 2001 From: Schneider Leo <leo.schneider@etu.ec-lyon.fr> Date: Thu, 5 Dec 2024 14:22:17 +0100 Subject: [PATCH] split args --- config.py | 3 +++ data/data_viz.py | 48 ++++++++++++++++++++++++------------------------ main.py | 6 +++--- 3 files changed, 30 insertions(+), 27 deletions(-) diff --git a/config.py b/config.py index 28fb0c6..4b7800e 100644 --- a/config.py +++ b/config.py @@ -12,8 +12,11 @@ def load_args(): parser.add_argument('--model', type=str, default='RT_multi_sum') parser.add_argument('--wandb', type=str, default=None) parser.add_argument('--dataset_train', type=str, default='data/data_prosit/data.csv') + parser.add_argument('--split_train', type=str, default='train') parser.add_argument('--dataset_val', type=str, default='data/data_prosit/data.csv') + parser.add_argument('--split_val', type=str, default='validation') parser.add_argument('--dataset_test', type=str, default='data/data_prosit/data.csv') + parser.add_argument('--split_test', type=str, default='holdout') parser.add_argument('--embedding_dim', type=int, default=16) parser.add_argument('--encoder_ff', type=int, default=2048) parser.add_argument('--decoder_rt_ff', type=int, default=2048) diff --git a/data/data_viz.py b/data/data_viz.py index 5aa57f7..e4d6b1f 100644 --- a/data/data_viz.py +++ b/data/data_viz.py @@ -208,13 +208,13 @@ if __name__ == '__main__' : # list_df = [df_1,df_2,df_3,df_4] # df = select_best_data(list_df, 0.05) # df.to_pickle('data_ISA/data_ISA_additionnal_005.pkl') - df = pd.read_pickle('data_ISA/data_ISA_additionnal_005.pkl') - df['state'] = 'train' - df['sequence'] = df['sequence'].map(numerical_to_alphabetical_str) - df_augmented_1 = pd.concat([df, df_base], axis=0).reset_index(drop=True) - df_augmented_1.columns = ['sequence', 'irt_scaled','state'] - - df_augmented_1.to_csv('data_ISA/isa_data_augmented_005.csv', index=False) + # df = pd.read_pickle('data_ISA/data_ISA_additionnal_005.pkl') + # df['state'] = 'train' + # df['sequence'] = df['sequence'].map(numerical_to_alphabetical_str) + # df_augmented_1 = pd.concat([df, df_base], axis=0).reset_index(drop=True) + # df_augmented_1.columns = ['sequence', 'irt_scaled','state'] + # + # df_augmented_1.to_csv('data_ISA/isa_data_augmented_005.csv', index=False) # # # df_1 = pd.read_csv('../output/out_ISA_noc_prosit_1.csv') @@ -226,30 +226,30 @@ if __name__ == '__main__' : # df = select_best_data(list_df, 0.1) # df.to_pickle('data_ISA/data_ISA_additionnal_01.pkl') # - df = pd.read_pickle('data_ISA/data_ISA_additionnal_01.pkl') - df['state'] = 'train' - df['sequence'] = df['sequence'].map(numerical_to_alphabetical_str) - df_augmented_1 = pd.concat([df, df_base], axis=0).reset_index(drop=True) - df_augmented_1.columns = ['sequence', 'irt_scaled','state'] - - df_augmented_1.to_csv('data_ISA/isa_data_augmented_01.csv', index=False) + # df = pd.read_pickle('data_ISA/data_ISA_additionnal_01.pkl') + # df['state'] = 'train' + # df['sequence'] = df['sequence'].map(numerical_to_alphabetical_str) + # df_augmented_1 = pd.concat([df, df_base], axis=0).reset_index(drop=True) + # df_augmented_1.columns = ['sequence', 'irt_scaled','state'] # - # df_1 = pd.read_csv('../output/out_ISA_noc_prosit_1.csv') - # df_2 = pd.read_csv('../output/out_ISA_noc_prosit_2.csv') - # df_3 = pd.read_csv('../output/out_ISA_noc_prosit_3.csv') - # df_4 = pd.read_csv('../output/out_ISA_noc_prosit_4.csv') - # - # list_df = [df_1, df_2, df_3, df_4] - # df = select_best_data(list_df, 0.2) - # df.to_pickle('data_ISA/data_ISA_additionnal_02.pkl') + # df_augmented_1.to_csv('data_ISA/isa_data_augmented_01.csv', index=False) + # # + df_1 = pd.read_csv('../output/out_ISA_noc_prosit_1.csv') + df_2 = pd.read_csv('../output/out_ISA_noc_prosit_2.csv') + df_3 = pd.read_csv('../output/out_ISA_noc_prosit_3.csv') + df_4 = pd.read_csv('../output/out_ISA_noc_prosit_4.csv') + + list_df = [df_1, df_2, df_3, df_4] + df = select_best_data(list_df, 0.5) + df.to_pickle('data_ISA/data_ISA_additionnal_05.pkl') - df = pd.read_pickle('data_ISA/data_ISA_additionnal_02.pkl') + df = pd.read_pickle('data_ISA/data_ISA_additionnal_05.pkl') df['state'] = 'train' df['sequence'] = df['sequence'].map(numerical_to_alphabetical_str) df_augmented_1 = pd.concat([df, df_base], axis=0).reset_index(drop=True) df_augmented_1.columns = ['sequence', 'irt_scaled','state'] - df_augmented_1.to_csv('data_ISA/isa_data_augmented_02.csv', index=False) + df_augmented_1.to_csv('data_ISA/isa_data_augmented_05.csv', index=False) diff --git a/main.py b/main.py index 1345ab8..6cd9458 100644 --- a/main.py +++ b/main.py @@ -90,9 +90,9 @@ def main(args): print(args) print('Cuda : ', torch.cuda.is_available()) - data_train = load_data(data_source=args.dataset_train, batch_size=args.batch_size, length=25, mode='train') - data_test = load_data(data_source=args.dataset_test , batch_size=args.batch_size, length=25, mode='holdout') - data_val = load_data(data_source=args.dataset_val, batch_size=args.batch_size, length=25, mode='validation') + data_train = load_data(data_source=args.dataset_train, batch_size=args.batch_size, length=25, mode=args.split_train) + data_test = load_data(data_source=args.dataset_test , batch_size=args.batch_size, length=25, mode=args.split_test) + data_val = load_data(data_source=args.dataset_val, batch_size=args.batch_size, length=25, mode=args.split_val) print('\nData loaded') model = ModelTransformer(encoder_ff=args.encoder_ff, decoder_rt_ff=args.decoder_rt_ff, -- GitLab