from transformers import BertForSequenceClassification, BertTokenizer
import os
import pickle
from sklearn import preprocessing
import torch

def get_device():
    if torch.cuda.is_available():
        print('We will use the GPU:', torch.cuda.get_device_name(0))
        return torch.device("cuda")
    else:
        print('No GPU available, using the CPU instead.')
        return torch.device("cpu")

def loader(f):
    def wrapped(*args, **kwargs):
        name = f.__name__.replace('_init_', '')
        print(f' - {name}', end='')
        f(*args, **kwargs)
        print(f'\r✔️  {name}')
    return wrapped

class BERT:
    model_name = 'bert-base-multilingual-cased'
    encoder_file = 'label_encoder.pkl'

    def __init__(self, root_path, train_on=None):
        self.device = get_device()
        print('Loading BERT tools')
        self._init_tokenizer()
        self.root_path = root_path
        self._init_classifier(train_on)
        self._init_encoder(train_on)

    @loader
    def _init_tokenizer(self):
        self.tokenizer = BertTokenizer.from_pretrained(BERT.model_name)

    @loader
    def _init_classifier(self, train_on):
        if train_on is not None:
            bert = BertForSequenceClassification.from_pretrained(
                    BERT.model_name, # Use the 12-layer BERT model, with an uncased vocab.
                    num_labels = len(train_on),
                    output_attentions = False,
                    output_hidden_states = False
                    )
        else:
            bert = BertForSequenceClassification.from_pretrained(self.root_path)
        self.model = bert.to(self.device.type)

    @loader
    def _init_encoder(self, train_on):
        path = f"{self.root_path}/{BERT.encoder_file}"
        if os.path.exists(path):
            with open(path, 'rb') as pickled:
                self.encoder = pickle.load(pickled)
        elif train_on is not None:
            self.encoder = preprocessing.LabelEncoder()
            self.encoder.fit(train_on)
            with open(path, 'wb') as file:
                pickle.dump(self.encoder, file)
        else:
            raise FileNotFoundError(path)

    def import_data(self, data):
        return map(lambda d: d.to(self.device), data)

    def save(self):
        self.model.save_pretrained(self.root_path)