From 00420f298c796f1f4d8f5f4a381dc84b04bd24a6 Mon Sep 17 00:00:00 2001
From: Fize Jacques <jacques.fize@cirad.fr>
Date: Fri, 6 Nov 2020 11:17:38 +0100
Subject: [PATCH] Addd label encoder in the training process of the bert
 geocoder

---
 train_bert_geocoder.py | 10 ++++++++++
 1 file changed, 10 insertions(+)

diff --git a/train_bert_geocoder.py b/train_bert_geocoder.py
index 7dfa289..d0749b6 100644
--- a/train_bert_geocoder.py
+++ b/train_bert_geocoder.py
@@ -27,6 +27,9 @@ from transformers import BertTokenizer
 from transformers import BertForSequenceClassification, AdamW, BertConfig
 from transformers import get_linear_schedule_with_warmup
 
+from sklearn import preprocessing
+from joblib import dump
+
 def flat_accuracy(preds, labels):
     pred_flat = np.argmax(preds, axis=1).flatten()
     labels_flat = labels.flatten()
@@ -76,6 +79,13 @@ if not os.path.isdir(args.outputdir):
 df_train = pd.read_csv(args.train, sep="\t")
 df_test = pd.read_csv(args.test, sep="\t")
 
+label_encoder = preprocessing.LabelEncoder()
+label_encoder.fit(np.concatenate((df_train.label.values,df_test.label.values)))
+dump(label_encoder,filename=output_dir+"/label_encoder.dump")
+
+df_train["label"] = label_encoder.transform(df_train.label.values)
+df_test["label"] = label_encoder.transform(df_test.label.values)
+
 # Get the GPU device name.
 device_name = tf.test.gpu_device_name()
 
-- 
GitLab