Skip to content
Snippets Groups Projects
Commit 3d904f5a authored by Céline Robardet's avatar Céline Robardet
Browse files

xp

parent 6f1e1b4f
No related branches found
No related tags found
No related merge requests found
......@@ -151,7 +151,7 @@ class BPRModel(nn.Module):
all_decisions = np.concatenate((all_decisions, comp), axis=0)
mse1 = mean_squared_error(all_labels, all_predictions)
print("RMSE", np.sqrt(mse1))#, root_mean_squared_error(all_labels, all_predictions))
print("RMSE", np.sqrt(mse1))
auc = roc_auc_score(all_labels, all_predictions)
print("AUC:", auc)
acc = accuracy_score(all_labels, all_decisions)
......
......@@ -38,21 +38,18 @@ def read_file(dataTrain, dataTest):
kc = flattern_arrays(kc.values, kcT.values)
num_kc = len(kc)
dico_kc = { k:v for (k,v) in zip(kc, range(len(kc)))}
print("NB KC", num_kc)
# dico users
users = df['user_id']
usersT = dfTest['user_id']
users = flattern_arrays(users.values, usersT.values)
num_users = len(users)
dico_users = { k:v for (k,v) in zip(users, range(num_users))}
print("NB Users", num_users)
# dico items and their associated kc
itemsDT = df['item_id']
itemsT = dfTest['item_id']
items = flattern_arrays(itemsDT.values, itemsT.values)
num_items = len(items)
dico_items = { k:v for (k,v) in zip(items, range(num_items))}
#print("NB Items", num_items)
return dico_kc, dico_users, dico_items
def parse_dataframe(data, dico_kc, dico_users, dico_item, is_train = True):
......@@ -79,7 +76,6 @@ def parse_dataframe(data, dico_kc, dico_users, dico_item, is_train = True):
col = row['item_id']
if col not in dico_items:
dico_items[col] = len(dico_items)
# Warning, all user's answers are positives!
q,r = parse_it(col)
col_neg = q+'_'+str(1-int(r))
if col_neg not in dico_items:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment