Skip to content
Snippets Groups Projects
Commit d607b300 authored by Schneider Leo's avatar Schneider Leo
Browse files

split args

parent 151ed3f4
No related branches found
No related tags found
No related merge requests found
......@@ -12,8 +12,11 @@ def load_args():
parser.add_argument('--model', type=str, default='RT_multi_sum')
parser.add_argument('--wandb', type=str, default=None)
parser.add_argument('--dataset_train', type=str, default='data/data_prosit/data.csv')
parser.add_argument('--split_train', type=str, default='train')
parser.add_argument('--dataset_val', type=str, default='data/data_prosit/data.csv')
parser.add_argument('--split_val', type=str, default='validation')
parser.add_argument('--dataset_test', type=str, default='data/data_prosit/data.csv')
parser.add_argument('--split_test', type=str, default='holdout')
parser.add_argument('--embedding_dim', type=int, default=16)
parser.add_argument('--encoder_ff', type=int, default=2048)
parser.add_argument('--decoder_rt_ff', type=int, default=2048)
......
......@@ -208,13 +208,13 @@ if __name__ == '__main__' :
# list_df = [df_1,df_2,df_3,df_4]
# df = select_best_data(list_df, 0.05)
# df.to_pickle('data_ISA/data_ISA_additionnal_005.pkl')
df = pd.read_pickle('data_ISA/data_ISA_additionnal_005.pkl')
df['state'] = 'train'
df['sequence'] = df['sequence'].map(numerical_to_alphabetical_str)
df_augmented_1 = pd.concat([df, df_base], axis=0).reset_index(drop=True)
df_augmented_1.columns = ['sequence', 'irt_scaled','state']
df_augmented_1.to_csv('data_ISA/isa_data_augmented_005.csv', index=False)
# df = pd.read_pickle('data_ISA/data_ISA_additionnal_005.pkl')
# df['state'] = 'train'
# df['sequence'] = df['sequence'].map(numerical_to_alphabetical_str)
# df_augmented_1 = pd.concat([df, df_base], axis=0).reset_index(drop=True)
# df_augmented_1.columns = ['sequence', 'irt_scaled','state']
#
# df_augmented_1.to_csv('data_ISA/isa_data_augmented_005.csv', index=False)
#
#
# df_1 = pd.read_csv('../output/out_ISA_noc_prosit_1.csv')
......@@ -226,30 +226,30 @@ if __name__ == '__main__' :
# df = select_best_data(list_df, 0.1)
# df.to_pickle('data_ISA/data_ISA_additionnal_01.pkl')
#
df = pd.read_pickle('data_ISA/data_ISA_additionnal_01.pkl')
df['state'] = 'train'
df['sequence'] = df['sequence'].map(numerical_to_alphabetical_str)
df_augmented_1 = pd.concat([df, df_base], axis=0).reset_index(drop=True)
df_augmented_1.columns = ['sequence', 'irt_scaled','state']
df_augmented_1.to_csv('data_ISA/isa_data_augmented_01.csv', index=False)
# df = pd.read_pickle('data_ISA/data_ISA_additionnal_01.pkl')
# df['state'] = 'train'
# df['sequence'] = df['sequence'].map(numerical_to_alphabetical_str)
# df_augmented_1 = pd.concat([df, df_base], axis=0).reset_index(drop=True)
# df_augmented_1.columns = ['sequence', 'irt_scaled','state']
#
# df_1 = pd.read_csv('../output/out_ISA_noc_prosit_1.csv')
# df_2 = pd.read_csv('../output/out_ISA_noc_prosit_2.csv')
# df_3 = pd.read_csv('../output/out_ISA_noc_prosit_3.csv')
# df_4 = pd.read_csv('../output/out_ISA_noc_prosit_4.csv')
#
# list_df = [df_1, df_2, df_3, df_4]
# df = select_best_data(list_df, 0.2)
# df.to_pickle('data_ISA/data_ISA_additionnal_02.pkl')
# df_augmented_1.to_csv('data_ISA/isa_data_augmented_01.csv', index=False)
# #
df_1 = pd.read_csv('../output/out_ISA_noc_prosit_1.csv')
df_2 = pd.read_csv('../output/out_ISA_noc_prosit_2.csv')
df_3 = pd.read_csv('../output/out_ISA_noc_prosit_3.csv')
df_4 = pd.read_csv('../output/out_ISA_noc_prosit_4.csv')
list_df = [df_1, df_2, df_3, df_4]
df = select_best_data(list_df, 0.5)
df.to_pickle('data_ISA/data_ISA_additionnal_05.pkl')
df = pd.read_pickle('data_ISA/data_ISA_additionnal_02.pkl')
df = pd.read_pickle('data_ISA/data_ISA_additionnal_05.pkl')
df['state'] = 'train'
df['sequence'] = df['sequence'].map(numerical_to_alphabetical_str)
df_augmented_1 = pd.concat([df, df_base], axis=0).reset_index(drop=True)
df_augmented_1.columns = ['sequence', 'irt_scaled','state']
df_augmented_1.to_csv('data_ISA/isa_data_augmented_02.csv', index=False)
df_augmented_1.to_csv('data_ISA/isa_data_augmented_05.csv', index=False)
......
......@@ -90,9 +90,9 @@ def main(args):
print(args)
print('Cuda : ', torch.cuda.is_available())
data_train = load_data(data_source=args.dataset_train, batch_size=args.batch_size, length=25, mode='train')
data_test = load_data(data_source=args.dataset_test , batch_size=args.batch_size, length=25, mode='holdout')
data_val = load_data(data_source=args.dataset_val, batch_size=args.batch_size, length=25, mode='validation')
data_train = load_data(data_source=args.dataset_train, batch_size=args.batch_size, length=25, mode=args.split_train)
data_test = load_data(data_source=args.dataset_test , batch_size=args.batch_size, length=25, mode=args.split_test)
data_val = load_data(data_source=args.dataset_val, batch_size=args.batch_size, length=25, mode=args.split_val)
print('\nData loaded')
model = ModelTransformer(encoder_ff=args.encoder_ff, decoder_rt_ff=args.decoder_rt_ff,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment