From 0d65742640d42d1ee7e380c044d511f63b91b05c Mon Sep 17 00:00:00 2001 From: Alice BRENON <alice.brenon@ens-lyon.fr> Date: Thu, 21 Sep 2023 23:36:27 +0200 Subject: [PATCH] Fix mistakes created when refactoring --- scripts/ML/BERT/Base.py | 8 ++++---- scripts/ML/BERT/Trainer.py | 4 ++++ scripts/ML/Corpus.py | 12 +++++++----- scripts/ML/predict.py | 4 ++-- 4 files changed, 17 insertions(+), 11 deletions(-) diff --git a/scripts/ML/BERT/Base.py b/scripts/ML/BERT/Base.py index c8b8d11..5c2de16 100644 --- a/scripts/ML/BERT/Base.py +++ b/scripts/ML/BERT/Base.py @@ -42,15 +42,15 @@ class BERT: print('Loading BERT tools') self._init_tokenizer() self.root_path = root_path - _init_classifier(training) + self._init_classifier(training) @loader - def _init_tokenizer(): + def _init_tokenizer(self): self.tokenizer = BertTokenizer.from_pretrained(BERT.model_name) @loader - def _init_classifier(training) - if training + def _init_classifier(self, training): + if training: bert = BertForSequenceClassification.from_pretrained( model_name, # Use the 12-layer BERT model, with an uncased vocab. num_labels = numberOfClasses, # The number of output labels--2 for binary classification. diff --git a/scripts/ML/BERT/Trainer.py b/scripts/ML/BERT/Trainer.py index e69de29..60bd0e5 100644 --- a/scripts/ML/BERT/Trainer.py +++ b/scripts/ML/BERT/Trainer.py @@ -0,0 +1,4 @@ +from BERT.Base import BERT + +class Trainer(BERT): + pass diff --git a/scripts/ML/Corpus.py b/scripts/ML/Corpus.py index b81a56c..910cf33 100644 --- a/scripts/ML/Corpus.py +++ b/scripts/ML/Corpus.py @@ -30,6 +30,8 @@ class TSVIndexed(Corpus): self.tsv_path = tsv_path self.column_name = column_name self.data = None + self.projectors = dict((p, self.__getattribute__(p)) + for p in ['key', 'content', 'full']) def load(self): if self.data is None: @@ -46,19 +48,19 @@ class TSVIndexed(Corpus): def content(self, key, row): pass - def keys(self, _, row): + def key(self, _, row): return row[self.keys].to_dict() def full(self, key, row): - d = self.keys(key, row) + d = self.key(key, row) d[self.column_name] = self.content(key, row).strip() + '\n' return d - def get_all(self, projector): + def get_all(self, projector=None): if projector is None: projector = self.full - elif type(projector) == str: - projector = self.__getattribute__(projector) + elif type(projector) == str and projector in self.projectors: + projector = self.projectors[projector] self.load() for row in self.data.iterrows(): yield projector(*row) diff --git a/scripts/ML/predict.py b/scripts/ML/predict.py index f1768db..2dfa3e8 100644 --- a/scripts/ML/predict.py +++ b/scripts/ML/predict.py @@ -21,8 +21,8 @@ def label(classify, source, name='label'): :return: a panda dataframe containing the records from the input TSV file plus an additional column """ - records = pandas.DataFrame(source.get_all('keys')) - records[name] = classify(source.get_all('content') + records = pandas.DataFrame(source.get_all('key')) + records[name] = classify(source.get_all('content')) return records if __name__ == '__main__': -- GitLab