diff --git a/scripts/ML/BERT.py b/scripts/ML/BERT.py index 2159179af3ab2c063555712f6b171ab37446ed2a..ec4381c9c2fd5c07a9a0707cd08e3fe298049cfa 100644 --- a/scripts/ML/BERT.py +++ b/scripts/ML/BERT.py @@ -1,8 +1,10 @@ from loaders import get_device +from transformers import BertForSequenceClassification, BertTokenizer class BERT: model_name = 'bert-base-multilingual-cased' - def __init(self, path): + + def __init__(self, path): print('Loading BERT tools…') self.tokenizer = BertTokenizer.from_pretrained(BERT.model_name) print('âœ”ï¸ tokenizer') diff --git a/scripts/ML/predict.py b/scripts/ML/predict.py index f6bdba4d2f62fb4c96ddef42b9fe8b709da4a6b9..a64100329cf6680e288092dcb32a25bc346586df 100644 --- a/scripts/ML/predict.py +++ b/scripts/ML/predict.py @@ -7,7 +7,7 @@ import sklearn from Source import Source from sys import argv from tqdm import tqdm -from transformers import BertForSequenceClassification, BertTokenizer, TextClassificationPipeline +from transformers import TextClassificationPipeline class Classifier(BERT): """