diff --git a/scripts/ML/BERT/Base.py b/scripts/ML/BERT/Base.py index 5c2de164bac527d7f41ea13b74b53b594964d4f5..fe28dbb7cdb632b9049be06b48d2e9da4b5fd4a8 100644 --- a/scripts/ML/BERT/Base.py +++ b/scripts/ML/BERT/Base.py @@ -12,20 +12,6 @@ def get_device(): print('No GPU available, using the CPU instead.') return torch.device("cpu") -def get_encoder(root_path, create_from=None): - path = f"{root_path}/label_encoder.pkl" - if os.path.isfile(path): - with open(path, 'rb') as pickled: - return pickle.load(pickled) - elif create_from is not None: - encoder = preprocessing.LabelEncoder() - encoder.fit(create_from) - with open(path, 'wb') as file: - pickle.dump(encoder, file) - return encoder - else: - raise FileNotFoundError(path) - def loader(f): def wrapped(*args, **kwargs): name = f.__name__.replace('_init_', '') @@ -63,6 +49,20 @@ class BERT: bert = BertForSequenceClassification.from_pretrained(self.root_path) self.model = bert.to(self.device.type) + @loader + def _init_encoder(self, create_from=None): + path = f"{self.root_path}/label_encoder.pkl" + if os.path.isfile(path): + with open(path, 'rb') as pickled: + self.encoder = pickle.load(pickled) + elif create_from is not None: + self.encoder = preprocessing.LabelEncoder() + self.encoder.fit(create_from) + with open(path, 'wb') as file: + pickle.dump(self.encoder, file) + else: + raise FileNotFoundError(path) + def import_data(self, data): return map(lambda d: d.to(self.device), data) diff --git a/scripts/ML/BERT/Classifier.py b/scripts/ML/BERT/Classifier.py index 2807e36b77124ca8e0f5c9006a70cd89d1612f10..5a5d60f9174be9b395229b81e5562c6f9736705a 100644 --- a/scripts/ML/BERT/Classifier.py +++ b/scripts/ML/BERT/Classifier.py @@ -1,4 +1,4 @@ -from BERT.Base import BERT, get_encoder +from BERT.Base import BERT import numpy from tqdm import tqdm from transformers import TextClassificationPipeline @@ -19,7 +19,7 @@ class Classifier(BERT): def __init__(self, root_path): BERT.__init__(self, root_path) self._init_pipe() - self.encoder = get_encoder(root_path) + self._init_encoder() def _init_pipe(self): self.pipe = TextClassificationPipeline(