diff --git a/scripts/ML/gpu.py b/scripts/ML/gpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0b34b2bb075299d9cd233a6559142fa9089c7c5
--- /dev/null
+++ b/scripts/ML/gpu.py
@@ -0,0 +1,10 @@
+import torch
+
+class WithGPU:
+    def __init__(self):
+        if torch.cuda.is_available():
+            print('We will use the GPU:', torch.cuda.get_device_name(0))
+            self.device = torch.device("cuda")
+        else:
+            print('No GPU available, using the CPU instead.')
+            self.device = torch.device("cpu")
diff --git a/scripts/ML/predict.py b/scripts/ML/predict.py
index 974fc69a27bcaa230b2ca5d046849daae89c8758..0f00ab0469f7ba512b2cc100bfb7e6e72f92b583 100644
--- a/scripts/ML/predict.py
+++ b/scripts/ML/predict.py
@@ -1,14 +1,14 @@
 #!/usr/bin/env python3
+from gpu import WithGPU
 import numpy
 import pandas
 import pickle
 import sklearn
 from sys import argv
-import torch
 from tqdm import tqdm
 from transformers import BertForSequenceClassification, BertTokenizer, TextClassificationPipeline
 
-class Classifier:
+class Classifier(WithGPU):
     """
     A class wrapping all the different models and classes used throughout a
     classification task:
@@ -22,12 +22,11 @@ class Classifier:
     containing the texts to classify
     """
     def __init__(self, root_path):
-        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+        WithGPU.__init__(self)
         self._init_tokenizer()
         self._init_model(root_path)
         self._init_pipe()
         self._init_encoder(f"{root_path}/label_encoder.pkl")
-        self.log()
 
     def _init_model(self, path):
         bert = BertForSequenceClassification.from_pretrained(path)
@@ -48,12 +47,6 @@ class Classifier:
         with open(path, 'rb') as pickled:
             self.encoder = pickle.load(pickled)
 
-    def log(self):
-        if self.device.type == 'cpu':
-            print('No GPU available, using the CPU instead.')
-        else:
-            print('We will use the GPU:', torch.cuda.get_device_name(0))
-
     def __call__(self, text_generator):
         tokenizer_kwargs = {'padding':True, 'truncation':True, 'max_length':512}
         predictions = []