Skip to content
Snippets Groups Projects
Select Git revision
  • f09a3cf4e46d8497428f78783ec3e463d2da5397
  • main default protected
  • ugly
  • ml-training
  • the-amazing-obj-refactoring
5 results

detokenizer.hs

Blame
  • Base.py 2.27 KiB
    from transformers import BertForSequenceClassification, BertTokenizer
    import os
    import pickle
    from sklearn import preprocessing
    import torch
    
    def get_device():
        if torch.cuda.is_available():
            print('We will use the GPU:', torch.cuda.get_device_name(0))
            return torch.device("cuda")
        else:
            print('No GPU available, using the CPU instead.')
            return torch.device("cpu")
    
    def loader(f):
        def wrapped(*args, **kwargs):
            name = f.__name__.replace('_init_', '')
            print(f' - {name}', end='')
            f(*args, **kwargs)
            print(f'\r✔️  {name}')
        return wrapped
    
    class BERT:
        model_name = 'bert-base-multilingual-cased'
        encoder_file = 'label_encoder.pkl'
    
        def __init__(self, root_path, train_on=None):
            self.device = get_device()
            print('Loading BERT tools')
            self._init_tokenizer()
            self.root_path = root_path
            self._init_classifier(train_on)
            self._init_encoder(train_on)
    
        @loader
        def _init_tokenizer(self):
            self.tokenizer = BertTokenizer.from_pretrained(BERT.model_name)
    
        @loader
        def _init_classifier(self, train_on):
            if train_on is not None:
                bert = BertForSequenceClassification.from_pretrained(
                        BERT.model_name, # Use the 12-layer BERT model, with an uncased vocab.
                        num_labels = len(train_on),
                        output_attentions = False,
                        output_hidden_states = False
                        )
            else:
                bert = BertForSequenceClassification.from_pretrained(self.root_path)
            self.model = bert.to(self.device.type)
    
        @loader
        def _init_encoder(self, train_on):
            path = f"{self.root_path}/{BERT.encoder_file}"
            if os.path.exists(path):
                with open(path, 'rb') as pickled:
                    self.encoder = pickle.load(pickled)
            elif train_on is not None:
                self.encoder = preprocessing.LabelEncoder()
                self.encoder.fit(train_on)
                with open(path, 'wb') as file:
                    pickle.dump(self.encoder, file)
            else:
                raise FileNotFoundError(path)
    
        def import_data(self, data):
            return map(lambda d: d.to(self.device), data)
    
        def save(self):
            self.model.save_pretrained(self.root_path)