Skip to content
Snippets Groups Projects
Base.py 2.27 KiB
Newer Older
Alice Brenon's avatar
Alice Brenon committed
from transformers import BertForSequenceClassification, BertTokenizer
import os
import pickle
from sklearn import preprocessing
import torch

def get_device():
    if torch.cuda.is_available():
        print('We will use the GPU:', torch.cuda.get_device_name(0))
        return torch.device("cuda")
    else:
        print('No GPU available, using the CPU instead.')
        return torch.device("cpu")

def loader(f):
    def wrapped(*args, **kwargs):
        name = f.__name__.replace('_init_', '')
        print(f' - {name}', end='')
        f(*args, **kwargs)
        print(f'\r✔️  {name}')
    return wrapped

class BERT:
    model_name = 'bert-base-multilingual-cased'
Alice Brenon's avatar
Alice Brenon committed

    def __init__(self, root_path, train_on=None):
        self.device = get_device()
        print('Loading BERT tools')
        self._init_tokenizer()
        self.root_path = root_path
        self._init_classifier(train_on)
        self._init_encoder(train_on)

    @loader
    def _init_tokenizer(self):
        self.tokenizer = BertTokenizer.from_pretrained(BERT.model_name)

    @loader
    def _init_classifier(self, train_on):
        if train_on is not None:
            bert = BertForSequenceClassification.from_pretrained(
                    BERT.model_name, # Use the 12-layer BERT model, with an uncased vocab.
                    num_labels = len(train_on),
                    output_attentions = False,
                    output_hidden_states = False
                    )
        else:
            bert = BertForSequenceClassification.from_pretrained(self.root_path)
        self.model = bert.to(self.device.type)

    @loader
    def _init_encoder(self, train_on):
        path = f"{self.root_path}/{BERT.encoder_file}"
        if os.path.exists(path):
Alice Brenon's avatar
Alice Brenon committed
            with open(path, 'rb') as pickled:
                self.encoder = pickle.load(pickled)
        elif train_on is not None:
            self.encoder = preprocessing.LabelEncoder()
            self.encoder.fit(train_on)
            with open(path, 'wb') as file:
                pickle.dump(self.encoder, file)
        else:
            raise FileNotFoundError(path)

    def import_data(self, data):
        return map(lambda d: d.to(self.device), data)

    def save(self):
        self.model.save_pretrained(self.root_path)