diff --git a/train_bert_geocoder.py b/train_bert_geocoder.py index 7dfa289e8cc410d55d722b1016f99f7248a1cd3c..d0749b6bbd2c13bd490310805e0bed4abde237e6 100644 --- a/train_bert_geocoder.py +++ b/train_bert_geocoder.py @@ -27,6 +27,9 @@ from transformers import BertTokenizer from transformers import BertForSequenceClassification, AdamW, BertConfig from transformers import get_linear_schedule_with_warmup +from sklearn import preprocessing +from joblib import dump + def flat_accuracy(preds, labels): pred_flat = np.argmax(preds, axis=1).flatten() labels_flat = labels.flatten() @@ -76,6 +79,13 @@ if not os.path.isdir(args.outputdir): df_train = pd.read_csv(args.train, sep="\t") df_test = pd.read_csv(args.test, sep="\t") +label_encoder = preprocessing.LabelEncoder() +label_encoder.fit(np.concatenate((df_train.label.values,df_test.label.values))) +dump(label_encoder,filename=output_dir+"/label_encoder.dump") + +df_train["label"] = label_encoder.transform(df_train.label.values) +df_test["label"] = label_encoder.transform(df_test.label.values) + # Get the GPU device name. device_name = tf.test.gpu_device_name()