diff --git a/scripts/ML/BERT/Base.py b/scripts/ML/BERT/Base.py index c8b8d11eced28882082aea89a3eba1d071ca4ea1..5c2de164bac527d7f41ea13b74b53b594964d4f5 100644 --- a/scripts/ML/BERT/Base.py +++ b/scripts/ML/BERT/Base.py @@ -42,15 +42,15 @@ class BERT: print('Loading BERT tools') self._init_tokenizer() self.root_path = root_path - _init_classifier(training) + self._init_classifier(training) @loader - def _init_tokenizer(): + def _init_tokenizer(self): self.tokenizer = BertTokenizer.from_pretrained(BERT.model_name) @loader - def _init_classifier(training) - if training + def _init_classifier(self, training): + if training: bert = BertForSequenceClassification.from_pretrained( model_name, # Use the 12-layer BERT model, with an uncased vocab. num_labels = numberOfClasses, # The number of output labels--2 for binary classification. diff --git a/scripts/ML/BERT/Trainer.py b/scripts/ML/BERT/Trainer.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..60bd0e57a7fc363dbbfeab470d946c5e467c07bb 100644 --- a/scripts/ML/BERT/Trainer.py +++ b/scripts/ML/BERT/Trainer.py @@ -0,0 +1,4 @@ +from BERT.Base import BERT + +class Trainer(BERT): + pass diff --git a/scripts/ML/Corpus.py b/scripts/ML/Corpus.py index b81a56c79fd41dd5862a48480260e8ffd2f468c7..910cf33ff5c636c380619c8c1778c52a1e6e1d83 100644 --- a/scripts/ML/Corpus.py +++ b/scripts/ML/Corpus.py @@ -30,6 +30,8 @@ class TSVIndexed(Corpus): self.tsv_path = tsv_path self.column_name = column_name self.data = None + self.projectors = dict((p, self.__getattribute__(p)) + for p in ['key', 'content', 'full']) def load(self): if self.data is None: @@ -46,19 +48,19 @@ class TSVIndexed(Corpus): def content(self, key, row): pass - def keys(self, _, row): + def key(self, _, row): return row[self.keys].to_dict() def full(self, key, row): - d = self.keys(key, row) + d = self.key(key, row) d[self.column_name] = self.content(key, row).strip() + '\n' return d - def get_all(self, projector): + def get_all(self, projector=None): if projector is None: projector = self.full - elif type(projector) == str: - projector = self.__getattribute__(projector) + elif type(projector) == str and projector in self.projectors: + projector = self.projectors[projector] self.load() for row in self.data.iterrows(): yield projector(*row) diff --git a/scripts/ML/predict.py b/scripts/ML/predict.py index f1768db8f7d7ceb338749c0e58459ace47dbbd34..2dfa3e867e1b39a78d34c160433792de49f7a2e5 100644 --- a/scripts/ML/predict.py +++ b/scripts/ML/predict.py @@ -21,8 +21,8 @@ def label(classify, source, name='label'): :return: a panda dataframe containing the records from the input TSV file plus an additional column """ - records = pandas.DataFrame(source.get_all('keys')) - records[name] = classify(source.get_all('content') + records = pandas.DataFrame(source.get_all('key')) + records[name] = classify(source.get_all('content')) return records if __name__ == '__main__':