Skip to content
Snippets Groups Projects
Commit d83be71f authored by Ludovic Moncla's avatar Ludovic Moncla
Browse files

Update generate_dataset.py remove duplicates from pairs

parent d3c5e2b4
No related branches found
No related tags found
No related merge requests found
...@@ -28,6 +28,7 @@ PREFIX = PREFIX + "_" + args.split_method ...@@ -28,6 +28,7 @@ PREFIX = PREFIX + "_" + args.split_method
# LOAD DATA # LOAD DATA
geonames_data = read_geonames(args.geonames_dataset) geonames_data = read_geonames(args.geonames_dataset)
geonames_data = geonames_data[geonames_data.feature_class.isin("A P".split())] # filter populated places and areas
wikipedia_data = pd.read_csv(args.wikipedia_dataset, sep="\t") wikipedia_data = pd.read_csv(args.wikipedia_dataset, sep="\t")
geonames_hierarchy_data = pd.read_csv(args.geonames_hierarchy_data, sep="\t", header=None, geonames_hierarchy_data = pd.read_csv(args.geonames_hierarchy_data, sep="\t", header=None,
names="parentId,childId,type".split(",")).fillna("") names="parentId,childId,type".split(",")).fillna("")
...@@ -69,6 +70,15 @@ def get_adjacent_pairs(dataframe, sampling_nb=4,no_sampling=False): ...@@ -69,6 +70,15 @@ def get_adjacent_pairs(dataframe, sampling_nb=4,no_sampling=False):
new_pairs.extend([[row.geonameid, topo_prin, sel, lat, lon] for sel in selected]) new_pairs.extend([[row.geonameid, topo_prin, sel, lat, lon] for sel in selected])
return new_pairs return new_pairs
def random_sample(values, m):
res = []
if len(values) == 1:
values.append(values[0])
for i in range(min(max(len(values),2), m)):
pos = np.random.randint(len(values))
res.append(values[pos])
values.pop(pos)
return res
def get_cooccurrence_pairs(dataframe, sampling_nb=4,no_sampling=False): def get_cooccurrence_pairs(dataframe, sampling_nb=4,no_sampling=False):
""" """
...@@ -87,7 +97,8 @@ def get_cooccurrence_pairs(dataframe, sampling_nb=4,no_sampling=False): ...@@ -87,7 +97,8 @@ def get_cooccurrence_pairs(dataframe, sampling_nb=4,no_sampling=False):
""" """
new_pairs = [] new_pairs = []
if not no_sampling: if not no_sampling:
dataframe["interlinks"] = dataframe.interlinks.apply(lambda x: np.random.choice(x.split("|"), sampling_nb)) #dataframe["interlinks"] = dataframe.interlinks.apply(lambda x: np.random.choice(x.split("|"), sampling_nb))
dataframe["interlinks"] = dataframe.interlinks.apply(lambda x: random_sample(x.split("|"),sampling_nb))
else: else:
dataframe["interlinks"] = dataframe.interlinks.apply(lambda x: x.split("|")) dataframe["interlinks"] = dataframe.interlinks.apply(lambda x: x.split("|"))
for ix, row in tqdm(dataframe.iterrows(), total=len(dataframe), desc="Get Cooccurrent Toponym Pairs"): for ix, row in tqdm(dataframe.iterrows(), total=len(dataframe), desc="Get Cooccurrent Toponym Pairs"):
...@@ -149,7 +160,13 @@ inc_train, _ = train_test_split(inclusion_pairs, test_size=0.33) ...@@ -149,7 +160,13 @@ inc_train, _ = train_test_split(inclusion_pairs, test_size=0.33)
inclusion_pairs["split"] = "test" inclusion_pairs["split"] = "test"
inclusion_pairs.loc[inc_train.index.values, "split"] = "train" inclusion_pairs.loc[inc_train.index.values, "split"] = "train"
# SAVE DATA
# PRINT NB PAIRS
print('# cooc_pairs: ', len(cooc_pairs))
print('# adjacent_pairs: ', len(adjacent_pairs))
print('# inclusion_pairs: ', len(inclusion_pairs))
# SAVE DATA
inclusion_pairs.to_csv("{0}_inclusion.csv".format(PREFIX), sep="\t") inclusion_pairs.to_csv("{0}_inclusion.csv".format(PREFIX), sep="\t")
adjacent_pairs.to_csv("{0}_adjacent.csv".format(PREFIX), sep="\t") adjacent_pairs.to_csv("{0}_adjacent.csv".format(PREFIX), sep="\t")
cooc_pairs.to_csv("{0}_cooc.csv".format(PREFIX), sep="\t") cooc_pairs.to_csv("{0}_cooc.csv".format(PREFIX), sep="\t")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment