Skip to content
Snippets Groups Projects
Commit bb824282 authored by Alice Brenon's avatar Alice Brenon
Browse files

BERT file went into the BERT directory as a package, removing left-over file

parent d7e8544f
No related branches found
No related tags found
No related merge requests found
from loaders import get_device
from transformers import BertForSequenceClassification, BertTokenizer
def loader(f):
def wrapped(*args, **kwargs):
name = f.__name__.replace('_init_', '')
print(f' - {name}', end='')
f(*args, **kwargs)
print(f'\r✔️ {name}')
return wrapped
class BERT:
model_name = 'bert-base-multilingual-cased'
def __init__(self, root_path, training=False):
self.device = get_device()
print('Loading BERT tools')
self._init_tokenizer()
self.root_path = root_path
_init_classifier(training)
@loader
def _init_tokenizer():
self.tokenizer = BertTokenizer.from_pretrained(BERT.model_name)
@loader
def _init_classifier(training)
if training
bert = BertForSequenceClassification.from_pretrained(self.root_path)
else:
bert = BertForSequenceClassification.from_pretrained(
model_name, # Use the 12-layer BERT model, with an uncased vocab.
num_labels = numberOfClasses, # The number of output labels--2 for binary classification.
# You can increase this
# for multi-class tasks.
output_attentions = False, # Whether the model returns attentions weights.
output_hidden_states = False, # Whether the model returns all hidden-states.
)
self.model = bert.to(self.device.type)
def import_data(self, data):
return map(lambda d: d.to(self.device), data)
def save(self):
self.model.save_pretrained(self.root_path)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment