Skip to content
Snippets Groups Projects
Commit 0d657426 authored by Alice Brenon's avatar Alice Brenon
Browse files

Fix mistakes created when refactoring

parent 75f61841
No related branches found
No related tags found
No related merge requests found
...@@ -42,15 +42,15 @@ class BERT: ...@@ -42,15 +42,15 @@ class BERT:
print('Loading BERT tools') print('Loading BERT tools')
self._init_tokenizer() self._init_tokenizer()
self.root_path = root_path self.root_path = root_path
_init_classifier(training) self._init_classifier(training)
@loader @loader
def _init_tokenizer(): def _init_tokenizer(self):
self.tokenizer = BertTokenizer.from_pretrained(BERT.model_name) self.tokenizer = BertTokenizer.from_pretrained(BERT.model_name)
@loader @loader
def _init_classifier(training) def _init_classifier(self, training):
if training if training:
bert = BertForSequenceClassification.from_pretrained( bert = BertForSequenceClassification.from_pretrained(
model_name, # Use the 12-layer BERT model, with an uncased vocab. model_name, # Use the 12-layer BERT model, with an uncased vocab.
num_labels = numberOfClasses, # The number of output labels--2 for binary classification. num_labels = numberOfClasses, # The number of output labels--2 for binary classification.
......
from BERT.Base import BERT
class Trainer(BERT):
pass
...@@ -30,6 +30,8 @@ class TSVIndexed(Corpus): ...@@ -30,6 +30,8 @@ class TSVIndexed(Corpus):
self.tsv_path = tsv_path self.tsv_path = tsv_path
self.column_name = column_name self.column_name = column_name
self.data = None self.data = None
self.projectors = dict((p, self.__getattribute__(p))
for p in ['key', 'content', 'full'])
def load(self): def load(self):
if self.data is None: if self.data is None:
...@@ -46,19 +48,19 @@ class TSVIndexed(Corpus): ...@@ -46,19 +48,19 @@ class TSVIndexed(Corpus):
def content(self, key, row): def content(self, key, row):
pass pass
def keys(self, _, row): def key(self, _, row):
return row[self.keys].to_dict() return row[self.keys].to_dict()
def full(self, key, row): def full(self, key, row):
d = self.keys(key, row) d = self.key(key, row)
d[self.column_name] = self.content(key, row).strip() + '\n' d[self.column_name] = self.content(key, row).strip() + '\n'
return d return d
def get_all(self, projector): def get_all(self, projector=None):
if projector is None: if projector is None:
projector = self.full projector = self.full
elif type(projector) == str: elif type(projector) == str and projector in self.projectors:
projector = self.__getattribute__(projector) projector = self.projectors[projector]
self.load() self.load()
for row in self.data.iterrows(): for row in self.data.iterrows():
yield projector(*row) yield projector(*row)
......
...@@ -21,8 +21,8 @@ def label(classify, source, name='label'): ...@@ -21,8 +21,8 @@ def label(classify, source, name='label'):
:return: a panda dataframe containing the records from the input TSV file plus :return: a panda dataframe containing the records from the input TSV file plus
an additional column an additional column
""" """
records = pandas.DataFrame(source.get_all('keys')) records = pandas.DataFrame(source.get_all('key'))
records[name] = classify(source.get_all('content') records[name] = classify(source.get_all('content'))
return records return records
if __name__ == '__main__': if __name__ == '__main__':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment