diff --git a/code/binary_bpr_ablation/main.py b/code/binary_bpr_ablation/main.py index 0f8217a25a3c6d6a5df47629e822a2bd2c5bbf26..77727cd23187bbbeab3c08cea2591ff840bf192b 100644 --- a/code/binary_bpr_ablation/main.py +++ b/code/binary_bpr_ablation/main.py @@ -149,11 +149,11 @@ def init_frequencielle(train, n, p, dim, dico_items): cc_exp_col[i] = cc_exp_col[i] / np.sum(cc_exp_col[i]) return cc_exp_row, cc_exp_col -def write_file_doa(FileName, embed, train, dico_kc, dico_users, dico_items): +def write_file_doa(FileName, embed, train, dico_kc, dico_users, dico_items, ablation): # write embeddings it = list(dico_items) ut = list(dico_users) - nom = FileName+"_embed.csv" + nom = FileName+"_embed_ablation_"+str(ablation)+".csv" f = open(nom, 'w') writer = csv.writer(f) for i in range(embed.shape[0]): @@ -220,7 +220,7 @@ if __name__ == '__main__': acc = bpr_model.train(train, len(dico_kc), epochs, batch_size, y_train, ablation) # DOA new_embedding_value = bpr_model.user_embeddings.weight.clone().detach().cpu().numpy() - write_file_doa(file, new_embedding_value, train, dico_kc, dico_users, dico_items) + write_file_doa(file, new_embedding_value, train, dico_kc, dico_users, dico_items, ablation) doa = compute_doa(file) print("DOA:", doa) # Test