diff --git a/scripts/ML/BERT/Base.py b/scripts/ML/BERT/Base.py index fe28dbb7cdb632b9049be06b48d2e9da4b5fd4a8..a1d1b8e50cad8fe60ed1fbef0d3b9ee5029bebb2 100644 --- a/scripts/ML/BERT/Base.py +++ b/scripts/ML/BERT/Base.py @@ -23,41 +23,40 @@ def loader(f): class BERT: model_name = 'bert-base-multilingual-cased' - def __init__(self, root_path, training=False): + def __init__(self, root_path, train_on=None): self.device = get_device() print('Loading BERT tools') self._init_tokenizer() self.root_path = root_path - self._init_classifier(training) + self._init_classifier(train_on) + self._init_encoder(train_on) @loader def _init_tokenizer(self): self.tokenizer = BertTokenizer.from_pretrained(BERT.model_name) @loader - def _init_classifier(self, training): - if training: + def _init_classifier(self, train_on): + if train_on is not None: bert = BertForSequenceClassification.from_pretrained( - model_name, # Use the 12-layer BERT model, with an uncased vocab. - num_labels = numberOfClasses, # The number of output labels--2 for binary classification. - # You can increase this - # for multi-class tasks. - output_attentions = False, # Whether the model returns attentions weights. - output_hidden_states = False, # Whether the model returns all hidden-states. + BERT.model_name, # Use the 12-layer BERT model, with an uncased vocab. + num_labels = len(train_on), + output_attentions = False, + output_hidden_states = False ) else: bert = BertForSequenceClassification.from_pretrained(self.root_path) self.model = bert.to(self.device.type) @loader - def _init_encoder(self, create_from=None): + def _init_encoder(self, train_on): path = f"{self.root_path}/label_encoder.pkl" if os.path.isfile(path): with open(path, 'rb') as pickled: self.encoder = pickle.load(pickled) - elif create_from is not None: + elif train_on is not None: self.encoder = preprocessing.LabelEncoder() - self.encoder.fit(create_from) + self.encoder.fit(train_on) with open(path, 'wb') as file: pickle.dump(self.encoder, file) else: diff --git a/scripts/ML/BERT/Classifier.py b/scripts/ML/BERT/Classifier.py index 5a5d60f9174be9b395229b81e5562c6f9736705a..04bcffa0aa8361b0748444479b1177f1c8bb152b 100644 --- a/scripts/ML/BERT/Classifier.py +++ b/scripts/ML/BERT/Classifier.py @@ -19,7 +19,6 @@ class Classifier(BERT): def __init__(self, root_path): BERT.__init__(self, root_path) self._init_pipe() - self._init_encoder() def _init_pipe(self): self.pipe = TextClassificationPipeline( diff --git a/scripts/ML/BERT/Trainer.py b/scripts/ML/BERT/Trainer.py index 60bd0e57a7fc363dbbfeab470d946c5e467c07bb..06851eef22cfce4f575beaf358abde39d713e53d 100644 --- a/scripts/ML/BERT/Trainer.py +++ b/scripts/ML/BERT/Trainer.py @@ -1,4 +1,64 @@ from BERT.Base import BERT +import datetime +from loaders import set_random +import time +import torch +from transformers import AdamW, get_linear_schedule_with_warmup + +def chrono(f): + def wrapped(*args, **kwargs): + t0 = time.time() + f(*args, **kwargs) + duration = datetime.timedelta(seconds=round(time.time() - t0)) + print(f"\n {f.__name__} took: {duration}") + return wrapped class Trainer(BERT): - pass + def __init__(self, root_path, labeled_data, epochs=4): + self.epochs = epochs + BERT.__init__(self, root_path, train_on=labeled_data.unique) + self._init_utils(labeled_data.load(self)) + + def _init_utils(self, data_loader): + self.optimizer = AdamW( + self.model.parameters(), + lr = 2e-5, # args.learning_rate - default is 5e-5 + ) + self.data_loader = data_loader + self.scheduler = get_linear_schedule_with_warmup( + self.optimizer, + num_warmup_steps = 0, # Default value in run_glue.py + num_training_steps = self.epochs * len(data_loader)) + + def __call__(self): + set_random() + losses = [self.epoch(e) for e in range(self.epochs)] + self.save() + print("\nTraining complete!") + + @chrono + def epoch(self, epoch): + self._start_epoch(epoch) + self.model.train() + total_loss = sum([self.learn_on(*self.import_data(batch)) + for batch in self.data_loader]) + avg_train_loss = total_loss / len(self.data_loader) + print("\n Average training loss: {0:.2f}".format(avg_train_loss)) + return avg_train_loss + + def learn_on(self, input_ids, input_mask, labels): + self.model.zero_grad() + outputs = self.model(input_ids, + token_type_ids=None, + attention_mask=input_mask, + labels=labels) + loss = outputs[0] + loss.backward() + torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) + self.optimizer.step() + self.scheduler.step() + return loss.item() + + def _start_epoch(self, epoch): + print(f'\n======== Epoch {epoch+1} / {self.epochs} ========') + print('Training...') diff --git a/scripts/ML/Corpus.py b/scripts/ML/Corpus.py index 910cf33ff5c636c380619c8c1778c52a1e6e1d83..159092b53763e18f1b4b225790e1c6d42de05337 100644 --- a/scripts/ML/Corpus.py +++ b/scripts/ML/Corpus.py @@ -26,12 +26,12 @@ class Corpus: class TSVIndexed(Corpus): default_keys = ['work', 'volume', 'article'] + projectors = ['key', 'content', 'full'] + def __init__(self, tsv_path, column_name): self.tsv_path = tsv_path self.column_name = column_name self.data = None - self.projectors = dict((p, self.__getattribute__(p)) - for p in ['key', 'content', 'full']) def load(self): if self.data is None: @@ -60,7 +60,7 @@ class TSVIndexed(Corpus): if projector is None: projector = self.full elif type(projector) == str and projector in self.projectors: - projector = self.projectors[projector] + projector = self.__getattribute__(projector) self.load() for row in self.data.iterrows(): yield projector(*row) diff --git a/scripts/ML/loaders.py b/scripts/ML/loaders.py index 5aa9dc7a0ae58cf19612072886eb88a8b8235de3..93986f4f4d260f7c92473a5cd5909547da690e4f 100644 --- a/scripts/ML/loaders.py +++ b/scripts/ML/loaders.py @@ -4,7 +4,7 @@ import torch def set_random(): seed_value = 42 - random.seed(seed_val) - numpy.random.seed(seed_val) - torch.manual_seed(seed_val) - torch.cuda.manual_seed_all(seed_val) + random.seed(seed_value) + numpy.random.seed(seed_value) + torch.manual_seed(seed_value) + torch.cuda.manual_seed_all(seed_value) diff --git a/scripts/ML/train.py b/scripts/ML/train.py new file mode 100755 index 0000000000000000000000000000000000000000..95b812f79dde5dcaee522849126dd8188407a731 --- /dev/null +++ b/scripts/ML/train.py @@ -0,0 +1,9 @@ +#!/usr/bin/env python3 +from BERT import Trainer +from LabeledData import LabeledData +import sys + +if __name__ == '__main__': + labeled_data = LabeledData(sys.argv[1]) + trainer = Trainer(sys.argv[2], labeled_data) + trainer()