From 00420f298c796f1f4d8f5f4a381dc84b04bd24a6 Mon Sep 17 00:00:00 2001 From: Fize Jacques <jacques.fize@cirad.fr> Date: Fri, 6 Nov 2020 11:17:38 +0100 Subject: [PATCH] Addd label encoder in the training process of the bert geocoder --- train_bert_geocoder.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/train_bert_geocoder.py b/train_bert_geocoder.py index 7dfa289..d0749b6 100644 --- a/train_bert_geocoder.py +++ b/train_bert_geocoder.py @@ -27,6 +27,9 @@ from transformers import BertTokenizer from transformers import BertForSequenceClassification, AdamW, BertConfig from transformers import get_linear_schedule_with_warmup +from sklearn import preprocessing +from joblib import dump + def flat_accuracy(preds, labels): pred_flat = np.argmax(preds, axis=1).flatten() labels_flat = labels.flatten() @@ -76,6 +79,13 @@ if not os.path.isdir(args.outputdir): df_train = pd.read_csv(args.train, sep="\t") df_test = pd.read_csv(args.test, sep="\t") +label_encoder = preprocessing.LabelEncoder() +label_encoder.fit(np.concatenate((df_train.label.values,df_test.label.values))) +dump(label_encoder,filename=output_dir+"/label_encoder.dump") + +df_train["label"] = label_encoder.transform(df_train.label.values) +df_test["label"] = label_encoder.transform(df_test.label.values) + # Get the GPU device name. device_name = tf.test.gpu_device_name() -- GitLab