Skip to content
Snippets Groups Projects
Commit d6f858af authored by Ludovic Moncla's avatar Ludovic Moncla
Browse files

Upload New File

parent a8599fb4
Branches master
No related tags found
No related merge requests found
import pandas as pd
import numpy as np
from tqdm import tqdm
from lib.utils_geo import haversine_pd
import warnings
from pandas.core.common import SettingWithCopyWarning
from lib.geocoder.our_geocoder import Geocoder
import argparse
import glob
parser = argparse.ArgumentParser()
parser.add_argument("models_dir")
parser.add_argument("coocurrence_dataset")
parser.add_argument("output_filename")
parser.add_argument("-k",default=4,type=int)
args = parser.parse_args()
tqdm.pandas()
warnings.simplefilter(action="ignore", category=SettingWithCopyWarning)
def heuristic_mean(geocoder,toponym, context_toponyms):
input_ = np.asarray([[toponym,t1] for t1 in context_toponyms if toponym != t1])
if len(input_) == 0:
input_ = np.asarray([[toponym,toponym]])
res_geocode = pd.DataFrame(input_,columns="t tc".split())
lons,lats = geocoder.get_coords(input_[:,0],input_[:,1])
res_geocode["lon"] = lons
res_geocode["lat"] = lats
return [res_geocode["lon"].mean(),res_geocode["lat"].mean()]
def accuracy_at_k(geocoding_df,k=100):
geocoding_df["distanceKM"] = haversine_pd(geocoding_df.longitude,geocoding_df.latitude,geocoding_df.pred_longitude,geocoding_df.pred_latitude)
return (geocoding_df.distanceKM <k).sum()/len(geocoding_df)
def median_distance_error(geocoding_df):
geocoding_df["distanceKM"] = haversine_pd(geocoding_df.longitude,geocoding_df.latitude,geocoding_df.pred_longitude,geocoding_df.pred_latitude)
return geocoding_df.distanceKM.median()
def geocode_wikipages(df,geo,k=None):
import random
random.seed(42)
if not k:
found_coords = df.progress_apply(lambda x: heuristic_mean(geo,x.title,x.interlinks),axis=1).values
else:
found_coords = df.progress_apply(lambda x: heuristic_mean(geo,x.title,random.choices(x.interlinks,k=k)),axis=1).values
found_coords = np.asarray(found_coords.tolist())
return found_coords
MODELS_DIR = args.models_dir.rstrip("/") + "/"
COOC_DATASET_FN = args.coocurrence_dataset
OUTPUT_FN = args.output_filename
k_cooc_used = args.k
df = pd.read_csv(COOC_DATASET_FN,sep="\t")
df["interlinks"] = df.interlinks.apply(lambda x: x.split("|"))
model_available = glob.glob(MODELS_DIR+"*.h5")
model_available = [mod.rstrip(".h5").split("/")[-1] for mod in model_available]
print("Models that will be evaluated :")
for model_fn in model_available:
print("\t*",model_fn)
res_ = []
for mod in tqdm(model_available):
index_fn = MODELS_DIR + mod +"_index"
model_fn = MODELS_DIR + mod +".h5"
g = Geocoder(model_fn, index_fn)
found_coords = geocode_wikipages(df,g, k_cooc_used)
df.loc[:, "pred_longitude"] = found_coords[:, 0]
df.loc[:, "pred_latitude"] = found_coords[:, 1]
res_.append([mod,accuracy_at_k(df,161),accuracy_at_k(df,100),accuracy_at_k(df,50),accuracy_at_k(df,20),median_distance_error(df)])
pd.DataFrame(res_,columns="dataset accuracy@161km accuracy@100km accuracy@50km accuracy@20km MDE".split()).to_csv(OUTPUT_FN,sep="\t",index=None)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment