diff --git a/scripts/ML/BERT/Base.py b/scripts/ML/BERT/Base.py index a1d1b8e50cad8fe60ed1fbef0d3b9ee5029bebb2..db816029cb57561f96dba46233ce9fe2a2c28eec 100644 --- a/scripts/ML/BERT/Base.py +++ b/scripts/ML/BERT/Base.py @@ -22,6 +22,7 @@ def loader(f): class BERT: model_name = 'bert-base-multilingual-cased' + encoder_file = 'label_encoder.pkl' def __init__(self, root_path, train_on=None): self.device = get_device() @@ -50,8 +51,8 @@ class BERT: @loader def _init_encoder(self, train_on): - path = f"{self.root_path}/label_encoder.pkl" - if os.path.isfile(path): + path = f"{self.root_path}/{BERT.encoder_file}" + if os.path.exists(path): with open(path, 'rb') as pickled: self.encoder = pickle.load(pickled) elif train_on is not None: diff --git a/scripts/ML/LabeledData.py b/scripts/ML/LabeledData.py new file mode 100644 index 0000000000000000000000000000000000000000..79fff38cbd3846b0d51af59cd0828742eacc3cee --- /dev/null +++ b/scripts/ML/LabeledData.py @@ -0,0 +1,36 @@ +import numpy +from os.path import dirname, isdir, isfile +import pandas +import torch +from torch.utils.data import DataLoader, RandomSampler, TensorDataset + +class LabeledData: + def __init__(self, + corpus, + label_column, + max_length=512, + batch_size=16): + self.corpus = corpus + self.max_length = max_length + self.batch_size = batch_size + self._init_labels(label_column) + + def _init_labels(self, label_column): + self.corpus.load() + self.labels = self.corpus.data[label_column] + self.unique = self.labels.unique() + + def load(self, bert): + encoded_data = TensorDataset(*map(torch.tensor, self.train_for(bert))) + return DataLoader(encoded_data, + sampler=RandomSampler(encoded_data), + batch_size=self.batch_size) + + def train_for(self, bert): + texts = self.corpus.get_all('content') + tokenized = bert.tokenizer(list(texts), + padding='max_length', + truncation='only_first') + return (tokenized.input_ids, + tokenized.attention_mask, + bert.encoder.transform(self.labels)) diff --git a/scripts/ML/trainMultiBERT.py b/scripts/ML/trainMultiBERT.py new file mode 100755 index 0000000000000000000000000000000000000000..637ad2942745a02656705da1587181809f667415 --- /dev/null +++ b/scripts/ML/trainMultiBERT.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 +from BERT import BERT, Trainer +from Corpus import Directory +import GEODE.discursive as discursive +from LabeledData import LabeledData +import os +import sys + +def split(columnName): + return {} + +def load(rootPath): + classes = {} + for f in os.listdir(rootPath): + if f[-4:] == '.tsv': + classes[f[:-4]] = f"{rootPath}/{f}" + return classes + +def trainSubClassifier(trainRoot, modelRoot, className): + trainData = Directory(trainRoot, tsv_filename=className) + labeled_data = LabeledData(trainData, "answer") + subModelPath = f"{modelRoot}/{className}" + os.makedirs(subModelPath, exist_ok=True) + os.symlink(f"../{BERT.encoder_file}", f"{subModelPath}/{BERT.encoder_file}") + trainer = Trainer(subModelPath, labeled_data) + trainer() + +if __name__ == '__main__': + for className in discursive.functions: + trainSubClassifier(sys.argv[1], sys.argv[2], className) diff --git a/scripts/ML/trainSimpleBERT.py b/scripts/ML/trainSimpleBERT.py new file mode 100755 index 0000000000000000000000000000000000000000..d869b4a8b0f73743515f5c6e5824a5aa267475c1 --- /dev/null +++ b/scripts/ML/trainSimpleBERT.py @@ -0,0 +1,10 @@ +#!/usr/bin/env python3 +from Corpus import corpus +from BERT import Trainer +from LabeledData import LabeledData +import sys + +if __name__ == '__main__': + labeled_data = LabeledData(corpus(sys.argv[1]), "paragraphFunction") + trainer = Trainer(sys.argv[2], labeled_data) + trainer()