diff --git a/generate_mixed_model_graph.py b/generate_mixed_model_graph.py
index fce2101acb0b7a0d98ca5c08e21660e7a5d8fab1..6a6d80c5f624b51325168c2d7a821337bec7c3ba 100644
--- a/generate_mixed_model_graph.py
+++ b/generate_mixed_model_graph.py
@@ -4,10 +4,6 @@ import os
 
 import networkx as nx
 import argparse
-import numpy as np
-import pandas as pd
-import random
-import copy
 from tqdm import tqdm
 import lib.random as ra
 
diff --git a/generate_theoric_random_graph.py b/generate_theoric_random_graph.py
index 885b99785dbdf883ca940f98d0802ca86a4990f5..51ca442da1585eb8a584de318966febf2f94fb44 100644
--- a/generate_theoric_random_graph.py
+++ b/generate_theoric_random_graph.py
@@ -4,10 +4,6 @@ import os
 
 import networkx as nx
 import argparse
-import numpy as np
-import pandas as pd
-import random
-import copy
 from tqdm import tqdm
 import lib.random as ra
 
diff --git a/lib/erosion_model.py b/lib/erosion_model.py
index ef37b8c1eefb58b48790b229ab7d41de20f50324..48eb731473891ac9edc82570336532d1bcb5798b 100644
--- a/lib/erosion_model.py
+++ b/lib/erosion_model.py
@@ -1,10 +1,9 @@
 # coding = utf-8
+
 from sklearn.linear_model import LogisticRegression
 from sklearn.metrics import roc_auc_score
-from tqdm import tqdm
 
 from .link_prediction_eval import get_auc_heuristics, split_train_test, get_all_possible_edges
-from .random import get_spat_probs, get_sbm_probs
 from .lambda_func import euclid_dist as dist
 from .lambda_func import  hash_func
 
@@ -21,20 +20,22 @@ def log(x):
     if VERBOSE:
         print(x)
 
-def probs_computation_based_on_weight(weights,n=100000):
+
+def new_probs(weights,nb_edges,n=1000):
     a = np.copy(weights)
     idx_vals = np.arange(len(a))
     res = np.zeros(len(a))
     for i in range(n):
-        idxrand = np.random.choice(idx_vals,1,p=a)
-        res[idxrand] = res[idxrand] + 1
+        idxrand = np.random.choice(a=idx_vals,size=nb_edges,p=a)
+        for x in idxrand:
+            res[x] = res[x] + 1
 
     res/=n
     return res
 
 
 class ErosionModel():
-    def __init__(self, G):
+    def __init__(self, G,spatial_exponent = 2):
         self.G = G
         self.coordinates = nx.get_node_attributes(G, "pos")
         self.block_assign = nx.get_node_attributes(G, "block")
@@ -42,10 +43,13 @@ class ErosionModel():
         self.initialize()
         self.H = G.copy()
 
+        self.graph_history = []
+
         self.nb_of_erosion = 0
+        self.spatial_exponent  = spatial_exponent
 
-    def erode(self):
 
+    def erode(self):
         test_H, _ = pp.prep_graph(self.H.copy(), maincc=True, relabel=False)
         if len(test_H) < 30:
             return False
@@ -53,69 +57,45 @@ class ErosionModel():
             return False
         self.nb_of_erosion += 1
 
-        old_probs = dict(self.probs_df["hash_ p_{0}".format(self.nb_of_erosion - 1).split()].values)
-
         auc_sbm, auc_spatial = get_auc_heuristics(self.H, 60)
-        if VERBOSE:print("SBM AUC",auc_sbm,"SPATIAL AUC",auc_spatial)
+        if VERBOSE: print("SBM AUC", auc_sbm, "SPATIAL AUC", auc_spatial)
         edges = get_all_possible_edges(self.H)
+
         if auc_sbm > auc_spatial:
             probs = stochastic_block_model(self.H, edges)
         else:
-            probs = spatial_link_prediction(self.H, edges)
+            probs = spatial_link_prediction(self.H, edges,exponent=self.spatial_exponent)
 
         edges = np.asarray(edges)
-        probs_dom = np.asarray(probs)
-
-        probs_dom = probs_computation_based_on_weight(probs_dom/probs_dom.sum())
-        sum_prob_dom = probs_dom.sum()
-        sum_prob_dom_H = sum([probs[ix] for ix, ed in enumerate(edges) if self.H.has_edge(*ed)])
+        probs= np.asarray(probs)
 
-        #store the model
-        probs_dom /= sum_prob_dom
+        is_in_H = np.asarray([int(self.H.has_edge(*ed)) for ed in edges]) # is a pair of nodes in H
+        empiric_probs = new_probs(probs/probs.sum(),self.H.size()) # compute empiric probs
+        erode_model = is_in_H-empiric_probs #Compute erode model
+        erode_model[erode_model<0] = 0
 
