Skip to content
Snippets Groups Projects
Commit a2a70432 authored by Léo Schneider's avatar Léo Schneider
Browse files

new metrics

parent 90bf2b77
No related branches found
No related tags found
No related merge requests found
import matplotlib.pyplot as plt
import numpy as np
import random
from sympy.utilities.misc import replace
from mass_prediction import compute_frag_mz_ration
seq = 'YEEEFLR'
......@@ -110,4 +113,89 @@ def frag_spectra_comparison(int_1, seq_1, int_2, seq_2=None):
ax.spines[["left", "top", "right"]].set_visible(False)
ax.margins(y=0.1)
plt.show()
\ No newline at end of file
plt.show()
def histo_abs_error(dataframe, display=False, save=False, path=None):
points = dataframe['abs_error']
## combine these different collections into a list
data_to_plot = [points]
# Create a figure instance
fig = plt.figure()
# Create an axes instance
ax = fig.add_axes([0, 0, 1, 1])
# Create the boxplot
bp = ax.violinplot(data_to_plot, vert=False)
if display :
plt.show()
if save :
plt.savefig(path)
def histo_abs_error_by_length(dataframe, display=False, save=False, path=None):
data_to_plot =[]
max_length = max(dataframe['length'])
min_length = min(dataframe['length'])
for l in range(min_length, max_length):
data_to_plot.append(dataframe['abs_error'].where(dataframe['length']==l))
# Create a figure instance
fig = plt.figure()
# Create an axes instance
ax = fig.add_axes([0, 0, 1, 1])
# Create the boxplot
bp = ax.violinplot(data_to_plot, vert=True)
if display:
plt.show()
if save:
plt.savefig(path)
def histo_length_by_error(dataframe, bins, display=False, save=False, path=None):
data_to_plot = []
max_error = max(dataframe['abs_error'])
inter = np.linspace(0, max_error, num=bins)
for i in range(bins):
data_to_plot.append(dataframe['length'].where(inter[i] < dataframe['abs_error'] < inter[i+1]))
# Create a figure instance
fig = plt.figure()
# Create an axes instance
ax = fig.add_axes([0, 0, 1, 1])
# Create the boxplot
bp = ax.violinplot(data_to_plot, vert=False)
if display:
plt.show()
if save:
plt.savefig(path)
def compare_error(df1, df2, display=False, save=False, path=None):
size = len(df2)
ind = np.random.choice(range(size), size=10, replace=False)
seq1 = df1['seq'][ind]
seq2 = df2['seq'][ind]
data_1 = df1['abs_error'][ind]
data_2 = df2['abs_error'][ind]
fig, ax = plt.subplots(figsize=(2, 1))
ax[0, 0].bar(seq1, data_1, width=0.8)
ax[1, 0].bar(seq2, data_2, width=0.8)
if display:
plt.show()
if save:
plt.savefig(path)
......@@ -183,6 +183,7 @@ def run(epochs, eval_inter, save_inter, model, data_train, data_val, data_test,
wandb=wandb)
if e % save_inter == 0:
save(model, 'model_common_' + str(e) + '.pt')
save_pred(model, data_val, forward, 'output/out.csv')
def main(args):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment