diff --git a/data/data_viz.py b/data/data_viz.py index f3ccbe15595d25056c5232df1a2e98853bafb113..0dc9bd95d81345a1027493a4cc0021c1db3afc85 100644 --- a/data/data_viz.py +++ b/data/data_viz.py @@ -165,15 +165,20 @@ def plot_res(): def calc_and_plot_res(): all_data=[] - base = 'out_' - for name in ['early_stop_plasma_plasma','plasma_train_augmented_005_plasma_train','plasma_train_augmented_01_plasma_train', - 'plasma_train_augmented_02_plasma_train','plasma_train_augmented_03_plasma_train','plasma_train_augmented_04_plasma_train', - 'plasma_train_augmented_05_plasma_train','plasma_train_augmented_07_plasma_train','plasma_train_augmented_1_plasma_train', - 'plasma_train_augmented_all_plasma_train','prosit_plasma']: - print(name) + threshold = ['005','01','02','03','04','05','07','1','all'] + base_aug_1 = 'out_early_stop_ISA_aug_' + base_aug_2 = '_ISA_noc' + name_list = ['out_early_stop_ISA_noc_ISA_noc'] + + for t in threshold : + name_list.append(base_aug_1+t+base_aug_2) + name_list.append('out_early_stop_prosit_ISA_noc') + + print(name_list) + for name in name_list: r2_list=[] - for index in range(9): - dataframe = pd.read_csv('../output/'+base+name+'_'+str(index)+'.csv') + for index in range(10): + dataframe = pd.read_csv('../archive_output/ISA/'+name+'_'+str(index)+'.csv') r2_list.append(r2_score(dataframe['true rt'], dataframe['rt pred'])) all_data.append(r2_list) fig, axs = plt.subplots(2, 1, figsize=(9, 8)) @@ -186,7 +191,7 @@ def calc_and_plot_res(): axs[1].set_xticks([y for y in range(len(all_data))], labels=[ 'plasma', 'Augm 0.05', 'Augm 0.1', 'Augm 0.2', 'Augm 0.3', 'Augm 0.4', 'Augm 0.5', 'Augm 0.7', 'Augm 1', 'Augm all', 'Prosit', ], rotation=30) - ref_path, base_path = './data_PXD006109/plasma_train/data_aligned_train_plasma.csv', './data_PXD006109/plasma_train/plasma_train_data_augmented_' + ref_path, base_path = './data_ISA/data_isa_noc.csv', './data_ISA/isa_data_augmented_' name = ['005','01','02','03','04','05','07','1','all'] df = pd.read_csv(ref_path) size = [df.shape[0]] @@ -208,7 +213,7 @@ def calc_and_plot_res(): - plt.savefig('../fig/model perf/summary_early_stop_plasma_train.png') + plt.savefig('../fig/model perf/summary_early_stop_isa.png')