Skip to content
Snippets Groups Projects
Commit 8071b8c2 authored by Arthur Batel's avatar Arthur Batel
Browse files

embeddings_saving

parent cde26d76
No related branches found
No related tags found
No related merge requests found
...@@ -149,11 +149,11 @@ def init_frequencielle(train, n, p, dim, dico_items): ...@@ -149,11 +149,11 @@ def init_frequencielle(train, n, p, dim, dico_items):
cc_exp_col[i] = cc_exp_col[i] / np.sum(cc_exp_col[i]) cc_exp_col[i] = cc_exp_col[i] / np.sum(cc_exp_col[i])
return cc_exp_row, cc_exp_col return cc_exp_row, cc_exp_col
def write_file_doa(FileName, embed, train, dico_kc, dico_users, dico_items): def write_file_doa(FileName, embed, train, dico_kc, dico_users, dico_items, ablation):
# write embeddings # write embeddings
it = list(dico_items) it = list(dico_items)
ut = list(dico_users) ut = list(dico_users)
nom = FileName+"_embed.csv" nom = FileName+"_embed_ablation_"+str(ablation)+".csv"
f = open(nom, 'w') f = open(nom, 'w')
writer = csv.writer(f) writer = csv.writer(f)
for i in range(embed.shape[0]): for i in range(embed.shape[0]):
...@@ -220,7 +220,7 @@ if __name__ == '__main__': ...@@ -220,7 +220,7 @@ if __name__ == '__main__':
acc = bpr_model.train(train, len(dico_kc), epochs, batch_size, y_train, ablation) acc = bpr_model.train(train, len(dico_kc), epochs, batch_size, y_train, ablation)
# DOA # DOA
new_embedding_value = bpr_model.user_embeddings.weight.clone().detach().cpu().numpy() new_embedding_value = bpr_model.user_embeddings.weight.clone().detach().cpu().numpy()
write_file_doa(file, new_embedding_value, train, dico_kc, dico_users, dico_items) write_file_doa(file, new_embedding_value, train, dico_kc, dico_users, dico_items, ablation)
doa = compute_doa(file) doa = compute_doa(file)
print("DOA:", doa) print("DOA:", doa)
# Test # Test
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment