diff --git a/generate_mixed_model_graph.py b/generate_mixed_model_graph.py
index ee7646a4749979c5c35b39e8458ef330c0196f8a..fce2101acb0b7a0d98ca5c08e21660e7a5d8fab1 100644
--- a/generate_mixed_model_graph.py
+++ b/generate_mixed_model_graph.py
@@ -19,7 +19,7 @@ args = parser.parse_args()
 
 GRAPH_SIZE = [300,1000]
 EDGE_SIZE = [2]
-sample_per_params  = 10
+sample_per_params  = 1
 
 OUTPUT_DIR = args.output_dir
 if not os.path.exists(OUTPUT_DIR):
diff --git a/lib/erosion_model.py b/lib/erosion_model.py
index 5f503b1ad33d0c48fb066e2d94d9bd3c528c1d7c..ccc64ff3425d4fe1b2aa8669fd7052bca093bae6 100644
--- a/lib/erosion_model.py
+++ b/lib/erosion_model.py
@@ -7,6 +7,7 @@ from .random import get_spat_probs, get_sbm_probs
 from .lambda_func import euclid_dist as dist
 from .lambda_func import  hash_func
 
+from evalne.utils import preprocess as pp
 from evalne.methods.similarity import stochastic_block_model,spatial_link_prediction
 
 import pandas as pd
@@ -14,6 +15,10 @@ import networkx as nx
 import numpy as np
 float_epsilon = np.finfo(float).eps
 
+VERBOSE = True
+def log(x):
+    if VERBOSE:
+        print(x)
 
 class ErosionModel():
     def __init__(self, G):
@@ -27,35 +32,46 @@ class ErosionModel():
         self.nb_of_erosion = 0
 
     def erode(self):
-        self.nb_of_erosion += 1
 
+        test_H, _ = pp.prep_graph(self.H.copy(), maincc=True, relabel=False)
+        if len(test_H) < 30:
+            return False
         if self.H.size() < 30:
-            self.probs_df["p_{0}".format(self.nb_of_erosion)] = self.probs_df["p_{0}".format(self.nb_of_erosion - 1)]
-            return
+            return False
+        self.nb_of_erosion += 1
+
         old_probs = dict(self.probs_df["hash_ p_{0}".format(self.nb_of_erosion - 1).split()].values)
 
         auc_sbm, auc_spatial = get_auc_heuristics(self.H, 60)
+        print(auc_spatial,auc_sbm)
         edges = get_all_possible_edges(self.H)
         if auc_sbm > auc_spatial:
             probs = stochastic_block_model(self.H, edges)
         else:
             probs = spatial_link_prediction(self.H, edges)
+
         edges = np.asarray(edges)
         probs_dom = np.asarray(probs)
-        probs_dom /= probs_dom.sum()
+        sum_prob_dom = probs_dom.sum()
+        sum_prob_dom_H = sum([probs[ix] for ix, ed in enumerate(edges) if self.H.has_edge(*ed)])
+        probs_dom /= sum_prob_dom
 
         edge_prob = dict(zip([hash_func(ed) for ed in edges], probs_dom))
         self.probs_df["p_{0}".format(self.nb_of_erosion)] = self.probs_df.apply(
             lambda x: edge_prob[hash_func([int(x.u), int(x.v)])] if hash_func([int(x.u), int(x.v)]) in edge_prob else 0,
             axis=1)
 
-        new_nb_edges = (np.asarray(
-            [(1 / self.H.size()) - probs_dom[ix] for ix, ed in enumerate(edges) if self.H.has_edge(*ed)])).sum() * self.H.size()
+        hhh = np.asarray(
+            [(1 / self.H.size()) - ((probs_dom[ix]*sum_prob_dom)/sum_prob_dom_H) for ix, ed in enumerate(edges) if self.H.has_edge(*ed)])
+        hhh[hhh < 0] = 0
+        new_nb_edges = hhh.sum() * self.H.size()
+        #print(hhh)
 
         probs_erosion = np.asarray([old_probs[hash_func(ed)] - probs_dom[ix] for ix, ed in enumerate(edges)])
-        probs_erosion[probs_erosion < 0] = float_epsilon
+        probs_erosion[probs_erosion <= 0] = float_epsilon
         probs_erosion /= probs_erosion.sum()
 
+
         final_edges = []
         index_selected_pairs = np.random.choice(np.arange(len(edges)), round(new_nb_edges), p=probs_erosion,
                                                 replace=False)  # round(0.7*H.size())
@@ -66,6 +82,7 @@ class ErosionModel():
             G2.nodes[n]["block"] = self.block_assign[n]
             G2.nodes[n]["pos"] = self.coordinates[n]
         self.H = G2.copy()
+        return probs_erosion
 
     def erode_n_times(self,n):
         if self.nb_of_erosion >0:
@@ -75,7 +92,12 @@ class ErosionModel():
         self.nb_of_erosion = 0
         self.H = self.G.copy()
         for i in range(n):
-            self.erode()
+            log(i)
+            log(self.H.size())
+            r = self.erode()
+            if r == False: # we cannot erode further
+                log("Cannot erode further")
+                break
 
 
     def initialize(self):
diff --git a/lib/random.py b/lib/random.py
index 207bcba41e76763b80485567fccdd5bd9b42baef..9347d95acd5c7089df5eea2a03a708da48e6120b 100644
--- a/lib/random.py
+++ b/lib/random.py
@@ -488,15 +488,23 @@ def mixed_model_spat_sbm(nb_nodes, nb_edges, nb_com, alpha, percentage_edge_betw
     all_probs_sbm /= all_probs_sbm.sum()
 
     pos = nx.get_node_attributes(G,"pos")
-    all_probs_spa = np.asarray([1 / (float_epsilon +dist_func(pos[edge[0]], pos[edge[1]])) for edge in all_edges])
+    all_probs_spa = np.asarray([1 / (float_epsilon + dist_func(pos[edge[0]], pos[edge[1]])) for edge in all_edges])
     all_probs_spa /= all_probs_spa.sum()
 
 
-    all_probs = alpha * (all_probs_sbm) + (1 - alpha) * all_probs_spa
+    #all_probs = alpha * (all_probs_sbm) + (1 - alpha) * all_probs_spa
+    nb_edges_sbm,nb_edges_spa = round(alpha*nb_edges),round((1-alpha)*nb_edges)
 
     final_edges = []
-    index_selected_pairs = np.random.choice(np.arange(len(all_edges)), nb_edges, p=all_probs, replace=False)
-    final_edges.extend(all_edges[index_selected_pairs])
+    index_selected_pairs_sbm = np.random.choice(np.arange(len(all_edges)), nb_edges_sbm, p=all_probs_sbm, replace=False)
+    final_edges.extend(all_edges[index_selected_pairs_sbm])
+
+    all_probs_spa[index_selected_pairs_sbm] = all_probs_spa.min()
+    all_probs_spa/= all_probs_spa.sum()
+
+    index_selected_pairs_spa = np.random.choice(np.arange(len(all_edges)), nb_edges_spa, p=all_probs_spa, replace=False)
+    final_edges.extend(all_edges[index_selected_pairs_spa])
+
     G2 = nx.from_edgelist(final_edges)
 
     for n in list(G2.nodes()):