Skip to content
Snippets Groups Projects
Commit 0d657426 authored by Alice Brenon's avatar Alice Brenon
Browse files

Fix mistakes created when refactoring

parent 75f61841
No related branches found
No related tags found
No related merge requests found
......@@ -42,15 +42,15 @@ class BERT:
print('Loading BERT tools')
self._init_tokenizer()
self.root_path = root_path
_init_classifier(training)
self._init_classifier(training)
@loader
def _init_tokenizer():
def _init_tokenizer(self):
self.tokenizer = BertTokenizer.from_pretrained(BERT.model_name)
@loader
def _init_classifier(training)
if training
def _init_classifier(self, training):
if training:
bert = BertForSequenceClassification.from_pretrained(
model_name, # Use the 12-layer BERT model, with an uncased vocab.
num_labels = numberOfClasses, # The number of output labels--2 for binary classification.
......
from BERT.Base import BERT
class Trainer(BERT):
pass
......@@ -30,6 +30,8 @@ class TSVIndexed(Corpus):
self.tsv_path = tsv_path
self.column_name = column_name
self.data = None
self.projectors = dict((p, self.__getattribute__(p))
for p in ['key', 'content', 'full'])
def load(self):
if self.data is None:
......@@ -46,19 +48,19 @@ class TSVIndexed(Corpus):
def content(self, key, row):
pass
def keys(self, _, row):
def key(self, _, row):
return row[self.keys].to_dict()
def full(self, key, row):
d = self.keys(key, row)
d = self.key(key, row)
d[self.column_name] = self.content(key, row).strip() + '\n'
return d
def get_all(self, projector):
def get_all(self, projector=None):
if projector is None:
projector = self.full
elif type(projector) == str:
projector = self.__getattribute__(projector)
elif type(projector) == str and projector in self.projectors:
projector = self.projectors[projector]
self.load()
for row in self.data.iterrows():
yield projector(*row)
......
......@@ -21,8 +21,8 @@ def label(classify, source, name='label'):
:return: a panda dataframe containing the records from the input TSV file plus
an additional column
"""
records = pandas.DataFrame(source.get_all('keys'))
records[name] = classify(source.get_all('content')
records = pandas.DataFrame(source.get_all('key'))
records[name] = classify(source.get_all('content'))
return records
if __name__ == '__main__':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment