diff --git a/data_viz.py b/data_viz.py index f2764b8055e3d620d35087131eb3e7ffb42cede7..d1686aeb63283fc607489cfa1090019d298a2576 100644 --- a/data_viz.py +++ b/data_viz.py @@ -1,6 +1,9 @@ import matplotlib.pyplot as plt import numpy as np import random + +from sympy.utilities.misc import replace + from mass_prediction import compute_frag_mz_ration seq = 'YEEEFLR' @@ -110,4 +113,89 @@ def frag_spectra_comparison(int_1, seq_1, int_2, seq_2=None): ax.spines[["left", "top", "right"]].set_visible(False) ax.margins(y=0.1) - plt.show() \ No newline at end of file + plt.show() + + +def histo_abs_error(dataframe, display=False, save=False, path=None): + points = dataframe['abs_error'] + + ## combine these different collections into a list + data_to_plot = [points] + + # Create a figure instance + fig = plt.figure() + + # Create an axes instance + ax = fig.add_axes([0, 0, 1, 1]) + + # Create the boxplot + bp = ax.violinplot(data_to_plot, vert=False) + if display : + plt.show() + + if save : + plt.savefig(path) + + + +def histo_abs_error_by_length(dataframe, display=False, save=False, path=None): + data_to_plot =[] + max_length = max(dataframe['length']) + min_length = min(dataframe['length']) + for l in range(min_length, max_length): + data_to_plot.append(dataframe['abs_error'].where(dataframe['length']==l)) + + + # Create a figure instance + fig = plt.figure() + + # Create an axes instance + ax = fig.add_axes([0, 0, 1, 1]) + + # Create the boxplot + bp = ax.violinplot(data_to_plot, vert=True) + if display: + plt.show() + + if save: + plt.savefig(path) + +def histo_length_by_error(dataframe, bins, display=False, save=False, path=None): + data_to_plot = [] + + max_error = max(dataframe['abs_error']) + inter = np.linspace(0, max_error, num=bins) + for i in range(bins): + data_to_plot.append(dataframe['length'].where(inter[i] < dataframe['abs_error'] < inter[i+1])) + + # Create a figure instance + fig = plt.figure() + + # Create an axes instance + ax = fig.add_axes([0, 0, 1, 1]) + + # Create the boxplot + bp = ax.violinplot(data_to_plot, vert=False) + if display: + plt.show() + + if save: + plt.savefig(path) + +def compare_error(df1, df2, display=False, save=False, path=None): + size = len(df2) + ind = np.random.choice(range(size), size=10, replace=False) + seq1 = df1['seq'][ind] + seq2 = df2['seq'][ind] + data_1 = df1['abs_error'][ind] + data_2 = df2['abs_error'][ind] + + fig, ax = plt.subplots(figsize=(2, 1)) + ax[0, 0].bar(seq1, data_1, width=0.8) + ax[1, 0].bar(seq2, data_2, width=0.8) + + if display: + plt.show() + + if save: + plt.savefig(path) diff --git a/main_custom.py b/main_custom.py index 672f886b048081b12903a77ffbf7e3dd7e662454..340bd9843a47c43ebc76e133cc91afef86124be7 100644 --- a/main_custom.py +++ b/main_custom.py @@ -183,6 +183,7 @@ def run(epochs, eval_inter, save_inter, model, data_train, data_val, data_test, wandb=wandb) if e % save_inter == 0: save(model, 'model_common_' + str(e) + '.pt') + save_pred(model, data_val, forward, 'output/out.csv') def main(args):