Skip to content
Snippets Groups Projects
Commit e97b3aa7 authored by Schneider Leo's avatar Schneider Leo
Browse files

color scatter + model eval

parent 1096b4e3
No related branches found
No related tags found
No related merge requests found
...@@ -4,7 +4,7 @@ import matplotlib.pyplot as plt ...@@ -4,7 +4,7 @@ import matplotlib.pyplot as plt
import numpy as np import numpy as np
import random import random
import pandas as pd import pandas as pd
import matplotlib.colors as mcolors
from mass_prediction import compute_frag_mz_ration from mass_prediction import compute_frag_mz_ration
...@@ -124,22 +124,37 @@ def histo_abs_error(dataframe, display=False, save=False, path=None): ...@@ -124,22 +124,37 @@ def histo_abs_error(dataframe, display=False, save=False, path=None):
## combine these different collections into a list ## combine these different collections into a list
data_to_plot = [points] data_to_plot = [points]
# Create a figure instance # Create a figure instance
fig, ax = plt.subplots() fig, ax = plt.subplots()
# Create the boxplot # Create the boxplot
ax.set_xlabel('abs error') ax.set_xlabel('abs error')
ax.violinplot(data_to_plot, vert=False, side='high', showmedians=True, quantiles=[0.95]) ax.violinplot(data_to_plot, vert=False, side='high', showmedians=True, quantiles=[0.95])
ax.set_xlim(0,175)
if display : if display :
plt.show() plt.show()
if save : if save :
plt.savefig(path) plt.savefig(path)
def scatter_rt(dataframe, display=False, save=False, path=None):
fig, ax = plt.subplots()
ax.scatter(dataframe['true rt'], dataframe['rt pred'], s=.1) def random_color_deterministic(df):
def rd10(str):
color = list(mcolors.CSS4_COLORS)
random.seed(str)
return color[random.randint(0,147)]
df['color']=df['seq'].map(rd10)
def scatter_rt(dataframe, display=False, save=False, path=None, color = False):
fig, ax = plt.subplots()
if color :
random_color_deterministic(dataframe)
ax.scatter(dataframe['true rt'], dataframe['rt pred'], s=.1, color = dataframe['color'])
else :
ax.scatter(dataframe['true rt'], dataframe['rt pred'], s=.1)
ax.set_xlabel('true RT') ax.set_xlabel('true RT')
ax.set_ylabel('pred RT') ax.set_ylabel('pred RT')
x = np.array([min(dataframe['true rt']), max(dataframe['true rt'])]) x = np.array([min(dataframe['true rt']), max(dataframe['true rt'])])
...@@ -153,6 +168,8 @@ def scatter_rt(dataframe, display=False, save=False, path=None): ...@@ -153,6 +168,8 @@ def scatter_rt(dataframe, display=False, save=False, path=None):
plt.savefig(path) plt.savefig(path)
def histo_abs_error_by_length(dataframe, display=False, save=False, path=None): def histo_abs_error_by_length(dataframe, display=False, save=False, path=None):
data_to_plot =[] data_to_plot =[]
max_length = max(dataframe['length']) max_length = max(dataframe['length'])
...@@ -160,6 +177,8 @@ def histo_abs_error_by_length(dataframe, display=False, save=False, path=None): ...@@ -160,6 +177,8 @@ def histo_abs_error_by_length(dataframe, display=False, save=False, path=None):
for l in range(min_length, max_length): for l in range(min_length, max_length):
data_to_plot.append(dataframe['abs_error'].where(dataframe['length']==l)) data_to_plot.append(dataframe['abs_error'].where(dataframe['length']==l))
# data_to_plot.append()
fig, ax = plt.subplots() fig, ax = plt.subplots()
...@@ -234,9 +253,23 @@ def add_length(dataframe): ...@@ -234,9 +253,23 @@ def add_length(dataframe):
dataframe['length']=dataframe['seq'].map(fonc) dataframe['length']=dataframe['seq'].map(fonc)
df = pd.read_csv('output/out_common_isa_no_tape.csv') df = pd.read_csv('output/output_common_data_ISA.csv')
add_length(df)
df['abs_error'] = np.abs(df['rt pred']-df['true rt'])
histo_abs_error(df, display=False, save=True, path='fig/custom model res/histo_ISA_ISA.png')
scatter_rt(df, display=False, save=True, path='fig/custom model res/RT_pred_ISA_ISA.png', color=True)
histo_length_by_error(df, bins=10, display=False, save=True, path='fig/custom model res/histo_length_ISA_ISA.png')
df = pd.read_csv('output/out_prosit_common.csv')
add_length(df)
df['abs_error'] = np.abs(df['rt pred']-df['true rt'])
histo_abs_error(df, display=False, save=True, path='fig/custom model res/histo_prosit_prosit.png')
scatter_rt(df, display=False, save=True, path='fig/custom model res/RT_pred_prosit_prosit.png', color=True)
histo_length_by_error(df, bins=10, display=False, save=True, path='fig/custom model res/histo_length_prosit_prosit.png')
df = pd.read_csv('output/out_common_transfer.csv')
add_length(df) add_length(df)
df['abs_error'] = np.abs(df['rt pred']-df['true rt']) df['abs_error'] = np.abs(df['rt pred']-df['true rt'])
# histo_abs_error(df, display=False, save=True, path='temp.png') histo_abs_error(df, display=False, save=True, path='fig/custom model res/histo_prosit_ISA.png')
scatter_rt(df, display=False, save=True, path='RT_pred_ISA_ISA.png') scatter_rt(df, display=False, save=True, path='fig/custom model res/RT_pred_prosit_ISA.png', color=True)
# histo_length_by_error(df, 10, save=True, path='temp.png') histo_length_by_error(df, bins=10, display=False, save=True, path='fig/custom model res/histo_length_prosit_ISA.png')
\ No newline at end of file \ No newline at end of file
...@@ -108,6 +108,7 @@ def eval(model, data_val, epoch, criterion_rt, criterion_intensity, metric_rt, m ...@@ -108,6 +108,7 @@ def eval(model, data_val, epoch, criterion_rt, criterion_intensity, metric_rt, m
losses_int = 0. losses_int = 0.
dist_rt_acc = 0. dist_rt_acc = 0.
dist_int_acc = 0. dist_int_acc = 0.
model.eval()
for param in model.parameters(): for param in model.parameters():
param.requires_grad = False param.requires_grad = False
if forward == 'both': if forward == 'both':
...@@ -273,6 +274,7 @@ def get_n_params(model): ...@@ -273,6 +274,7 @@ def get_n_params(model):
def save_pred(model, data_val, forward, output_path): def save_pred(model, data_val, forward, output_path):
data_frame = pd.DataFrame() data_frame = pd.DataFrame()
model.eval()
for param in model.parameters(): for param in model.parameters():
param.requires_grad = False param.requires_grad = False
if forward == 'both': if forward == 'both':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment