from lib.geocoder import Geocoder
geocoder = Geocoder("./outputs/FR_MODEL_2/FR.txt_100_4_100__A_I_C.h5","./outputs/FR_MODEL_2/FR.txt_100_4_100__A_I_C_index")
import pandas as pd
df = pd.read_csv("data/rando_toponymes.tsv",sep="\t")
df["name"]=df.name.apply(lambda x:x.split("¦")[0])

def heuristic_mean(toponyms):
	input_ = np.asarray([[t1,t2] for t2 in toponyms for t1 in toponyms if t2 != t1])
	if len(input_)<1:
		input_=np.asarray([[toponyms[0],toponyms[0]]])
	res_geocode = pd.DataFrame(input_,columns="t tc".split())
	lons,lats = geocoder.wgs_coord(*geocoder.get_coords(input_[:,0],input_[:,1]))
	res_geocode["lon"] = lons
	res_geocode["lat"] = lats
	results = {}
	for tp in toponyms:
		lat = res_geocode[res_geocode.t == tp].lat.mean()
		lon = res_geocode[res_geocode.t == tp].lon.mean()
		results[tp]={"lat":lat,"lon":lon}
	return results

def heuristic_one_couple(toponyms):
	input_ = np.asarray([[t1,t2] for t2 in toponyms for t1 in toponyms if t2 == t1])
	if len(input_)<1:
		input_=np.asarray([[toponyms[0],toponyms[0]]])
	res_geocode = pd.DataFrame(input_,columns="t tc".split())
	lons,lats = geocoder.wgs_coord(*geocoder.get_coords(input_[:,0],input_[:,1]))
	res_geocode["lon"] = lons
	res_geocode["lat"] = lats
	results = {}
	for tp in toponyms:
		lat = res_geocode[res_geocode.t == tp].lat.mean()
		lon = res_geocode[res_geocode.t == tp].lon.mean()
		results[tp]={"lat":lat,"lon":lon}
	return results

results_fin = []
for ix,group in df.groupby("filename"):
    res_geocode = heuristic_one_couple(group.name_gazetteer.values)
    results_fin.extend(group.name_gazetteer.apply(lambda x : res_geocode[x]).values.tolist())
dd = pd.DataFrame(results_fin).rename(columns={"lat":"lat_pred","lon":"lon_pred"})
df2 = pd.concat((df,dd),axis=1)

from lib.geo import haversine_pd
df2["dist_error"] = haversine_pd(df2.longitude,df2.latitude,df2.lon_pred,df2.lat_pred)
print(df2.dist_error.mean())