Skip to content
Snippets Groups Projects
Commit 470500a3 authored by Léo Schneider's avatar Léo Schneider
Browse files

test cossim

parent 19f55602
No related branches found
No related tags found
No related merge requests found
......@@ -234,5 +234,5 @@ df = pd.read_csv('output/out_ISA.csv')
add_length(df)
df['abs_error'] = np.abs(df['rt pred']-df['true rt'])
# histo_abs_error(df, display=False, save=True, path='temp.png')
# scatter_rt(df, display=False, save=True, path='temp.png')
histo_length_by_error(df, 10, save=True, path='temp.png')
\ No newline at end of file
scatter_rt(df, display=False, save=True, path='temp.png')
# histo_length_by_error(df, 10, save=True, path='temp.png')
\ No newline at end of file
......@@ -207,7 +207,7 @@ def main(args):
print('\nData loaded')
model = Model_Common_Transformer_TAPE(encoder_ff=args.encoder_ff, decoder_rt_ff=args.decoder_rt_ff,
model = Model_Common_Transformer(encoder_ff=args.encoder_ff, decoder_rt_ff=args.decoder_rt_ff,
decoder_int_ff=args.decoder_int_ff
, n_head=args.n_head, encoder_num_layer=args.encoder_num_layer,
decoder_int_num_layer=args.decoder_int_num_layer,
......
import numpy as np
import matplotlib.pyplot as plt
def cos_sim(a,b):
return (a.dot(b))/(np.linalg.norm(a)*np.linalg.norm(b))
l10=[]
l100=[]
for _ in range(1000):
vec10 = np.random.random(10)
vec10b = vec10 + np.random.random(10)
l10.append(cos_sim(vec10,vec10b))
vec100 = np.random.random(100)
vec100b = vec100 + np.random.random(100)
l100.append(cos_sim(vec100,vec100b))
fig, ax = plt.subplots()
ax.plot(l10,c='b',ls='',marker='.')
ax.plot(l100,c='y',ls='',marker='.')
plt.savefig('temp.png')
print(np.mean(l10))
print(np.mean(l100))
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment