diff --git a/lib/ngram_index.py b/lib/ngram_index.py index 5f862206aa2f3c852d09a1a2438906b57541b5d6..0b7a42b466a90661ffb2397509007db464942df9 100644 --- a/lib/ngram_index.py +++ b/lib/ngram_index.py @@ -73,7 +73,7 @@ class NgramIndex(): self.index_ngram[self.cpt]=ngram - def encode(self,word): + def encode(self,word,complete=True): """ Return a ngram representation of a word @@ -93,6 +93,8 @@ class NgramIndex(): ngrams = [ng for ng in ngrams if ng.count(self.empty_char)<2] if not self.loaded: [self.add(ng) for ng in ngrams if not ng in self.ngram_index] + if not complete: + return [self.ngram_index[ng] for ng in ngrams if ng in self.ngram_index] return self.complete([self.ngram_index[ng] for ng in ngrams if ng in self.ngram_index],self.max_len) def complete(self,ngram_encoding,MAX_LEN,filling_item=0): @@ -136,7 +138,8 @@ class NgramIndex(): np.array embedding matrix """ - model = Word2Vec([[str(w) for w in t] for t in texts], size=dim,window=5, min_count=1, workers=4,**kwargs) + sentences = SentenceIterator(self,texts,True) + model = Word2Vec(sentences, size=dim,window=5, min_count=1, workers=4,**kwargs) N = len(self.ngram_index) embedding_matrix = np.zeros((N,dim)) for i in range(N): @@ -227,3 +230,29 @@ class NgramIndex(): new_obj.max_len = data["max_len_state"] return new_obj +class SentenceIterator: + + """Iterator that counts upward forever.""" + + def __init__(self, ng_encoder, input_,cast_str = False): + self.ng_encoder = ng_encoder + self.input = input_ + self.i = -1 + self.cast_str = cast_str + + def __iter__(self): + self.i = -1 + return self + + def __len__(self): + return len(self.input) + + def __next__(self): + try: + self.i +=1 + if not self.cast_str: + return self.ng_encoder.encode(self.input[self.i]) + else: + return list(map(str,self.ng_encoder.encode(self.input[self.i]))) + except: + raise StopIteration