From d7e8544f7e9d230c8c4d853cb1953e2cf789a995 Mon Sep 17 00:00:00 2001 From: Alice BRENON <alice.brenon@ens-lyon.fr> Date: Fri, 22 Sep 2023 10:53:42 +0200 Subject: [PATCH] Encoder seems to belong to the BERT model --- scripts/ML/BERT/Base.py | 28 ++++++++++++++-------------- scripts/ML/BERT/Classifier.py | 4 ++-- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/scripts/ML/BERT/Base.py b/scripts/ML/BERT/Base.py index 5c2de16..fe28dbb 100644 --- a/scripts/ML/BERT/Base.py +++ b/scripts/ML/BERT/Base.py @@ -12,20 +12,6 @@ def get_device(): print('No GPU available, using the CPU instead.') return torch.device("cpu") -def get_encoder(root_path, create_from=None): - path = f"{root_path}/label_encoder.pkl" - if os.path.isfile(path): - with open(path, 'rb') as pickled: - return pickle.load(pickled) - elif create_from is not None: - encoder = preprocessing.LabelEncoder() - encoder.fit(create_from) - with open(path, 'wb') as file: - pickle.dump(encoder, file) - return encoder - else: - raise FileNotFoundError(path) - def loader(f): def wrapped(*args, **kwargs): name = f.__name__.replace('_init_', '') @@ -63,6 +49,20 @@ class BERT: bert = BertForSequenceClassification.from_pretrained(self.root_path) self.model = bert.to(self.device.type) + @loader + def _init_encoder(self, create_from=None): + path = f"{self.root_path}/label_encoder.pkl" + if os.path.isfile(path): + with open(path, 'rb') as pickled: + self.encoder = pickle.load(pickled) + elif create_from is not None: + self.encoder = preprocessing.LabelEncoder() + self.encoder.fit(create_from) + with open(path, 'wb') as file: + pickle.dump(self.encoder, file) + else: + raise FileNotFoundError(path) + def import_data(self, data): return map(lambda d: d.to(self.device), data) diff --git a/scripts/ML/BERT/Classifier.py b/scripts/ML/BERT/Classifier.py index 2807e36..5a5d60f 100644 --- a/scripts/ML/BERT/Classifier.py +++ b/scripts/ML/BERT/Classifier.py @@ -1,4 +1,4 @@ -from BERT.Base import BERT, get_encoder +from BERT.Base import BERT import numpy from tqdm import tqdm from transformers import TextClassificationPipeline @@ -19,7 +19,7 @@ class Classifier(BERT): def __init__(self, root_path): BERT.__init__(self, root_path) self._init_pipe() - self.encoder = get_encoder(root_path) + self._init_encoder() def _init_pipe(self): self.pipe = TextClassificationPipeline( -- GitLab