From d83be71f33045515b448e3b0b05d953b83ce2709 Mon Sep 17 00:00:00 2001 From: Ludovic Moncla <ludovic.moncla@insa-lyon.fr> Date: Tue, 9 Mar 2021 09:36:54 +0100 Subject: [PATCH] Update generate_dataset.py remove duplicates from pairs --- generate_dataset.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/generate_dataset.py b/generate_dataset.py index bf1f8af..90224c8 100644 --- a/generate_dataset.py +++ b/generate_dataset.py @@ -28,6 +28,7 @@ PREFIX = PREFIX + "_" + args.split_method # Â LOAD DATA geonames_data = read_geonames(args.geonames_dataset) +geonames_data = geonames_data[geonames_data.feature_class.isin("A P".split())] # filter populated places and areas wikipedia_data = pd.read_csv(args.wikipedia_dataset, sep="\t") geonames_hierarchy_data = pd.read_csv(args.geonames_hierarchy_data, sep="\t", header=None, names="parentId,childId,type".split(",")).fillna("") @@ -69,6 +70,15 @@ def get_adjacent_pairs(dataframe, sampling_nb=4,no_sampling=False): new_pairs.extend([[row.geonameid, topo_prin, sel, lat, lon] for sel in selected]) return new_pairs +def random_sample(values, m): + res = [] + if len(values) == 1: + values.append(values[0]) + for i in range(min(max(len(values),2), m)): + pos = np.random.randint(len(values)) + res.append(values[pos]) + values.pop(pos) + return res def get_cooccurrence_pairs(dataframe, sampling_nb=4,no_sampling=False): """ @@ -87,7 +97,8 @@ def get_cooccurrence_pairs(dataframe, sampling_nb=4,no_sampling=False): """ new_pairs = [] if not no_sampling: - dataframe["interlinks"] = dataframe.interlinks.apply(lambda x: np.random.choice(x.split("|"), sampling_nb)) + #dataframe["interlinks"] = dataframe.interlinks.apply(lambda x: np.random.choice(x.split("|"), sampling_nb)) + dataframe["interlinks"] = dataframe.interlinks.apply(lambda x: random_sample(x.split("|"),sampling_nb)) else: dataframe["interlinks"] = dataframe.interlinks.apply(lambda x: x.split("|")) for ix, row in tqdm(dataframe.iterrows(), total=len(dataframe), desc="Get Cooccurrent Toponym Pairs"): @@ -149,7 +160,13 @@ inc_train, _ = train_test_split(inclusion_pairs, test_size=0.33) inclusion_pairs["split"] = "test" inclusion_pairs.loc[inc_train.index.values, "split"] = "train" -#Â SAVE DATA + +# PRINT NB PAIRS +print('# cooc_pairs: ', len(cooc_pairs)) +print('# adjacent_pairs: ', len(adjacent_pairs)) +print('# inclusion_pairs: ', len(inclusion_pairs)) + +# SAVE DATA inclusion_pairs.to_csv("{0}_inclusion.csv".format(PREFIX), sep="\t") adjacent_pairs.to_csv("{0}_adjacent.csv".format(PREFIX), sep="\t") cooc_pairs.to_csv("{0}_cooc.csv".format(PREFIX), sep="\t") -- GitLab