From 5c0d03fa44ff1860a59783ba47bc4ad8a94dfae9 Mon Sep 17 00:00:00 2001 From: Fize Jacques <jacques.fize@cirad.fr> Date: Thu, 17 Dec 2020 11:43:10 +0100 Subject: [PATCH] Optimisation on ngram embedding --- lib/ngram_index.py | 33 +++++++++++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/lib/ngram_index.py b/lib/ngram_index.py index 5f86220..0b7a42b 100644 --- a/lib/ngram_index.py +++ b/lib/ngram_index.py @@ -73,7 +73,7 @@ class NgramIndex(): self.index_ngram[self.cpt]=ngram - def encode(self,word): + def encode(self,word,complete=True): """ Return a ngram representation of a word @@ -93,6 +93,8 @@ class NgramIndex(): ngrams = [ng for ng in ngrams if ng.count(self.empty_char)<2] if not self.loaded: [self.add(ng) for ng in ngrams if not ng in self.ngram_index] + if not complete: + return [self.ngram_index[ng] for ng in ngrams if ng in self.ngram_index] return self.complete([self.ngram_index[ng] for ng in ngrams if ng in self.ngram_index],self.max_len) def complete(self,ngram_encoding,MAX_LEN,filling_item=0): @@ -136,7 +138,8 @@ class NgramIndex(): np.array embedding matrix """ - model = Word2Vec([[str(w) for w in t] for t in texts], size=dim,window=5, min_count=1, workers=4,**kwargs) + sentences = SentenceIterator(self,texts,True) + model = Word2Vec(sentences, size=dim,window=5, min_count=1, workers=4,**kwargs) N = len(self.ngram_index) embedding_matrix = np.zeros((N,dim)) for i in range(N): @@ -227,3 +230,29 @@ class NgramIndex(): new_obj.max_len = data["max_len_state"] return new_obj +class SentenceIterator: + + """Iterator that counts upward forever.""" + + def __init__(self, ng_encoder, input_,cast_str = False): + self.ng_encoder = ng_encoder + self.input = input_ + self.i = -1 + self.cast_str = cast_str + + def __iter__(self): + self.i = -1 + return self + + def __len__(self): + return len(self.input) + + def __next__(self): + try: + self.i +=1 + if not self.cast_str: + return self.ng_encoder.encode(self.input[self.i]) + else: + return list(map(str,self.ng_encoder.encode(self.input[self.i]))) + except: + raise StopIteration -- GitLab