-        edge_prob = dict(zip([hash_func(ed) for ed in edges], probs_dom))
+        edge_prob = dict(zip([hash_func(ed) for ed in edges], probs))
         self.probs_df["p_{0}".format(self.nb_of_erosion)] = self.probs_df.apply(
             lambda x: edge_prob[hash_func([int(x.u), int(x.v)])] if hash_func([int(x.u), int(x.v)]) in edge_prob else 0,
             axis=1)
+        new_nb_edges = erode_model.sum()*0.7
 
-        # Compute new edges
-        hhh = np.asarray(
-            [(1 / self.H.size()) - ((probs_dom[ix]*sum_prob_dom)/sum_prob_dom_H) for ix, ed in enumerate(edges) if self.H.has_edge(*ed)])
-        hhh[hhh < 0] = 0
-        new_nb_edges = hhh.sum() * self.H.size()
 
+        edges = edges[erode_model>0]
+        erode_model = erode_model[erode_model > 0]
+        sorted_idx = np.argsort(erode_model)[::-1][:round(new_nb_edges)]
+        #index_selected_pairs = np.random.choice(np.arange(len(edges)), round(new_nb_edges), p=erode_model/erode_model.sum(),
+        #                                        replace=False)  # round(0.7*H.size()) round(new_nb_edges)
 
+        #G2 = nx.from_edgelist(edges[index_selected_pairs])
+        G2 = nx.from_edgelist(edges[sorted_idx])
 
-
-
-
-        # Compute prob erosion
-        probs_erosion = np.asarray([old_probs[hash_func(ed)] - probs_dom[ix] for ix, ed in enumerate(edges)])
-        print("probs_erosion",probs_erosion)
-        probs_erosion[probs_erosion <= 0] = 0
-        print("probs erosion after filter negative value",probs_erosion)
-        probs_erosion /= probs_erosion.sum()
-        print("probserosion at ",self.nb_of_erosion,"with ",np.count_nonzero(probs_erosion),"of non zero values")
-
-        # Generate new graph
-        edges = edges[probs_erosion > 0]
-        probs_erosion=probs_erosion[probs_erosion > 0]
-        print("EDGES for erosion", edges)
-        print("|E| with erosion and len(probs_ero)",len(edges),len(probs_erosion))
-        print("new_edges_len",round(new_nb_edges))
-
-        if new_nb_edges > len(edges):
-            return False
-        final_edges = []
-        index_selected_pairs = np.random.choice(np.arange(len(edges)), round(new_nb_edges), p=probs_erosion,
-                                                replace=False)  # round(0.7*H.size()) round(new_nb_edges)
-        final_edges.extend(edges[index_selected_pairs])
-
-        G2 = nx.from_edgelist(final_edges)
         for n in list(G2.nodes()):
             G2.nodes[n]["block"] = self.block_assign[n]
             G2.nodes[n]["pos"] = self.coordinates[n]
         self.H = G2.copy()
+        self.graph_history.append(self.H.copy())
+        return probs
 
     def erode_n_times(self,n):
         if self.nb_of_erosion >0:
@@ -128,7 +108,8 @@ class ErosionModel():
             log(i)
             log(self.H.size())
             r = self.erode()
-            if r == False: # we cannot erode further
+
+            if type(r) == bool and  r == False: # we cannot erode further
                 log("Cannot erode further")
                 break
 
@@ -180,6 +161,24 @@ class ErosionModel():
 
         return X_train,X_test,y_train,y_test
 
+    def plot(self):
+        import matplotlib.pyplot as plt
+        fig, axes = plt.subplots(nrows=2, ncols=self.nb_of_erosion + 1, figsize=(40, 20))
+        nx.draw(self.G, ax=axes[0][0], node_size=20, pos=nx.get_node_attributes(self.G, "pos"),
+                node_color=[v for _, v in nx.get_node_attributes(self.G, "block").items()])
+        nx.draw(self.G, ax=axes[1][0], node_size=20,
+                node_color=[v for _, v in nx.get_node_attributes(self.G, "block").items()])
+        axes[0][0].set_title("Original graph (with pos)")
+        axes[1][0].set_title("Original graph (without pos)")
+        for i in range(1, self.nb_of_erosion + 1):
+            g_ = self.graph_history[i - 1]
+            nx.draw(g_, ax=axes[0][i], node_size=20, pos=nx.get_node_attributes(g_, "pos"),
+                    node_color=[v for _, v in nx.get_node_attributes(g_, "block").items()])
+            nx.draw(g_, ax=axes[1][i], node_size=20,
+                    node_color=[v for _, v in nx.get_node_attributes(g_, "block").items()])
+            axes[0][i].set_title("Erosion {0} (with pos)".format(i))
+            axes[1][i].set_title("Erosion {0} (without pos)".format(i))
+        return fig,axes
 
 def position_str_process(G):
     def foo(x):
