Skip to content
Snippets Groups Projects
Commit 17fbdb37 authored by Schneider Leo's avatar Schneider Leo
Browse files

data augmented

parent c4cecdfd
No related branches found
No related tags found
No related merge requests found
...@@ -136,14 +136,45 @@ def numerical_to_alphabetical_str(s): ...@@ -136,14 +136,45 @@ def numerical_to_alphabetical_str(s):
seq+=ALPHABET_UNMOD_REV[arr[i]] seq+=ALPHABET_UNMOD_REV[arr[i]]
return seq return seq
def plot_res():
import matplotlib.pyplot as plt
import numpy as np
fig, axs = plt.subplots(figsize=(9, 4))
all_data = [[0.911,0.899,0.9,0.885],
[0.852,0.75,0.857,0.788],
[0.853,0.839,0.862,0.826],
[0.902,0.833,0.808],
[0.9,0.904,0.922,0.907],
[0.915,0.912,0.923,0.911],
[0.945,0.918,0.919,0.933],
[0.91,0.919,0.927,0.906],
[0.881,0.901,0.919,0.902],
[0.893,0.909,0.918,0.896],]
# plot box plot
axs.boxplot(all_data)
axs.set_title('Box plot')
# adding horizontal grid lines
axs.yaxis.grid(True)
axs.set_xticks([y + 1 for y in range(len(all_data))],
labels=['Prosit','ISA_noc','Augm 0.05','Augm 0.1','Augm 0.2','Augm 0.3','Augm 0.4','Augm 0.7','Augm 1','Augm all',])
plt.savefig('../fig/model perf/summary.png')
if __name__ == '__main__' : if __name__ == '__main__' :
base = ['ISA_noc_ISA_noc','prosit_ISA_noc', 'ISA_noc_prosit', 'prosit_prosit'] # base = ['ISA_noc_ISA_noc','prosit_ISA_noc', 'ISA_noc_prosit', 'prosit_prosit']
augmented = ['ISA_aug_07_ISA_noc','ISA_aug_1_ISA_noc','ISA_aug_all_ISA_noc'] # augmented = ['ISA_aug_07_ISA_noc','ISA_aug_1_ISA_noc','ISA_aug_all_ISA_noc']
for f_suffix_name in augmented: # for f_suffix_name in augmented:
for number in ['1','2','3','4']: # for number in ['1','2','3','4']:
df = pd.read_csv('../output/out_{}_{}.csv'.format(f_suffix_name,number)) # df = pd.read_csv('../output/out_{}_{}.csv'.format(f_suffix_name,number))
add_length(df) # add_length(df)
df['abs_error'] = np.abs(df['rt pred']-df['true rt']) # df['abs_error'] = np.abs(df['rt pred']-df['true rt'])
# histo_abs_error(df, display=False, save=True, path='../fig/model perf/histo_{}_{}.png'.format(f_suffix_name,number)) # histo_abs_error(df, display=False, save=True, path='../fig/model perf/histo_{}_{}.png'.format(f_suffix_name,number))
scatter_rt(df, display=False, save=True, path='../fig/model perf/RT_pred_{}_{}.png'.format(f_suffix_name,number), color=True) # scatter_rt(df, display=False, save=True, path='../fig/model perf/RT_pred_{}_{}.png'.format(f_suffix_name,number), color=True)
# histo_length_by_error(df, bins=10, display=False, save=True, path='../fig/model perf/histo_length_{}_{}.png'.format(f_suffix_name,number)) # histo_length_by_error(df, bins=10, display=False, save=True, path='../fig/model perf/histo_length_{}_{}.png'.format(f_suffix_name,number))
\ No newline at end of file plot_res()
...@@ -67,17 +67,24 @@ def eval(model, data_val, epoch, criterion_rt, metric_rt, wandb=None): ...@@ -67,17 +67,24 @@ def eval(model, data_val, epoch, criterion_rt, metric_rt, wandb=None):
print('epoch : ', epoch, 'val rt loss', losses_rt / len(data_val), print('epoch : ', epoch, 'val rt loss', losses_rt / len(data_val),
"val rt mean metric : ", "val rt mean metric : ",
dist_rt_acc / len(data_val)) dist_rt_acc / len(data_val))
return losses_rt
def run(epochs, eval_inter, save_inter, model, data_train, data_val, data_test, optimizer, criterion_rt, def run(epochs, eval_inter, save_inter, model, data_train, data_val, data_test, optimizer, criterion_rt,
metric_rt, wandb=None, output='output/out.csv'): metric_rt, wandb=None, output='output/out.csv'):
mem = 1000000.
for e in range(1, epochs + 1): for e in range(1, epochs + 1):
train(model, data_train, e, optimizer, criterion_rt, metric_rt, wandb=wandb) train(model, data_train, e, optimizer, criterion_rt, metric_rt, wandb=wandb)
if e % eval_inter == 0: if e % eval_inter == 0:
eval(model, data_val, e, criterion_rt, metric_rt, wandb=wandb) losses_rt = eval(model, data_val, e, criterion_rt, metric_rt, wandb=wandb)
if losses_rt < mem :
mem = losses_rt
torch.save(model.state_dict(), output.strip('.csv')+'pt')
print('model saved')
if e % save_inter == 0: if e % save_inter == 0:
save(model, 'model_common_' + str(e) + '.pt') save(model, 'model_common_' + str(e) + '.pt')
save_pred(model, data_val, output) model.load_state_dict(torch.load(output.strip('.csv')+'pt', weights_only=True))
save_pred(model, data_test, output)
def main(args): def main(args):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment