From bc7e99730955964e60a566248ee4cb30f80751be Mon Sep 17 00:00:00 2001
From: Fize Jacques <jacques.fize@cirad.fr>
Date: Fri, 26 Mar 2021 16:43:31 +0100
Subject: [PATCH] Add evaluation script for our models

---
 evaluate_geocoder.py | 57 ++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 57 insertions(+)
 create mode 100644 evaluate_geocoder.py

diff --git a/evaluate_geocoder.py b/evaluate_geocoder.py
new file mode 100644
index 0000000..e57c2b3
--- /dev/null
+++ b/evaluate_geocoder.py
@@ -0,0 +1,57 @@
+# coding = utf-8
+
+import glob
+import pandas as pd
+import numpy as np
+from tqdm import tqdm
+
+import os
+
+from lib.geocoder.our_geocoder import Geocoder
+from lib.utils_geo import haversine_pd
+
+import argparse
+
+parser = argparse.ArgumentParser()
+
+parser.add_argument("models_dir")
+parser.add_argument("geocoding_dataset")
+parser.add_argument("output_filename")
+
+
+args = parser.parse_args()
+
+MODELS_DIR = args.models_dir.rstrip("/") + "/"
+GEOCODING_DATASET_FN = args.geocoding_dataset
+OUTPUT_FN = args.output_filename
+
+for fn in [MODELS_DIR,GEOCODING_DATASET_FN]:
+    if not os.path.exists(fn):
+        raise FileNotFoundError("{0} does not exists!".format(fn))
+
+geocoding_df = pd.read_csv(GEOCODING_DATASET_FN,sep="\t",index_col=0)
+geocoding_df = geocoding_df[geocoding_df.split == "test"]
+
+model_available = glob.glob(MODELS_DIR+"*.h5")
+model_available = [mod.rstrip(".h5").split("/")[-1] for mod in model_available]
+print("Models that will be evaluated :")
+for model_fn in model_available:
+    print("\t*",model_fn)
+
+def accuracy_at_k(geocoding_df,geocoder,k=100):
+    lons,lats = g.get_coords(geocoding_df.toponym.values,geocoding_df.toponym_context.values)
+    geocoding_df["pred_latitude"] = lats
+    geocoding_df["pred_longitude"] = lons
+    geocoding_df["distanceKM"] = haversine_pd(geocoding_df.longitude,geocoding_df.latitude,geocoding_df.pred_longitude,geocoding_df.pred_latitude)
+    return (geocoding_df.distanceKM <k).sum()/len(geocoding_df)
+
+
+res_ = []
+for mod in tqdm(model_available):
+    index_fn = MODELS_DIR + mod +"_index"
+    model_fn = MODELS_DIR + mod +".h5"
+    g = Geocoder(model_fn, index_fn)
+    res_.append([mod,accuracy_at_k(geocoding_df,g,100),accuracy_at_k(geocoding_df,g,50),accuracy_at_k(geocoding_df,g,20)])
+
+
+pd.DataFrame(res_,columns="dataset accuracy@100km accuracy@50km accuracy@20km".split()).to_csv(OUTPUT_FN,sep="\t",index=None)
-- 
GitLab