From 38de3c27f5ce6b65ff7b48f605116948942f079a Mon Sep 17 00:00:00 2001
From: Alice BRENON <alice.brenon@ens-lyon.fr>
Date: Fri, 15 Sep 2023 11:58:11 +0200
Subject: [PATCH] Take GPU detection out into a base class

---
 scripts/ML/gpu.py     | 10 ++++++++++
 scripts/ML/predict.py | 13 +++----------
 2 files changed, 13 insertions(+), 10 deletions(-)
 create mode 100644 scripts/ML/gpu.py

diff --git a/scripts/ML/gpu.py b/scripts/ML/gpu.py
new file mode 100644
index 0000000..b0b34b2
--- /dev/null
+++ b/scripts/ML/gpu.py
@@ -0,0 +1,10 @@
+import torch
+
+class WithGPU:
+    def __init__(self):
+        if torch.cuda.is_available():
+            print('We will use the GPU:', torch.cuda.get_device_name(0))
+            self.device = torch.device("cuda")
+        else:
+            print('No GPU available, using the CPU instead.')
+            self.device = torch.device("cpu")
diff --git a/scripts/ML/predict.py b/scripts/ML/predict.py
index 974fc69..0f00ab0 100644
--- a/scripts/ML/predict.py
+++ b/scripts/ML/predict.py
@@ -1,14 +1,14 @@
 #!/usr/bin/env python3
+from gpu import WithGPU
 import numpy
 import pandas
 import pickle
 import sklearn
 from sys import argv
-import torch
 from tqdm import tqdm
 from transformers import BertForSequenceClassification, BertTokenizer, TextClassificationPipeline
 
-class Classifier:
+class Classifier(WithGPU):
     """
     A class wrapping all the different models and classes used throughout a
     classification task:
@@ -22,12 +22,11 @@ class Classifier:
     containing the texts to classify
     """
     def __init__(self, root_path):
-        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+        WithGPU.__init__(self)
         self._init_tokenizer()
         self._init_model(root_path)
         self._init_pipe()
         self._init_encoder(f"{root_path}/label_encoder.pkl")
-        self.log()
 
     def _init_model(self, path):
         bert = BertForSequenceClassification.from_pretrained(path)
@@ -48,12 +47,6 @@ class Classifier:
         with open(path, 'rb') as pickled:
             self.encoder = pickle.load(pickled)
 
-    def log(self):
-        if self.device.type == 'cpu':
-            print('No GPU available, using the CPU instead.')
-        else:
-            print('We will use the GPU:', torch.cuda.get_device_name(0))
-
     def __call__(self, text_generator):
         tokenizer_kwargs = {'padding':True, 'truncation':True, 'max_length':512}
         predictions = []
-- 
GitLab