From a12ebebc63249a8ce95c508126f36a542d501b8a Mon Sep 17 00:00:00 2001
From: Fize Jacques <jacques.fize@cirad.fr>
Date: Mon, 12 Apr 2021 14:30:35 +0200
Subject: [PATCH] debug

---
 eval_mixed_model.py     | 37 ++++++++++++++++++++++++-------------
 run_eval_mixed_model.sh |  4 ++--
 2 files changed, 26 insertions(+), 15 deletions(-)

diff --git a/eval_mixed_model.py b/eval_mixed_model.py
index ba303bb..678dcda 100644
--- a/eval_mixed_model.py
+++ b/eval_mixed_model.py
@@ -36,6 +36,8 @@ parser.add_argument("nb_nodes",type=int)
 parser.add_argument("nb_edges",type=int)
 parser.add_argument("nb_com",type=int)
 parser.add_argument("alpha",type=float)
+parser.add_argument("nb_iterations",type=int)
+parser.add_argument('-f', '--features', help='Feature(s) used in the model training', type=str)
 parser.add_argument("-v","--verbose",action="store_true")
 
 args= parser.parse_args()
@@ -44,8 +46,9 @@ GRAPH_NODE_NB = args.nb_nodes
 GRAPH_EDGE_NB = args.nb_edges
 ALPHA = args.alpha
 NB_COM = args.nb_com
-NB_ITERATION = 3
+NB_ITERATION = args.nb_iterations
 VERBOSE = args.verbose
+FEATURES = set(args.features.split(","))
 
 dist = lambda a,b : np.linalg.norm(a-b)**2
 hash_func = lambda x:"_".join(sorted([str(x[0]),str(x[1])]))
@@ -113,23 +116,31 @@ y_train = traintest_split.train_labels
 X_test = traintest_split.test_edges
 y_test = traintest_split.test_labels
 
+if "pos" in FEATURES:
+    pos = nx.get_node_attributes(G,"pos")
+    dist_X_train = np.asarray([dist(pos[ed[0]],pos[ed[1]]) for ed in X_train]).reshape(-1,1)
+    dist_X_test = np.asarray([dist(pos[ed[0]],pos[ed[1]]) for ed in X_test]).reshape(-1,1)
 
-pos = nx.get_node_attributes(G,"pos")
-dist_X_train = np.asarray([dist(pos[ed[0]],pos[ed[1]]) for ed in X_train]).reshape(-1,1)
-dist_X_test = np.asarray([dist(pos[ed[0]],pos[ed[1]]) for ed in X_test]).reshape(-1,1)
+    X_train = np.concatenate((X_train, dist_X_train), axis=1)
+    X_test = np.concatenate((X_test, dist_X_test), axis=1)
+
+if "centrality" in FEATURES:
+    centrality = nx.degree_centrality(G)
+    centrality_X_train = np.asarray([[centrality[ed[0]],centrality[ed[1]]] for ed in X_train])
+    centrality_X_test = np.asarray([[centrality[ed[0]],centrality[ed[1]]] for ed in X_test])
+
+    X_train = np.concatenate((X_train, centrality_X_train), axis=1)
+    X_test = np.concatenate((X_test, centrality_X_test), axis=1)
 
 
-centrality = nx.degree_centrality(G)
-centrality_X_train = np.asarray([[centrality[ed[0]],centrality[ed[1]]] for ed in X_train])
-centrality_X_test = np.asarray([[centrality[ed[0]],centrality[ed[1]]] for ed in X_test])
+if "it_probs":
+    if_not =[0 for i in range(NB_ITERATION-1)]
+    feature_X_train = np.asarray([ (edge_feature[hash_func(ed)] if hash_func(ed) in edge_feature else if_not) for ed in X_train])
+    feature_X_test = np.asarray([ (edge_feature[hash_func(ed)] if hash_func(ed) in edge_feature else if_not) for ed in X_test])
 
-if_not =[0 for i in range(NB_ITERATION-1)]
-feature_X_train = np.asarray([ (edge_feature[hash_func(ed)] if hash_func(ed) in edge_feature else if_not) for ed in X_train])
-feature_X_test = np.asarray([ (edge_feature[hash_func(ed)] if hash_func(ed) in edge_feature else if_not) for ed in X_test])
+    X_train = np.concatenate((X_train, feature_X_train), axis=1)
+    X_test = np.concatenate((X_test, feature_X_test ), axis=1)
 
-##ADD centrality and distance to X train
-X_train = np.concatenate((X_train,dist_X_train,centrality_X_train),axis=1)
-X_test = np.concatenate((X_test,dist_X_test,centrality_X_test),axis=1)
 
 
 classifier_dict = {
diff --git a/run_eval_mixed_model.sh b/run_eval_mixed_model.sh
index fed67e7..2544835 100644
--- a/run_eval_mixed_model.sh
+++ b/run_eval_mixed_model.sh
@@ -5,7 +5,7 @@ do
   for nbcom in 2 3 4 5
   do
     echo "alpha= "$alpha", nb_com= "$nbcom
-    python eval_mixed_model.py 100 200 $nbcom $alpha
-    python eval_mixed_model.py 300 600 $nbcom $alpha
+    python eval_mixed_model.py 100 200 $nbcom $alpha -f pos,centrality,it_probs
+    python eval_mixed_model.py 300 600 $nbcom $alpha -f pos,centrality,it_probs
   done
 done
\ No newline at end of file
-- 
GitLab