diff --git a/lib/random.py b/lib/random.py
index 55867e2e368384eab25975f70ec82691f8debfbe..29f164cd0f01089bee615308875cb6feab2fbc50 100644
--- a/lib/random.py
+++ b/lib/random.py
@@ -190,6 +190,7 @@ def spatial_graph(nb_nodes, nb_edges, coords="country", dist_func=lambda a, b: n
     df = pd.DataFrame(data, columns="src tar weight".split()).astype({"src": int, "tar": int})
     df["hash"] = df.apply(lambda x: "_".join(sorted([str(int(x.src)), str(int(x.tar))])), axis=1)
     df = df.drop_duplicates(subset="hash")
+    df["weight"] = df.weight/df.weight.sum()
 
     register = set([])
 
@@ -261,13 +262,30 @@ def stochastic_block_model_graph(nb_nodes, nb_edges, nb_com, percentage_edge_bet
     if nb_edges > edge_max:
         raise ValueError("nb_edges must be inferior to {0}".format(edge_max))
 
-    percentage_edge_within = 1 - percentage_edge_betw
+    def nb_of_pair(N):
+        return (N*(N-1))/2
 
-    G = nx.planted_partition_graph(nb_com, nb_nodes//nb_com, 1, 1)
+    G = nx.planted_partition_graph(nb_com, nb_nodes // nb_com, 1, 1)
+    block_assign = nx.get_node_attributes(G, "block")
+    b_assign_array = np.asarray(list(nx.get_node_attributes(G,"block").values()))
     if verbose:
         print(G.size())
 
-    block_assign = nx.get_node_attributes(G, "block")
+    u_in = sum([nb_of_pair((b_assign_array==b).sum()) for b in range(nb_com)])
+    u_out = nb_of_pair(len(G)) - u_in
+    l_out = u_out*percentage_edge_betw
+    l_in = nb_edges - l_out
+
+    percentage_edge_within = l_in / u_in
+    if verbose:
+        print("u_out",u_out)
+        print("u_in",u_in)
+        print("l_out",l_out)
+        print("l_in", l_in)
+        print("p_in",percentage_edge_within)
+    # percentage_edge_within = 1 - percentage_edge_betw
+
+
     inter_edges, intra_edges = [], []
     register = set([])
     for n1 in list(G.nodes()):
@@ -285,16 +303,24 @@ def stochastic_block_model_graph(nb_nodes, nb_edges, nb_com, percentage_edge_bet
     inter_edges = np.asarray(inter_edges)
     intra_edges = np.asarray(intra_edges)
     inter_N, intra_N = len(inter_edges), len(intra_edges)
+    probs_inter = np.ones(inter_N) * percentage_edge_betw
+    probs_intra = np.ones(intra_N) * percentage_edge_within
+
+    all_edges = np.concatenate((inter_edges, intra_edges))
+    del inter_edges
+    del intra_edges
+    all_probs = np.concatenate((probs_inter, probs_intra))
+    del probs_inter
+    del probs_intra
+    all_probs /= all_probs.sum()
 
     if verbose:
         print(inter_N, intra_N)
         print(int(np.ceil(nb_edges * percentage_edge_betw)), int(np.ceil(nb_edges * percentage_edge_within)))
 
     final_edges = []
-    index_inter = np.random.choice(np.arange(inter_N), int(np.round(nb_edges * percentage_edge_betw)), replace=False)
-    index_intra = np.random.choice(np.arange(intra_N), int(np.round(nb_edges * percentage_edge_within)), replace=False)
-    final_edges.extend(inter_edges[index_inter])
-    final_edges.extend(intra_edges[index_intra])
+    index_selected_pairs = np.random.choice(np.arange(len(all_edges)), nb_edges, p=all_probs, replace=False)
+    final_edges.extend(all_edges[index_selected_pairs])
 
     if verbose:
         print(len(final_edges))