From 5c0d03fa44ff1860a59783ba47bc4ad8a94dfae9 Mon Sep 17 00:00:00 2001
From: Fize Jacques <jacques.fize@cirad.fr>
Date: Thu, 17 Dec 2020 11:43:10 +0100
Subject: [PATCH] Optimisation on ngram embedding

---
 lib/ngram_index.py | 33 +++++++++++++++++++++++++++++++--
 1 file changed, 31 insertions(+), 2 deletions(-)

diff --git a/lib/ngram_index.py b/lib/ngram_index.py
index 5f86220..0b7a42b 100644
--- a/lib/ngram_index.py
+++ b/lib/ngram_index.py
@@ -73,7 +73,7 @@ class NgramIndex():
             self.index_ngram[self.cpt]=ngram
         
 
-    def encode(self,word):
+    def encode(self,word,complete=True):
         """
         Return a ngram representation of a word
         
@@ -93,6 +93,8 @@ class NgramIndex():
         ngrams = [ng for ng in ngrams if ng.count(self.empty_char)<2]
         if not self.loaded:
             [self.add(ng) for ng in ngrams if not ng in self.ngram_index]
+        if not complete:
+            return [self.ngram_index[ng] for ng in ngrams if ng in self.ngram_index]
         return self.complete([self.ngram_index[ng] for ng in ngrams if ng in self.ngram_index],self.max_len)
 
     def complete(self,ngram_encoding,MAX_LEN,filling_item=0):
@@ -136,7 +138,8 @@ class NgramIndex():
         np.array
             embedding matrix
         """
-        model = Word2Vec([[str(w) for w in t] for t in texts], size=dim,window=5, min_count=1, workers=4,**kwargs)
+        sentences = SentenceIterator(self,texts,True)
+        model = Word2Vec(sentences, size=dim,window=5, min_count=1, workers=4,**kwargs)
         N = len(self.ngram_index)
         embedding_matrix = np.zeros((N,dim))
         for i in range(N):
@@ -227,3 +230,29 @@ class NgramIndex():
         new_obj.max_len = data["max_len_state"]
         return new_obj
 
+class SentenceIterator:
+
+    """Iterator that counts upward forever."""
+
+    def __init__(self, ng_encoder, input_,cast_str = False):
+        self.ng_encoder = ng_encoder
+        self.input = input_
+        self.i = -1
+        self.cast_str = cast_str
+
+    def __iter__(self):
+        self.i = -1
+        return self
+    
+    def __len__(self):
+        return len(self.input)
+
+    def __next__(self):
+        try:
+            self.i +=1
+            if not self.cast_str:
+                return self.ng_encoder.encode(self.input[self.i])
+            else:
+                return list(map(str,self.ng_encoder.encode(self.input[self.i])))
+        except:
+            raise StopIteration
-- 
GitLab