Skip to content
Snippets Groups Projects
Commit 38de3c27 authored by Alice Brenon's avatar Alice Brenon
Browse files

Take GPU detection out into a base class

parent 02253ba3
No related branches found
No related tags found
No related merge requests found
import torch
class WithGPU:
def __init__(self):
if torch.cuda.is_available():
print('We will use the GPU:', torch.cuda.get_device_name(0))
self.device = torch.device("cuda")
else:
print('No GPU available, using the CPU instead.')
self.device = torch.device("cpu")
#!/usr/bin/env python3
from gpu import WithGPU
import numpy
import pandas
import pickle
import sklearn
from sys import argv
import torch
from tqdm import tqdm
from transformers import BertForSequenceClassification, BertTokenizer, TextClassificationPipeline
class Classifier:
class Classifier(WithGPU):
"""
A class wrapping all the different models and classes used throughout a
classification task:
......@@ -22,12 +22,11 @@ class Classifier:
containing the texts to classify
"""
def __init__(self, root_path):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
WithGPU.__init__(self)
self._init_tokenizer()
self._init_model(root_path)
self._init_pipe()
self._init_encoder(f"{root_path}/label_encoder.pkl")
self.log()
def _init_model(self, path):
bert = BertForSequenceClassification.from_pretrained(path)
......@@ -48,12 +47,6 @@ class Classifier:
with open(path, 'rb') as pickled:
self.encoder = pickle.load(pickled)
def log(self):
if self.device.type == 'cpu':
print('No GPU available, using the CPU instead.')
else:
print('We will use the GPU:', torch.cuda.get_device_name(0))
def __call__(self, text_generator):
tokenizer_kwargs = {'padding':True, 'truncation':True, 'max_length':512}
predictions = []
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment