Skip to content
Snippets Groups Projects
torch_generator.py 2.19 KiB
import torch
from keras.preprocessing.sequence import pad_sequences
import numpy as np 

def chunks(lst, n):
    """Yield successive n-sized chunks from lst."""
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

class SentenceDataset(torch.utils.data.Dataset):
    'Characterizes a dataset for PyTorch'
    def __init__(self, dataframe,tokenizer,max_len=96,batch_size=32,shuffle=True):
        'Initialization'
        self.sentences = dataframe["sentence"].values
        self.labels = dataframe["label"].values
        self.tokenizer = tokenizer
        self.max_len = max_len

        self.batch_size = batch_size
        a = np.arange(len(dataframe))
        if shuffle:
            np.random.shuffle(a)
        self.batch_tokenization = list(chunks(a,batch_size))
        assert(len(self.batch_tokenization[0])==batch_size)
        self.current_batch_id = 0
        self.boundaries = (0,0+batch_size)
        self.current_batch_tokenized = self.tokenize(self.current_batch_id)

    def tokenize(self,batch_index):
        X = [ self.tokenizer.encode(self.sentences[x],add_special_tokens = True,max_length=512,truncation=True) for x in self.batch_tokenization[batch_index]]# Tokenizer
        X = pad_sequences(X, maxlen=self.max_len, dtype="long", value=0, truncating="post", padding="post").tolist()
        return X

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.sentences)
    def __getitem__(self, index):
        'Generates one sample of data'
        if not index < self.boundaries[1] or not index >= self.boundaries[0]:
            self.current_batch_id = index//self.batch_size
            self.current_batch_tokenized = self.tokenize(self.current_batch_id)
            self.boundaries= (self.current_batch_id*self.batch_size,self.current_batch_id*self.batch_size + self.batch_size)
        # Load data and get label
        
        index_in_batch = index-self.boundaries[0]
        #print(self.boundaries,index_in_batch)
        X = self.current_batch_tokenized[index_in_batch]
        M = [int(token_id > 0) for token_id in X] # attention mask
        y = self.labels[index]
        return torch.tensor(np.array(X)),torch.tensor(np.array(M)),torch.tensor(np.array(y))