You need to sign in or sign up before continuing.
Newer
Older
from transformers import BertForSequenceClassification, BertTokenizer
import os
import pickle
from sklearn import preprocessing
import torch
def get_device():
if torch.cuda.is_available():
print('We will use the GPU:', torch.cuda.get_device_name(0))
return torch.device("cuda")
else:
print('No GPU available, using the CPU instead.')
return torch.device("cpu")
def loader(f):
def wrapped(*args, **kwargs):
name = f.__name__.replace('_init_', '')
print(f' - {name}', end='')
f(*args, **kwargs)
print(f'\r✔️ {name}')
return wrapped
class BERT:
model_name = 'bert-base-multilingual-cased'
Alice Brenon
committed
encoder_file = 'label_encoder.pkl'
def __init__(self, root_path, train_on=None):
self.device = get_device()
print('Loading BERT tools')
self._init_tokenizer()
self.root_path = root_path
self._init_classifier(train_on)
self._init_encoder(train_on)
@loader
def _init_tokenizer(self):
self.tokenizer = BertTokenizer.from_pretrained(BERT.model_name)
@loader
def _init_classifier(self, train_on):
if train_on is not None:
bert = BertForSequenceClassification.from_pretrained(
BERT.model_name, # Use the 12-layer BERT model, with an uncased vocab.
num_labels = len(train_on),
output_attentions = False,
output_hidden_states = False
)
else:
bert = BertForSequenceClassification.from_pretrained(self.root_path)
self.model = bert.to(self.device.type)
@loader
def _init_encoder(self, train_on):
Alice Brenon
committed
path = f"{self.root_path}/{BERT.encoder_file}"
if os.path.exists(path):
with open(path, 'rb') as pickled:
self.encoder = pickle.load(pickled)
elif train_on is not None:
self.encoder = preprocessing.LabelEncoder()
self.encoder.fit(train_on)
with open(path, 'wb') as file:
pickle.dump(self.encoder, file)
else:
raise FileNotFoundError(path)
def import_data(self, data):
return map(lambda d: d.to(self.device), data)
def save(self):
self.model.save_pretrained(self.root_path)