From 8071b8c27a4ed5e15b47701c8899c1c6a1d2fd9f Mon Sep 17 00:00:00 2001
From: Arthur BATEL <arthur.batel@insa-lyon.fr>
Date: Tue, 23 Apr 2024 21:02:00 +0200
Subject: [PATCH] embeddings_saving

---
 code/binary_bpr_ablation/main.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/code/binary_bpr_ablation/main.py b/code/binary_bpr_ablation/main.py
index 0f8217a..77727cd 100644
--- a/code/binary_bpr_ablation/main.py
+++ b/code/binary_bpr_ablation/main.py
@@ -149,11 +149,11 @@ def init_frequencielle(train, n, p, dim, dico_items):
         cc_exp_col[i] = cc_exp_col[i] / np.sum(cc_exp_col[i])
     return cc_exp_row, cc_exp_col 
 
-def write_file_doa(FileName, embed, train, dico_kc, dico_users, dico_items):
+def write_file_doa(FileName, embed, train, dico_kc, dico_users, dico_items, ablation):
     # write embeddings 
     it = list(dico_items)
     ut = list(dico_users)
-    nom = FileName+"_embed.csv"
+    nom = FileName+"_embed_ablation_"+str(ablation)+".csv"
     f = open(nom, 'w')
     writer = csv.writer(f)
     for i in range(embed.shape[0]):
@@ -220,7 +220,7 @@ if __name__ == '__main__':
     acc = bpr_model.train(train, len(dico_kc), epochs, batch_size, y_train, ablation)
     # DOA 
     new_embedding_value = bpr_model.user_embeddings.weight.clone().detach().cpu().numpy()
-    write_file_doa(file, new_embedding_value, train, dico_kc, dico_users, dico_items)  
+    write_file_doa(file, new_embedding_value, train, dico_kc, dico_users, dico_items, ablation)
     doa = compute_doa(file)
     print("DOA:", doa)
     # Test
-- 
GitLab