diff --git a/lib/helpers.py b/lib/helpers.py
index 7e27eb096fbae124fbe10fb5025540929c5a5b3e..e5db269b09e8292ab00595ad05d369c72e10442f 100644
--- a/lib/helpers.py
+++ b/lib/helpers.py
@@ -1,7 +1,9 @@
-import pandas as pd
 import numpy as np
 import networkx as nx
-import os
+import pandas as pd
+from .link_prediction_eval import get_all_possible_edges
+
+
 
 try:
     import graph_tool as gt
@@ -154,4 +156,33 @@ def nx2gt(nxG):
             gtG.ep[key][e] = value  # ep is short for edge_properties
 
     # Done, finally!
-    return gtG
\ No newline at end of file
+    return gtG
+
+
+
+def get_distribution_dist(G, dist_func):
+    df = pd.DataFrame(get_all_possible_edges(G), columns="u v".split())
+
+    df["in_G"] = df.apply(lambda row: 1 if G.has_edge(row.u, row.v) else 0, axis=1)
+
+    # Compute distance
+    nodes_positions = nx.get_node_attributes(G, "pos")
+    df["distance"] = df.apply(lambda row: dist_func(nodes_positions[row.u], nodes_positions[row.v]), axis=1)
+    # df["distance"] = 1/(df.distance**4) # normaalise between 0 and 1
+
+    #  divided in 30 bins
+    df['bin'] = pd.cut(df['distance'], bins=np.linspace(0, df.distance.max(), num=30))
+
+    # compute the count of pair of nodes within each bins
+    count_all_per_bin = df.groupby("bin", as_index=False).count()
+    count_all_per_bin = dict(count_all_per_bin["bin distance".split()].values)
+
+    #  compute number of edges in each distance bins
+    new_df = df[df.in_G == 1].groupby("bin", as_index=False).count()
+
+    new_df = new_df["bin distance".split()].rename(columns={"distance": "edge_count_per_bin"})
+    new_df["count_all_per_bin"] = new_df.bin.apply(lambda x: count_all_per_bin[x])
+    new_df["value"] = new_df.apply(
+        lambda x: x.edge_count_per_bin / x.count_all_per_bin if x.count_all_per_bin > 0 else 0, axis=1)
+    return new_df
+
diff --git a/lib/random.py b/lib/random.py
index 2f439a512fd172e5b4982fbd073a4c5a41fa92f2..5b4f1c4ec451fcbdbc13440a1e7734103f4da649 100644
--- a/lib/random.py
+++ b/lib/random.py
@@ -1,12 +1,10 @@
 # coding = utf-8
-import copy
 from collections import Iterable
 
 import numpy as np
 import networkx as nx
 import pandas as pd
 from networkx.generators.degree_seq import _to_stublist
-from cdlib import algorithms
 import random
 float_epsilon = np.finfo(float).eps
 
@@ -181,8 +179,8 @@ def spatial_graph(nb_nodes, nb_edges, coords="country", dist_func=lambda a, b: n
             raise ValueError("number of nodes must match the size of the coords dict")
     elif coords == "random":
         coords = np.random.random(nb_nodes * 2).reshape(nb_nodes, 2)
-        coords[:, 0] = (coords[:, 0] * 360) - 180
-        coords[:, 1] = (coords[:, 1] * 180) - 90
+        #coords[:, 0] = (coords[:, 0] * 360) - 180
+        #coords[:, 1] = (coords[:, 1] * 180) - 90
     else:
         coords = get_countries_coords()
         if nb_nodes > len(coords):
diff --git a/lib/utils.py b/lib/utils.py
index aa53243454976aacdb7536b331f38ba182ae04d5..d6edfd859173f179ef41ce29acd3c57e61627264 100644
--- a/lib/utils.py
+++ b/lib/utils.py
@@ -1,6 +1,4 @@
 import pandas as pd
-from sklearn.preprocessing import LabelEncoder
-import numpy as np
 import networkx as nx
 
 
diff --git a/lib/visualisation.py b/lib/visualisation.py
index 032a64d1db03fd63e954c700aed5ec823a6a81a6..ed5d6f45019642453fefadd22d49577ed93bbca3 100644
--- a/lib/visualisation.py
+++ b/lib/visualisation.py
@@ -1,7 +1,6 @@
 # coding = utf-8
 
 import pandas as pd
-import numpy as np
 import seaborn as sns
 import matplotlib.pyplot as plt
 import re
diff --git a/requirements.txt b/requirements.txt
index 939f6cf2ec7faac059bc68cc69375de12e54b128..4d553c63be14c3096820ec8f69537515720c856a 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -4,3 +4,9 @@ sklearn
 seaborn
 haversine
 geopandas
+networkx
+matplotlib
+seaborn
+joblib
+tqdm
+scikit-learn