Skip to content
Snippets Groups Projects
word_index.py 4.83 KiB
import json

import numpy as np

from ngram import NGram

# Machine learning 
from gensim.models import Word2Vec

class WordIndex():
    """
    Class used for encoding words in ngram representation
    """
    def __init__(self,loaded = False):
        """
        Constructor
        
        Parameters
        ----------
        loaded : bool
            if loaded from external file
        """
        self.ngram_index = {"":0}
        self.index_ngram = {0:""}
        self.cpt = 0
        self.max_len = 0

        self.loaded = loaded

    def split_and_add(self,word):
        """
        Split word in multiple ngram and add each one of them to the index
        
        Parameters
        ----------
        word : str
            a word
        """
        grams = word.lower().split(" ")
        [self.add(subword) for subword in grams ]
        self.max_len = max(self.max_len,len(grams))

    def add(self,subword):
        """
        Add a ngram to the index
        
        Parameters
        ----------
        ngram : str
            ngram
        """
        if not subword in self.ngram_index:
            self.cpt+=1
            self.ngram_index[subword]=self.cpt
            self.index_ngram[self.cpt]=subword
        

    def encode(self,word):
        """
        Return a ngram representation of a word
        
        Parameters
        ----------
        word : str
            a word
        
        Returns
        -------
        list of int
            listfrom shapely.geometry import Point,box
 of ngram index
        """
        subwords = [w.lower() for w in word.split(" ")]
        if not self.loaded:
            [self.add(ng) for ng in subwords if not ng in self.ngram_index]
        if self.max_len < len(subwords):
            self.max_len = max(self.max_len,len(subwords))
        return self.complete([self.ngram_index[ng] for ng in subwords if ng in self.ngram_index],self.max_len)

    def complete(self,ngram_encoding,MAX_LEN,filling_item=0):
        """
        Complete a ngram encoded version of word with void ngram. It's necessary for neural network.
        
        Parameters
        ----------
        ngram_encoding : list of int
            first encoding of a word
        MAX_LEN : int
            desired length of the encoding
        filling_item : int, optional
            ngram index you wish to use, by default 0
        
        Returns
        -------
        list of int
            list of ngram index
        """
        if self.loaded and len(ngram_encoding) >=MAX_LEN:
            return ngram_encoding[:MAX_LEN]
        assert len(ngram_encoding) <= MAX_LEN
        diff = MAX_LEN - len(ngram_encoding)
        ngram_encoding.extend([filling_item]*diff)  
        return ngram_encoding
    
    def get_embedding_layer(self,texts,dim=100,**kwargs):
        """
        Return an embedding matrix for each ngram using encoded texts. Using gensim.Word2vec model.
        
        Parameters
        ----------
        texts : list of [list of int]
            list of encoded word
        dim : int, optional
            embedding dimension, by default 100
        
        Returns
        -------
        np.array
            embedding matrix
        """
        model = Word2Vec([[str(w) for w in t] for t in texts], size=dim,window=5, min_count=1, workers=4,**kwargs)
        N = len(self.ngram_index)
        embedding_matrix = np.zeros((N,dim))
        for i in range(N):
            if str(i) in model.wv:
                embedding_matrix[i] = model.wv[str(i)]
        return embedding_matrix

    def save(self,fn):
        """

        Save the NgramIndex
        
        Parameters
        ----------
        fn : str
            output filename
        """
        data = {
            "word_index": self.ngram_index,
            "cpt_state": self.cpt,
            "max_len_state": self.max_len
        }
        json.dump(data,open(fn,'w'))

    @staticmethod
    def load(fn):
        """
        
        Load a NgramIndex state from a file.
        
        Parameters
        ----------
        fn : str
            input filename
        
        Returns
        -------
        NgramIndex
            ngram index
        
        Raises
        ------
        KeyError
            raised if a required field does not appear in the input file
        """
        try:
            data = json.load(open(fn))
        except json.JSONDecodeError:
            print("Data file must be a JSON")
        for key in ["word_index","cpt_state","max_len_state"]:
            if not key in data:
                raise KeyError("{0} field cannot be found in given file".format(key))
        new_obj = WordIndex(loaded=True)
        new_obj.ngram_index = data["ngram_index"]
        new_obj.index_ngram = {v:k for k,v in new_obj.ngram_index.items()}
        new_obj.cpt = data["cpt_state"]
        new_obj.max_len = data["max_len_state"]
        return new_obj