Skip to content
Snippets Groups Projects
Commit 28606210 authored by Arthur Batel's avatar Arthur Batel
Browse files

emb naming

parent b7edf5b9
No related branches found
No related tags found
No related merge requests found
......@@ -65,7 +65,7 @@ def read_file(dataTrain, dataTest):
dico_items = { k:v for (k,v) in zip(items, range(num_items))}
print("NB Items", num_items, len(dico_items))
return dico_kc, dico_users, dico_items
def save_embeddings(xpName: str, modelName: str, embeddings,userEmbDir : str,itemEmbDir : str, grid_search_id):
def save_embeddings(xpName: str, modelName: str, embeddings,userEmbDir : str,itemEmbDir : str):
"""
Saves all the metrics measured after the training process.
......@@ -77,10 +77,11 @@ def save_embeddings(xpName: str, modelName: str, embeddings,userEmbDir : str,ite
u_emb, i_emb = embeddings
results_name_file = (xpName + modelName + "_" + str(grid_search_id))
results_name_file = (xpName +"_"+ modelName)
# save embeddings
results_path = os.path.join(userEmbDir, results_name_file+".csv")
print(results_path)
np.savetxt(results_path, u_emb, delimiter=',')
# save embeddings
results_path = os.path.join(itemEmbDir, results_name_file+".csv")
......@@ -335,9 +336,7 @@ if __name__ == '__main__':
emb = [bpr_model.user_embeddings.weight.clone().detach().cpu().numpy(),
bpr_model.item_embeddings.weight.clone().detach().cpu().numpy()]
print("save")
save_embeddings(xpName=datasetName, modelName="BPR", embeddings=emb,itemEmbDir=embPath+'items/',userEmbDir=embPath+'users/',
grid_search_id=get_formatted_date())
save_embeddings(xpName=datasetName+"_"+str(i_fold), modelName="BPR", embeddings=emb,itemEmbDir=embPath+'items/',userEmbDir=embPath+'users/')
write_file_doa(FileNameTrain_temp, emb[0], train, dico_kc, dico_users, dico_items)
doa = compute_doa(FileNameTrain_temp)
......@@ -398,7 +397,7 @@ if __name__ == '__main__':
acc = bpr_model.train(train, len(dico_kc), epochs, batch_size, y_train,test,y_test)
emb = [bpr_model.user_embeddings.weight.clone().detach().cpu().numpy(),bpr_model.item_embeddings.weight.clone().detach().cpu().numpy()]
save_embeddings(xpName=datasetName, modelName="BPR", embeddings=emb,itemEmbDir=embPath+'items/',userEmbDir=embPath+'users/', grid_search_id=get_formatted_date())
save_embeddings(xpName=datasetName, modelName="BPR", embeddings=emb,itemEmbDir=embPath+'items/',userEmbDir=embPath+'users/')
write_file_doa(trainFileName, emb[0], train, dico_kc, dico_users, dico_items)
doa = compute_doa(trainFileName)
......
......@@ -2,8 +2,8 @@ import os
dPath = "../../data/cdbpr_format/"
embDirPath = "../../results/table_2/"
datasets = ['assist0910_tkde', 'assist17_tkde', 'algebra','math_1', 'math_2']
epochs = [75, 95, 5, 90, 90]
batchSize =[ 512, 512,512, 512,512]
epochs = [1,75, 95, 5, 90, 90]
batchSize =[ 512, 512,512, 512,512,4000]
learningRate = [0.01,0.01,0.01,0.01,0.01]
mode = [1,1,1,1,1]
for i in range(len(datasets)):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment