Skip to content
Snippets Groups Projects
Commit d7e8544f authored by Alice Brenon's avatar Alice Brenon
Browse files

Encoder seems to belong to the BERT model

parent 0d657426
No related branches found
No related tags found
No related merge requests found
......@@ -12,20 +12,6 @@ def get_device():
print('No GPU available, using the CPU instead.')
return torch.device("cpu")
def get_encoder(root_path, create_from=None):
path = f"{root_path}/label_encoder.pkl"
if os.path.isfile(path):
with open(path, 'rb') as pickled:
return pickle.load(pickled)
elif create_from is not None:
encoder = preprocessing.LabelEncoder()
encoder.fit(create_from)
with open(path, 'wb') as file:
pickle.dump(encoder, file)
return encoder
else:
raise FileNotFoundError(path)
def loader(f):
def wrapped(*args, **kwargs):
name = f.__name__.replace('_init_', '')
......@@ -63,6 +49,20 @@ class BERT:
bert = BertForSequenceClassification.from_pretrained(self.root_path)
self.model = bert.to(self.device.type)
@loader
def _init_encoder(self, create_from=None):
path = f"{self.root_path}/label_encoder.pkl"
if os.path.isfile(path):
with open(path, 'rb') as pickled:
self.encoder = pickle.load(pickled)
elif create_from is not None:
self.encoder = preprocessing.LabelEncoder()
self.encoder.fit(create_from)
with open(path, 'wb') as file:
pickle.dump(self.encoder, file)
else:
raise FileNotFoundError(path)
def import_data(self, data):
return map(lambda d: d.to(self.device), data)
......
from BERT.Base import BERT, get_encoder
from BERT.Base import BERT
import numpy
from tqdm import tqdm
from transformers import TextClassificationPipeline
......@@ -19,7 +19,7 @@ class Classifier(BERT):
def __init__(self, root_path):
BERT.__init__(self, root_path)
self._init_pipe()
self.encoder = get_encoder(root_path)
self._init_encoder()
def _init_pipe(self):
self.pipe = TextClassificationPipeline(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment