From d607b30088ef31e3f4b80f42fc3702ef0d55c042 Mon Sep 17 00:00:00 2001
From: Schneider Leo <leo.schneider@etu.ec-lyon.fr>
Date: Thu, 5 Dec 2024 14:22:17 +0100
Subject: [PATCH] split args

---
 config.py        |  3 +++
 data/data_viz.py | 48 ++++++++++++++++++++++++------------------------
 main.py          |  6 +++---
 3 files changed, 30 insertions(+), 27 deletions(-)

diff --git a/config.py b/config.py
index 28fb0c6..4b7800e 100644
--- a/config.py
+++ b/config.py
@@ -12,8 +12,11 @@ def load_args():
     parser.add_argument('--model', type=str, default='RT_multi_sum')
     parser.add_argument('--wandb', type=str, default=None)
     parser.add_argument('--dataset_train', type=str, default='data/data_prosit/data.csv')
+    parser.add_argument('--split_train', type=str, default='train')
     parser.add_argument('--dataset_val', type=str, default='data/data_prosit/data.csv')
+    parser.add_argument('--split_val', type=str, default='validation')
     parser.add_argument('--dataset_test', type=str, default='data/data_prosit/data.csv')
+    parser.add_argument('--split_test', type=str, default='holdout')
     parser.add_argument('--embedding_dim', type=int, default=16)
     parser.add_argument('--encoder_ff', type=int, default=2048)
     parser.add_argument('--decoder_rt_ff', type=int, default=2048)
diff --git a/data/data_viz.py b/data/data_viz.py
index 5aa57f7..e4d6b1f 100644
--- a/data/data_viz.py
+++ b/data/data_viz.py
@@ -208,13 +208,13 @@ if __name__ == '__main__' :
     # list_df = [df_1,df_2,df_3,df_4]
     # df = select_best_data(list_df, 0.05)
     # df.to_pickle('data_ISA/data_ISA_additionnal_005.pkl')
-    df = pd.read_pickle('data_ISA/data_ISA_additionnal_005.pkl')
-    df['state'] = 'train'
-    df['sequence'] = df['sequence'].map(numerical_to_alphabetical_str)
-    df_augmented_1 = pd.concat([df, df_base], axis=0).reset_index(drop=True)
-    df_augmented_1.columns = ['sequence', 'irt_scaled','state']
-
-    df_augmented_1.to_csv('data_ISA/isa_data_augmented_005.csv', index=False)
+    # df = pd.read_pickle('data_ISA/data_ISA_additionnal_005.pkl')
+    # df['state'] = 'train'
+    # df['sequence'] = df['sequence'].map(numerical_to_alphabetical_str)
+    # df_augmented_1 = pd.concat([df, df_base], axis=0).reset_index(drop=True)
+    # df_augmented_1.columns = ['sequence', 'irt_scaled','state']
+    #
+    # df_augmented_1.to_csv('data_ISA/isa_data_augmented_005.csv', index=False)
     #
     #
     # df_1 = pd.read_csv('../output/out_ISA_noc_prosit_1.csv')
@@ -226,30 +226,30 @@ if __name__ == '__main__' :
     # df = select_best_data(list_df, 0.1)
     # df.to_pickle('data_ISA/data_ISA_additionnal_01.pkl')
     #
-    df = pd.read_pickle('data_ISA/data_ISA_additionnal_01.pkl')
-    df['state'] = 'train'
-    df['sequence'] = df['sequence'].map(numerical_to_alphabetical_str)
-    df_augmented_1 = pd.concat([df, df_base], axis=0).reset_index(drop=True)
-    df_augmented_1.columns = ['sequence', 'irt_scaled','state']
-
-    df_augmented_1.to_csv('data_ISA/isa_data_augmented_01.csv', index=False)
+    # df = pd.read_pickle('data_ISA/data_ISA_additionnal_01.pkl')
+    # df['state'] = 'train'
+    # df['sequence'] = df['sequence'].map(numerical_to_alphabetical_str)
+    # df_augmented_1 = pd.concat([df, df_base], axis=0).reset_index(drop=True)
+    # df_augmented_1.columns = ['sequence', 'irt_scaled','state']
     #
-    # df_1 = pd.read_csv('../output/out_ISA_noc_prosit_1.csv')
-    # df_2 = pd.read_csv('../output/out_ISA_noc_prosit_2.csv')
-    # df_3 = pd.read_csv('../output/out_ISA_noc_prosit_3.csv')
-    # df_4 = pd.read_csv('../output/out_ISA_noc_prosit_4.csv')
-    #
-    # list_df = [df_1, df_2, df_3, df_4]
-    # df = select_best_data(list_df, 0.2)
-    # df.to_pickle('data_ISA/data_ISA_additionnal_02.pkl')
+    # df_augmented_1.to_csv('data_ISA/isa_data_augmented_01.csv', index=False)
+    # #
+    df_1 = pd.read_csv('../output/out_ISA_noc_prosit_1.csv')
+    df_2 = pd.read_csv('../output/out_ISA_noc_prosit_2.csv')
+    df_3 = pd.read_csv('../output/out_ISA_noc_prosit_3.csv')
+    df_4 = pd.read_csv('../output/out_ISA_noc_prosit_4.csv')
+
+    list_df = [df_1, df_2, df_3, df_4]
+    df = select_best_data(list_df, 0.5)
+    df.to_pickle('data_ISA/data_ISA_additionnal_05.pkl')
 
-    df = pd.read_pickle('data_ISA/data_ISA_additionnal_02.pkl')
+    df = pd.read_pickle('data_ISA/data_ISA_additionnal_05.pkl')
     df['state'] = 'train'
     df['sequence'] = df['sequence'].map(numerical_to_alphabetical_str)
     df_augmented_1 = pd.concat([df, df_base], axis=0).reset_index(drop=True)
     df_augmented_1.columns = ['sequence', 'irt_scaled','state']
 
-    df_augmented_1.to_csv('data_ISA/isa_data_augmented_02.csv', index=False)
+    df_augmented_1.to_csv('data_ISA/isa_data_augmented_05.csv', index=False)
 
 
 
diff --git a/main.py b/main.py
index 1345ab8..6cd9458 100644
--- a/main.py
+++ b/main.py
@@ -90,9 +90,9 @@ def main(args):
     print(args)
     print('Cuda : ', torch.cuda.is_available())
 
-    data_train = load_data(data_source=args.dataset_train, batch_size=args.batch_size, length=25, mode='train')
-    data_test = load_data(data_source=args.dataset_test , batch_size=args.batch_size, length=25, mode='holdout')
-    data_val = load_data(data_source=args.dataset_val, batch_size=args.batch_size, length=25, mode='validation')
+    data_train = load_data(data_source=args.dataset_train, batch_size=args.batch_size, length=25, mode=args.split_train)
+    data_test = load_data(data_source=args.dataset_test , batch_size=args.batch_size, length=25, mode=args.split_test)
+    data_val = load_data(data_source=args.dataset_val, batch_size=args.batch_size, length=25, mode=args.split_val)
     print('\nData loaded')
 
     model = ModelTransformer(encoder_ff=args.encoder_ff, decoder_rt_ff=args.decoder_rt_ff,
-- 
GitLab