Skip to content
Snippets Groups Projects
Commit 9958a37c authored by Alice Brenon's avatar Alice Brenon
Browse files

Implement Trainer

parent bb824282
No related branches found
No related tags found
No related merge requests found
......@@ -23,41 +23,40 @@ def loader(f):
class BERT:
model_name = 'bert-base-multilingual-cased'
def __init__(self, root_path, training=False):
def __init__(self, root_path, train_on=None):
self.device = get_device()
print('Loading BERT tools')
self._init_tokenizer()
self.root_path = root_path
self._init_classifier(training)
self._init_classifier(train_on)
self._init_encoder(train_on)
@loader
def _init_tokenizer(self):
self.tokenizer = BertTokenizer.from_pretrained(BERT.model_name)
@loader
def _init_classifier(self, training):
if training:
def _init_classifier(self, train_on):
if train_on is not None:
bert = BertForSequenceClassification.from_pretrained(
model_name, # Use the 12-layer BERT model, with an uncased vocab.
num_labels = numberOfClasses, # The number of output labels--2 for binary classification.
# You can increase this
# for multi-class tasks.
output_attentions = False, # Whether the model returns attentions weights.
output_hidden_states = False, # Whether the model returns all hidden-states.
BERT.model_name, # Use the 12-layer BERT model, with an uncased vocab.
num_labels = len(train_on),
output_attentions = False,
output_hidden_states = False
)
else:
bert = BertForSequenceClassification.from_pretrained(self.root_path)
self.model = bert.to(self.device.type)
@loader
def _init_encoder(self, create_from=None):
def _init_encoder(self, train_on):
path = f"{self.root_path}/label_encoder.pkl"
if os.path.isfile(path):
with open(path, 'rb') as pickled:
self.encoder = pickle.load(pickled)
elif create_from is not None:
elif train_on is not None:
self.encoder = preprocessing.LabelEncoder()
self.encoder.fit(create_from)
self.encoder.fit(train_on)
with open(path, 'wb') as file:
pickle.dump(self.encoder, file)
else:
......
......@@ -19,7 +19,6 @@ class Classifier(BERT):
def __init__(self, root_path):
BERT.__init__(self, root_path)
self._init_pipe()
self._init_encoder()
def _init_pipe(self):
self.pipe = TextClassificationPipeline(
......
from BERT.Base import BERT
import datetime
from loaders import set_random
import time
import torch
from transformers import AdamW, get_linear_schedule_with_warmup
def chrono(f):
def wrapped(*args, **kwargs):
t0 = time.time()
f(*args, **kwargs)
duration = datetime.timedelta(seconds=round(time.time() - t0))
print(f"\n {f.__name__} took: {duration}")
return wrapped
class Trainer(BERT):
pass
def __init__(self, root_path, labeled_data, epochs=4):
self.epochs = epochs
BERT.__init__(self, root_path, train_on=labeled_data.unique)
self._init_utils(labeled_data.load(self))
def _init_utils(self, data_loader):
self.optimizer = AdamW(
self.model.parameters(),
lr = 2e-5, # args.learning_rate - default is 5e-5
)
self.data_loader = data_loader
self.scheduler = get_linear_schedule_with_warmup(
self.optimizer,
num_warmup_steps = 0, # Default value in run_glue.py
num_training_steps = self.epochs * len(data_loader))
def __call__(self):
set_random()
losses = [self.epoch(e) for e in range(self.epochs)]
self.save()
print("\nTraining complete!")
@chrono
def epoch(self, epoch):
self._start_epoch(epoch)
self.model.train()
total_loss = sum([self.learn_on(*self.import_data(batch))
for batch in self.data_loader])
avg_train_loss = total_loss / len(self.data_loader)
print("\n Average training loss: {0:.2f}".format(avg_train_loss))
return avg_train_loss
def learn_on(self, input_ids, input_mask, labels):
self.model.zero_grad()
outputs = self.model(input_ids,
token_type_ids=None,
attention_mask=input_mask,
labels=labels)
loss = outputs[0]
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.optimizer.step()
self.scheduler.step()
return loss.item()
def _start_epoch(self, epoch):
print(f'\n======== Epoch {epoch+1} / {self.epochs} ========')
print('Training...')
......@@ -26,12 +26,12 @@ class Corpus:
class TSVIndexed(Corpus):
default_keys = ['work', 'volume', 'article']
projectors = ['key', 'content', 'full']
def __init__(self, tsv_path, column_name):
self.tsv_path = tsv_path
self.column_name = column_name
self.data = None
self.projectors = dict((p, self.__getattribute__(p))
for p in ['key', 'content', 'full'])
def load(self):
if self.data is None:
......@@ -60,7 +60,7 @@ class TSVIndexed(Corpus):
if projector is None:
projector = self.full
elif type(projector) == str and projector in self.projectors:
projector = self.projectors[projector]
projector = self.__getattribute__(projector)
self.load()
for row in self.data.iterrows():
yield projector(*row)
......
......@@ -4,7 +4,7 @@ import torch
def set_random():
seed_value = 42
random.seed(seed_val)
numpy.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)
random.seed(seed_value)
numpy.random.seed(seed_value)
torch.manual_seed(seed_value)
torch.cuda.manual_seed_all(seed_value)
#!/usr/bin/env python3
from BERT import Trainer
from LabeledData import LabeledData
import sys
if __name__ == '__main__':
labeled_data = LabeledData(sys.argv[1])
trainer = Trainer(sys.argv[2], labeled_data)
trainer()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment