From 470500a3176acefa35a021832eec6846eef11b39 Mon Sep 17 00:00:00 2001 From: lschneider <leo.schneider@univ-lyon1.fr> Date: Thu, 19 Sep 2024 14:25:17 +0200 Subject: [PATCH] test cossim --- data_viz.py | 4 ++-- main_custom.py | 2 +- vector_sim_test.py | 27 +++++++++++++++++++++++++++ 3 files changed, 30 insertions(+), 3 deletions(-) create mode 100644 vector_sim_test.py diff --git a/data_viz.py b/data_viz.py index 52fd309..3d0f590 100644 --- a/data_viz.py +++ b/data_viz.py @@ -234,5 +234,5 @@ df = pd.read_csv('output/out_ISA.csv') add_length(df) df['abs_error'] = np.abs(df['rt pred']-df['true rt']) # histo_abs_error(df, display=False, save=True, path='temp.png') -# scatter_rt(df, display=False, save=True, path='temp.png') -histo_length_by_error(df, 10, save=True, path='temp.png') \ No newline at end of file +scatter_rt(df, display=False, save=True, path='temp.png') +# histo_length_by_error(df, 10, save=True, path='temp.png') \ No newline at end of file diff --git a/main_custom.py b/main_custom.py index fcc9074..db984b4 100644 --- a/main_custom.py +++ b/main_custom.py @@ -207,7 +207,7 @@ def main(args): print('\nData loaded') - model = Model_Common_Transformer_TAPE(encoder_ff=args.encoder_ff, decoder_rt_ff=args.decoder_rt_ff, + model = Model_Common_Transformer(encoder_ff=args.encoder_ff, decoder_rt_ff=args.decoder_rt_ff, decoder_int_ff=args.decoder_int_ff , n_head=args.n_head, encoder_num_layer=args.encoder_num_layer, decoder_int_num_layer=args.decoder_int_num_layer, diff --git a/vector_sim_test.py b/vector_sim_test.py new file mode 100644 index 0000000..1adb080 --- /dev/null +++ b/vector_sim_test.py @@ -0,0 +1,27 @@ +import numpy as np +import matplotlib.pyplot as plt + +def cos_sim(a,b): + return (a.dot(b))/(np.linalg.norm(a)*np.linalg.norm(b)) + +l10=[] +l100=[] +for _ in range(1000): + vec10 = np.random.random(10) + vec10b = vec10 + np.random.random(10) + + l10.append(cos_sim(vec10,vec10b)) + + vec100 = np.random.random(100) + vec100b = vec100 + np.random.random(100) + + l100.append(cos_sim(vec100,vec100b)) + + +fig, ax = plt.subplots() + +ax.plot(l10,c='b',ls='',marker='.') +ax.plot(l100,c='y',ls='',marker='.') +plt.savefig('temp.png') +print(np.mean(l10)) +print(np.mean(l100)) \ No newline at end of file -- GitLab