From d83be71f33045515b448e3b0b05d953b83ce2709 Mon Sep 17 00:00:00 2001
From: Ludovic Moncla <ludovic.moncla@insa-lyon.fr>
Date: Tue, 9 Mar 2021 09:36:54 +0100
Subject: [PATCH] Update generate_dataset.py remove duplicates from pairs

---
 generate_dataset.py | 21 +++++++++++++++++++--
 1 file changed, 19 insertions(+), 2 deletions(-)

diff --git a/generate_dataset.py b/generate_dataset.py
index bf1f8af..90224c8 100644
--- a/generate_dataset.py
+++ b/generate_dataset.py
@@ -28,6 +28,7 @@ PREFIX = PREFIX + "_" + args.split_method
 
 #  LOAD DATA
 geonames_data = read_geonames(args.geonames_dataset)
+geonames_data = geonames_data[geonames_data.feature_class.isin("A P".split())] # filter populated places and areas
 wikipedia_data = pd.read_csv(args.wikipedia_dataset, sep="\t")
 geonames_hierarchy_data = pd.read_csv(args.geonames_hierarchy_data, sep="\t", header=None,
                                       names="parentId,childId,type".split(",")).fillna("")
@@ -69,6 +70,15 @@ def get_adjacent_pairs(dataframe, sampling_nb=4,no_sampling=False):
         new_pairs.extend([[row.geonameid, topo_prin, sel, lat, lon] for sel in selected])
     return new_pairs
 
+def random_sample(values, m):
+    res = []
+    if len(values) == 1:
+        values.append(values[0])
+    for i in range(min(max(len(values),2), m)):
+        pos = np.random.randint(len(values))
+        res.append(values[pos])
+        values.pop(pos)
+    return res
 
 def get_cooccurrence_pairs(dataframe, sampling_nb=4,no_sampling=False):
     """
@@ -87,7 +97,8 @@ def get_cooccurrence_pairs(dataframe, sampling_nb=4,no_sampling=False):
     """
     new_pairs = []
     if not no_sampling:
-        dataframe["interlinks"] = dataframe.interlinks.apply(lambda x: np.random.choice(x.split("|"), sampling_nb))
+        #dataframe["interlinks"] = dataframe.interlinks.apply(lambda x: np.random.choice(x.split("|"), sampling_nb))
+        dataframe["interlinks"] = dataframe.interlinks.apply(lambda x: random_sample(x.split("|"),sampling_nb))
     else:
         dataframe["interlinks"] = dataframe.interlinks.apply(lambda x: x.split("|"))
     for ix, row in tqdm(dataframe.iterrows(), total=len(dataframe), desc="Get Cooccurrent Toponym Pairs"):
@@ -149,7 +160,13 @@ inc_train, _ = train_test_split(inclusion_pairs, test_size=0.33)
 inclusion_pairs["split"] = "test"
 inclusion_pairs.loc[inc_train.index.values, "split"] = "train"
 
-# SAVE DATA
+
+# PRINT NB PAIRS
+print('# cooc_pairs: ', len(cooc_pairs))
+print('# adjacent_pairs: ', len(adjacent_pairs))
+print('# inclusion_pairs: ', len(inclusion_pairs))
+
+# SAVE DATA
 inclusion_pairs.to_csv("{0}_inclusion.csv".format(PREFIX), sep="\t")
 adjacent_pairs.to_csv("{0}_adjacent.csv".format(PREFIX), sep="\t")
 cooc_pairs.to_csv("{0}_cooc.csv".format(PREFIX), sep="\t")
-- 
GitLab