diff --git a/scripts/ML/gpu.py b/scripts/ML/gpu.py new file mode 100644 index 0000000000000000000000000000000000000000..b0b34b2bb075299d9cd233a6559142fa9089c7c5 --- /dev/null +++ b/scripts/ML/gpu.py @@ -0,0 +1,10 @@ +import torch + +class WithGPU: + def __init__(self): + if torch.cuda.is_available(): + print('We will use the GPU:', torch.cuda.get_device_name(0)) + self.device = torch.device("cuda") + else: + print('No GPU available, using the CPU instead.') + self.device = torch.device("cpu") diff --git a/scripts/ML/predict.py b/scripts/ML/predict.py index 974fc69a27bcaa230b2ca5d046849daae89c8758..0f00ab0469f7ba512b2cc100bfb7e6e72f92b583 100644 --- a/scripts/ML/predict.py +++ b/scripts/ML/predict.py @@ -1,14 +1,14 @@ #!/usr/bin/env python3 +from gpu import WithGPU import numpy import pandas import pickle import sklearn from sys import argv -import torch from tqdm import tqdm from transformers import BertForSequenceClassification, BertTokenizer, TextClassificationPipeline -class Classifier: +class Classifier(WithGPU): """ A class wrapping all the different models and classes used throughout a classification task: @@ -22,12 +22,11 @@ class Classifier: containing the texts to classify """ def __init__(self, root_path): - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + WithGPU.__init__(self) self._init_tokenizer() self._init_model(root_path) self._init_pipe() self._init_encoder(f"{root_path}/label_encoder.pkl") - self.log() def _init_model(self, path): bert = BertForSequenceClassification.from_pretrained(path) @@ -48,12 +47,6 @@ class Classifier: with open(path, 'rb') as pickled: self.encoder = pickle.load(pickled) - def log(self): - if self.device.type == 'cpu': - print('No GPU available, using the CPU instead.') - else: - print('We will use the GPU:', torch.cuda.get_device_name(0)) - def __call__(self, text_generator): tokenizer_kwargs = {'padding':True, 'truncation':True, 'max_length':512} predictions = []