From 87143d8e1bdfceb0942d53b6fcf31b2640728f22 Mon Sep 17 00:00:00 2001 From: Fize Jacques <jacques.fize@cirad.fr> Date: Tue, 10 Nov 2020 09:17:14 +0100 Subject: [PATCH] Add bert tokenization --- lib/ngram_index.py | 23 +++++++++++++++---- .../toponym_combination_embedding_v3.json | 2 +- train_geocoder_v2.py | 5 +++- 3 files changed, 24 insertions(+), 6 deletions(-) diff --git a/lib/ngram_index.py b/lib/ngram_index.py index 9e422e9..5f86220 100644 --- a/lib/ngram_index.py +++ b/lib/ngram_index.py @@ -3,15 +3,24 @@ import json import numpy as np from ngram import NGram +from transformers import BertTokenizer # Machine learning from gensim.models import Word2Vec + +class bertTokenizer: + def __init__(self): + self.tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased',do_lower_case=False) + + def split(self,string): + return self.tokenizer.tokenize(string) + class NgramIndex(): """ Class used for encoding words in ngram representation """ - def __init__(self,n,loaded = False): + def __init__(self,n,bert_tokenization=False,loaded = False): """ Constructor @@ -21,6 +30,10 @@ class NgramIndex(): ngram size """ self.ngram_gen = NGram(N=n) + self.empty_char = "$" + if bert_tokenization: + self.ngram_gen = bertTokenizer() + self.empty_char = "#" self.size = n self.ngram_index = {"":0} @@ -29,6 +42,8 @@ class NgramIndex(): self.max_len = 0 self.loaded = loaded + + def split_and_add(self,word): """ Split word in multiple ngram and add each one of them to the index @@ -38,7 +53,7 @@ class NgramIndex(): word : str a word """ - ngrams = str(word).lower().replace(" ","$") + ngrams = str(word).lower().replace(" ",self.empty_char) ngrams = list(self.ngram_gen.split(ngrams)) [self.add(ngram) for ngram in ngrams] self.max_len = max(self.max_len,len(ngrams)) @@ -73,9 +88,9 @@ class NgramIndex(): listfrom shapely.geometry import Point,box of ngram index """ - ngrams = str(word).lower().replace(" ","$") + ngrams = str(word).lower().replace(" ",self.empty_char) ngrams = list(self.ngram_gen.split(ngrams)) - ngrams = [ng for ng in ngrams if ng.count("$")<2] + ngrams = [ng for ng in ngrams if ng.count(self.empty_char)<2] if not self.loaded: [self.add(ng) for ng in ngrams if not ng in self.ngram_index] return self.complete([self.ngram_index[ng] for ng in ngrams if ng in self.ngram_index],self.max_len) diff --git a/parser_config/toponym_combination_embedding_v3.json b/parser_config/toponym_combination_embedding_v3.json index d4362ef..507e9c5 100644 --- a/parser_config/toponym_combination_embedding_v3.json +++ b/parser_config/toponym_combination_embedding_v3.json @@ -15,6 +15,6 @@ { "short": "-e", "long": "--epochs", "type": "int", "default": 100 }, { "short": "-d", "long": "--dimension", "type": "int", "default": 256 }, { "short": "-l", "long": "--lstm-layer", "type": "int", "default": 2, "choices": [1, 2] }, - { "long": "--tokenization-method", "type": "str", "default": "char-level", "choices": ["char-level", "word-level"] } + { "long": "--tokenization-method", "type": "str", "default": "char-level", "choices": ["char-level", "word-level", "bert"] } ] } \ No newline at end of file diff --git a/train_geocoder_v2.py b/train_geocoder_v2.py index 73a6818..665661e 100644 --- a/train_geocoder_v2.py +++ b/train_geocoder_v2.py @@ -34,7 +34,8 @@ try: physical_devices = tf.config.list_physical_devices('GPU') tf.config.experimental.set_memory_growth(physical_devices[0], enable=True) except: - print("NO GPU FOUND") + print("NO GPU FOUND...") + #Â COMMAND ARGS args = ConfigurationReader("./parser_config/toponym_combination_embedding_v3.json")\ .parse_args()#("IGN ../data/IGN/IGN_inclusion.csv ../data/IGN/IGN_adjacent_corrected.csv ../data/IGN/IGN_cooc.csv -i -w -a -n 4 --ngram-word2vec-iter 1".split()) @@ -103,6 +104,8 @@ logging.info("Encoding toponyms to ngram...") index = NgramIndex(NGRAM_SIZE) if args.tokenization_method == "word-level": index = WordIndex() +if args.tokenization_method == "bert": + index = NgramIndex(NGRAM_SIZE,bert_tokenization=True) # Identify all ngram available pairs_of_toponym.toponym.apply(lambda x : index.split_and_add(x)) -- GitLab