From 4dd99f2a5658c5171b4ccb5e72b3fdd8d86e296a Mon Sep 17 00:00:00 2001 From: Fize Jacques <jacques.fize@cirad.fr> Date: Mon, 11 Jan 2021 10:01:16 +0100 Subject: [PATCH] Optimisation on ngram embedding --- lib/ngram_index.py | 12 +++++++++++- train_geocoder.py | 2 +- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/lib/ngram_index.py b/lib/ngram_index.py index b461718..8bf4ce3 100644 --- a/lib/ngram_index.py +++ b/lib/ngram_index.py @@ -78,13 +78,23 @@ class NgramIndex(): else: self.freq_ngram[ngram] += 1 - def filter_ngram(self,threshold=20): + def filter_ngram_by_freq(self,threshold=20): freq_data = pd.DataFrame(self.freq_ngram.items(),columns="ngram freq".split()) selected_ngram = freq_data[freq_data.freq<threshold].ngram.values for ng in selected_ngram: index = self.ngram_index[ng] del self.ngram_index[ng] del self.index_ngram[index] + + def filter_top_ngram(self,threshold=20000): + freq_data = pd.DataFrame(self.freq_ngram.items(),columns="ngram freq".split()).sort_values(by="freq",ascending=False) + if len(self.ngram_index)-threshold <0: + return 0 + selected_ngram = freq_data.tail(len(self.ngram_index)-threshold).ngram.values + for ng in selected_ngram: + index = self.ngram_index[ng] + del self.ngram_index[ng] + del self.index_ngram[index] def encode(self,word,complete=True): diff --git a/train_geocoder.py b/train_geocoder.py index d143b58..b8c6978 100644 --- a/train_geocoder.py +++ b/train_geocoder.py @@ -112,7 +112,7 @@ if args.tokenization_method == "bert": # Identify all ngram available pairs_of_toponym.toponym.apply(lambda x : index.split_and_add(x)) pairs_of_toponym.toponym_context.apply(lambda x : index.split_and_add(x)) -index.filter_ngram() +index.filter_top_ngram(40000) num_words = len(index.index_ngram) # necessary for the embedding matrix -- GitLab