Skip to content
Snippets Groups Projects
Commit 00420f29 authored by Fize Jacques's avatar Fize Jacques
Browse files

Addd label encoder in the training process of the bert geocoder

parent 01154519
No related branches found
No related tags found
No related merge requests found
......@@ -27,6 +27,9 @@ from transformers import BertTokenizer
from transformers import BertForSequenceClassification, AdamW, BertConfig
from transformers import get_linear_schedule_with_warmup
from sklearn import preprocessing
from joblib import dump
def flat_accuracy(preds, labels):
pred_flat = np.argmax(preds, axis=1).flatten()
labels_flat = labels.flatten()
......@@ -76,6 +79,13 @@ if not os.path.isdir(args.outputdir):
df_train = pd.read_csv(args.train, sep="\t")
df_test = pd.read_csv(args.test, sep="\t")
label_encoder = preprocessing.LabelEncoder()
label_encoder.fit(np.concatenate((df_train.label.values,df_test.label.values)))
dump(label_encoder,filename=output_dir+"/label_encoder.dump")
df_train["label"] = label_encoder.transform(df_train.label.values)
df_test["label"] = label_encoder.transform(df_test.label.values)
# Get the GPU device name.
device_name = tf.test.gpu_device_name()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment