From 2d98534b78b350d71003197ff9f2d99be6255683 Mon Sep 17 00:00:00 2001 From: Schneider Leo <leo.schneider@etu.ec-lyon.fr> Date: Thu, 23 Jan 2025 15:23:40 +0100 Subject: [PATCH] data viz --- data/data_processing.py | 1 - data/data_viz.py | 68 ++++++++++++++++++++++++++++++++++------- 2 files changed, 57 insertions(+), 12 deletions(-) diff --git a/data/data_processing.py b/data/data_processing.py index 0898283..a1ac05b 100644 --- a/data/data_processing.py +++ b/data/data_processing.py @@ -155,7 +155,6 @@ if __name__ == '__main__': df_base = pd.read_csv('./data_PXD006109/plasma_train/data_aligned_train_plasma.csv') df_base = df_base[['sequence', 'irt_scaled','state']] t = [0.05,0.1,0.2,0.3,0.4,0.5,0.7,1,10] - #reste 07 1 et all name = ['005','01','02','03','04','05','07','1','all'] df_0 = pd.read_csv('../output/out_plasma_aligned_train_0.csv') df_1 = pd.read_csv('../output/out_plasma_aligned_train_1.csv') diff --git a/data/data_viz.py b/data/data_viz.py index 42a5f82..f3ccbe1 100644 --- a/data/data_viz.py +++ b/data/data_viz.py @@ -166,26 +166,51 @@ def plot_res(): def calc_and_plot_res(): all_data=[] base = 'out_' - for name in ['early_stop_plasma_plasma','plasma_aug_005_plasma','plasma_aug_01_plasma', - 'plasma_aug_02_plasma','plasma_aug_03_plasma','plasma_aug_04_plasma', - 'plasma_aug_05_plasma','plasma_aug_07_plasma','plasma_aug_1_plasma', - 'plasma_aug_all_plasma','prosit_plasma']: + for name in ['early_stop_plasma_plasma','plasma_train_augmented_005_plasma_train','plasma_train_augmented_01_plasma_train', + 'plasma_train_augmented_02_plasma_train','plasma_train_augmented_03_plasma_train','plasma_train_augmented_04_plasma_train', + 'plasma_train_augmented_05_plasma_train','plasma_train_augmented_07_plasma_train','plasma_train_augmented_1_plasma_train', + 'plasma_train_augmented_all_plasma_train','prosit_plasma']: print(name) r2_list=[] for index in range(9): dataframe = pd.read_csv('../output/'+base+name+'_'+str(index)+'.csv') r2_list.append(r2_score(dataframe['true rt'], dataframe['rt pred'])) all_data.append(r2_list) - fig, axs = plt.subplots(figsize=(9, 4)) - axs.boxplot(all_data) - axs.set_title('Box plot') + fig, axs = plt.subplots(2, 1, figsize=(9, 8)) + axs[0].boxplot(all_data) + + axs[0].set_title('Box plot') # adding horizontal grid lines - axs.yaxis.grid(True) - axs.set_xticks([y + 1 for y in range(len(all_data))], + axs[0].yaxis.grid(True) + axs[1].set_xticks([y for y in range(len(all_data))], labels=[ 'plasma', 'Augm 0.05', 'Augm 0.1', 'Augm 0.2', 'Augm 0.3', 'Augm 0.4', 'Augm 0.5', 'Augm 0.7', - 'Augm 1', 'Augm all', 'Prosit', ]) - plt.savefig('../fig/model perf/summary_early_stop_plasma.png') + 'Augm 1', 'Augm all', 'Prosit', ], rotation=30) + ref_path, base_path = './data_PXD006109/plasma_train/data_aligned_train_plasma.csv', './data_PXD006109/plasma_train/plasma_train_data_augmented_' + name = ['005','01','02','03','04','05','07','1','all'] + df = pd.read_csv(ref_path) + size = [df.shape[0]] + x = [i for i in range(len(name)+2)] + for i in range(len(name)): + path = base_path+name[i]+'.csv' + df = pd.read_csv(path) + s = df.shape[0] + + + size.append(s) + df = pd.read_csv('./data_prosit/data.csv') + size.append(df.shape[0]) + size = np.array(size)/max(size) + for i in range(len(size)): + axs[1].text(i, size[i], '{:.3f}'.format(size[i]), style='italic') + axs[0].text(i+0.5, np.mean(all_data[i])-0.01, f': {np.mean(all_data[i]):.3f}') + axs[1].plot(x,size) + + + + plt.savefig('../fig/model perf/summary_early_stop_plasma_train.png') + + def error_by_methionine(dataframe): def fonc(a): @@ -216,6 +241,27 @@ def filter_outlier_rt(dataframe): plt.savefig('../fig/data_exploration/outlier_selection.png') return df_out +def plot_augmented_dataset_size(ref_path,base_path): + t = [0.05,0.1,0.2,0.3,0.4,0.5,0.7,1,10] + name = ['005','01','02','03','04','05','07','1','all'] + df = pd.read_csv(ref_path) + size = [df.shape[0]] + x = [i for i in range(len(name)+1)] + for i in range(len(name)): + path = base_path+name[i]+'.csv' + df = pd.read_csv(path) + size.append(df.shape[0]) + size = np.array(size)/max(size) + fig, ax = plt.subplots() + ax.plot(x,size) + ax.set_xticks([y + 1 for y in range(len(name)+1)], + labels=['base', 'Augm 0.05', 'Augm 0.1', 'Augm 0.2', 'Augm 0.3', 'Augm 0.4', 'Augm 0.5', + 'Augm 0.7', + 'Augm 1', 'Augm all'], rotation=30) + + plt.savefig('../fig/data_exploration/augmented_dataset_size.png') + + if __name__ == '__main__' : calc_and_plot_res() # base = ['plasma_plasma','plasma_prosit'] -- GitLab