Skip to content
Snippets Groups Projects
Commit 26c6ee1d authored by Alice Brenon's avatar Alice Brenon
Browse files

Separate train script into one to train directly one classifier and one to...

Separate train script into one to train directly one classifier and one to train one binary classifier (accept/reject) for each discursive function
parent dc3b79a5
No related branches found
No related tags found
No related merge requests found
......@@ -22,6 +22,7 @@ def loader(f):
class BERT:
model_name = 'bert-base-multilingual-cased'
encoder_file = 'label_encoder.pkl'
def __init__(self, root_path, train_on=None):
self.device = get_device()
......@@ -50,8 +51,8 @@ class BERT:
@loader
def _init_encoder(self, train_on):
path = f"{self.root_path}/label_encoder.pkl"
if os.path.isfile(path):
path = f"{self.root_path}/{BERT.encoder_file}"
if os.path.exists(path):
with open(path, 'rb') as pickled:
self.encoder = pickle.load(pickled)
elif train_on is not None:
......
import numpy
from os.path import dirname, isdir, isfile
import pandas
import torch
from torch.utils.data import DataLoader, RandomSampler, TensorDataset
class LabeledData:
def __init__(self,
corpus,
label_column,
max_length=512,
batch_size=16):
self.corpus = corpus
self.max_length = max_length
self.batch_size = batch_size
self._init_labels(label_column)
def _init_labels(self, label_column):
self.corpus.load()
self.labels = self.corpus.data[label_column]
self.unique = self.labels.unique()
def load(self, bert):
encoded_data = TensorDataset(*map(torch.tensor, self.train_for(bert)))
return DataLoader(encoded_data,
sampler=RandomSampler(encoded_data),
batch_size=self.batch_size)
def train_for(self, bert):
texts = self.corpus.get_all('content')
tokenized = bert.tokenizer(list(texts),
padding='max_length',
truncation='only_first')
return (tokenized.input_ids,
tokenized.attention_mask,
bert.encoder.transform(self.labels))
#!/usr/bin/env python3
from BERT import BERT, Trainer
from Corpus import Directory
import GEODE.discursive as discursive
from LabeledData import LabeledData
import os
import sys
def split(columnName):
return {}
def load(rootPath):
classes = {}
for f in os.listdir(rootPath):
if f[-4:] == '.tsv':
classes[f[:-4]] = f"{rootPath}/{f}"
return classes
def trainSubClassifier(trainRoot, modelRoot, className):
trainData = Directory(trainRoot, tsv_filename=className)
labeled_data = LabeledData(trainData, "answer")
subModelPath = f"{modelRoot}/{className}"
os.makedirs(subModelPath, exist_ok=True)
os.symlink(f"../{BERT.encoder_file}", f"{subModelPath}/{BERT.encoder_file}")
trainer = Trainer(subModelPath, labeled_data)
trainer()
if __name__ == '__main__':
for className in discursive.functions:
trainSubClassifier(sys.argv[1], sys.argv[2], className)
#!/usr/bin/env python3
from Corpus import corpus
from BERT import Trainer
from LabeledData import LabeledData
import sys
if __name__ == '__main__':
labeled_data = LabeledData(corpus(sys.argv[1]), "paragraphFunction")
trainer = Trainer(sys.argv[2], labeled_data)
trainer()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment