diff --git a/lib/ngram_index.py b/lib/ngram_index.py index 6f2215b5370cedc1a5b1637474fb896c10efbafc..462e327116ec57c66b7c947331c00fcbee026339 100644 --- a/lib/ngram_index.py +++ b/lib/ngram_index.py @@ -58,7 +58,7 @@ class NgramIndex(): """ ngrams = str(word).lower().replace(" ",self.empty_char) ngrams = list(self.ngram_gen.split(ngrams)) - [self.add(ngram) for ngram in ngrams] + [self.add(ngram) for ngram in ngrams if not ngram in self.ngram_index] self.max_len = max(self.max_len,len(ngrams)) def add(self,ngram): @@ -80,21 +80,21 @@ class NgramIndex(): def filter_ngram_by_freq(self,threshold=20): freq_data = pd.DataFrame(self.freq_ngram.items(),columns="ngram freq".split()) - selected_ngram = freq_data[freq_data.freq<threshold].ngram.values - for ng in selected_ngram: - index = self.ngram_index[ng] - del self.ngram_index[ng] - del self.index_ngram[index] + selected_ngram = freq_data[freq_data.freq>threshold] + selected_ngram["index__"] = np.arange(len(selected_ngram)) + self.ngram_index = dict(selected_ngram["ngram index__".split()].values) + self.index_ngram = dict(selected_ngram["index__ ngram".split()].values) def filter_top_ngram(self,threshold=20000): freq_data = pd.DataFrame(self.freq_ngram.items(),columns="ngram freq".split()).sort_values(by="freq",ascending=False) - if len(self.ngram_index)-threshold <0: + if len(self.ngram_index)-threshold <=0: return 0 - selected_ngram = freq_data.tail(len(self.ngram_index)-threshold).ngram.values - for ng in selected_ngram: - index = self.ngram_index[ng] - del self.ngram_index[ng] - del self.index_ngram[index] + selected_ngram = freq_data.head(threshold) + selected_ngram["index__"] = np.arange(len(selected_ngram)) + self.ngram_index = dict(selected_ngram["ngram index__".split()].values) + self.index_ngram = dict(selected_ngram["index__ ngram".split()].values) + + def encode(self,word,complete=True): diff --git a/train_geocoder.py b/train_geocoder.py index b8c6978c6f2d8991c911bdb6516fdcb619db2148..2ee31cd9271798f7c3e0d144fa4fa39d28a5a980 100644 --- a/train_geocoder.py +++ b/train_geocoder.py @@ -112,7 +112,8 @@ if args.tokenization_method == "bert": # Identify all ngram available pairs_of_toponym.toponym.apply(lambda x : index.split_and_add(x)) pairs_of_toponym.toponym_context.apply(lambda x : index.split_and_add(x)) -index.filter_top_ngram(40000) +print(len(index.ngram_index)) +index.filter_top_ngram(10000) num_words = len(index.index_ngram) # necessary for the embedding matrix