diff --git a/generate_dataset.py b/generate_dataset.py index bf1f8af05130be86d3e0a9ebbfde8b4917fecd8b..90224c8997d3b3ab370ff29e5e5d8ee05cd96be1 100644 --- a/generate_dataset.py +++ b/generate_dataset.py @@ -28,6 +28,7 @@ PREFIX = PREFIX + "_" + args.split_method # Â LOAD DATA geonames_data = read_geonames(args.geonames_dataset) +geonames_data = geonames_data[geonames_data.feature_class.isin("A P".split())] # filter populated places and areas wikipedia_data = pd.read_csv(args.wikipedia_dataset, sep="\t") geonames_hierarchy_data = pd.read_csv(args.geonames_hierarchy_data, sep="\t", header=None, names="parentId,childId,type".split(",")).fillna("") @@ -69,6 +70,15 @@ def get_adjacent_pairs(dataframe, sampling_nb=4,no_sampling=False): new_pairs.extend([[row.geonameid, topo_prin, sel, lat, lon] for sel in selected]) return new_pairs +def random_sample(values, m): + res = [] + if len(values) == 1: + values.append(values[0]) + for i in range(min(max(len(values),2), m)): + pos = np.random.randint(len(values)) + res.append(values[pos]) + values.pop(pos) + return res def get_cooccurrence_pairs(dataframe, sampling_nb=4,no_sampling=False): """ @@ -87,7 +97,8 @@ def get_cooccurrence_pairs(dataframe, sampling_nb=4,no_sampling=False): """ new_pairs = [] if not no_sampling: - dataframe["interlinks"] = dataframe.interlinks.apply(lambda x: np.random.choice(x.split("|"), sampling_nb)) + #dataframe["interlinks"] = dataframe.interlinks.apply(lambda x: np.random.choice(x.split("|"), sampling_nb)) + dataframe["interlinks"] = dataframe.interlinks.apply(lambda x: random_sample(x.split("|"),sampling_nb)) else: dataframe["interlinks"] = dataframe.interlinks.apply(lambda x: x.split("|")) for ix, row in tqdm(dataframe.iterrows(), total=len(dataframe), desc="Get Cooccurrent Toponym Pairs"): @@ -149,7 +160,13 @@ inc_train, _ = train_test_split(inclusion_pairs, test_size=0.33) inclusion_pairs["split"] = "test" inclusion_pairs.loc[inc_train.index.values, "split"] = "train" -#Â SAVE DATA + +# PRINT NB PAIRS +print('# cooc_pairs: ', len(cooc_pairs)) +print('# adjacent_pairs: ', len(adjacent_pairs)) +print('# inclusion_pairs: ', len(inclusion_pairs)) + +# SAVE DATA inclusion_pairs.to_csv("{0}_inclusion.csv".format(PREFIX), sep="\t") adjacent_pairs.to_csv("{0}_adjacent.csv".format(PREFIX), sep="\t") cooc_pairs.to_csv("{0}_cooc.csv".format(PREFIX), sep="\t")