Skip to content
Snippets Groups Projects
Commit f13251b8 authored by Ludovic Moncla's avatar Ludovic Moncla
Browse files

Update Predict_LGE.ipynb

parent 926002bb
No related branches found
No related tags found
No related merge requests found
%% Cell type:markdown id: tags:
# BERT Predict classification
## 1. Setup the environment
### 1.1 Setup colab environment
#### 1.1.1 Install packages
%% Cell type:code id: tags:
``` python
!pip install transformers==4.10.3
!pip install sentencepiece
```
%% Cell type:markdown id: tags:
#### 1.1.2 Use more RAM
%% Cell type:code id: tags:
``` python
from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))
if ram_gb < 20:
print('Not using a high-RAM runtime')
else:
print('You are using a high-RAM runtime!')
```
%% Cell type:markdown id: tags:
#### 1.1.3 Mount GoogleDrive
%% Cell type:code id: tags:
``` python
from google.colab import drive
drive.mount('/content/drive')
```
%% Cell type:markdown id: tags:
### 1.2 Import librairies
%% Cell type:code id: tags:
``` python
import os
import pandas as pd
import numpy as np
import pickle
import torch
from tqdm import tqdm
from transformers import BertTokenizer, BertForSequenceClassification, CamembertTokenizer, CamembertForSequenceClassification
from torch.utils.data import TensorDataset, DataLoader, SequentialSampler
```
%% Cell type:markdown id: tags:
### 1.3 Setup GPU
%% Cell type:code id: tags:
``` python
# If there's a GPU available...
if torch.cuda.is_available():
# Tell PyTorch to use the GPU.
device = torch.device("cuda")
gpu_name = "cuda"
print('There are %d GPU(s) available.' % torch.cuda.device_count())
print('We will use the GPU:', torch.cuda.get_device_name(0))
# for MacOS
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
device = torch.device("mps")
gpu_name = "mps"
print('We will use the GPU')
else:
device = torch.device("cpu")
gpu_name = "cpu"
print('No GPU available, using the CPU instead.')
```
%% Output
We will use the GPU
%% Cell type:markdown id: tags:
## 2. Utils
%% Cell type:code id: tags:
``` python
def generate_dataloader(tokenizer, sentences, batch_size = 8, max_len = 512):
# Tokenize all of the sentences and map the tokens to thier word IDs.
input_ids_test = []
# For every sentence...
for sent in sentences:
# `encode` will:
# (1) Tokenize the sentence.
# (2) Prepend the `[CLS]` token to the start.
# (3) Append the `[SEP]` token to the end.
# (4) Map tokens to their IDs.
encoded_sent = tokenizer.encode(
sent, # Sentence to encode.
add_special_tokens = True, # Add '[CLS]' and '[SEP]'
# This function also supports truncation and conversion
# to pytorch tensors, but I need to do padding, so I
# can't use these features.
#max_length = max_len, # Truncate all sentences.
#return_tensors = 'pt', # Return pytorch tensors.
)
input_ids_test.append(encoded_sent)
# Pad our input tokens
padded_test = []
for i in input_ids_test:
if len(i) > max_len:
padded_test.extend([i[:max_len]])
else:
padded_test.extend([i + [0] * (max_len - len(i))])
input_ids_test = np.array(padded_test)
# Create attention masks
attention_masks = []
# Create a mask of 1s for each token followed by 0s for padding
for seq in input_ids_test:
seq_mask = [float(i>0) for i in seq]
attention_masks.append(seq_mask)
# Convert to tensors.
inputs = torch.tensor(input_ids_test)
masks = torch.tensor(attention_masks)
#set batch size
# Create the DataLoader.
data = TensorDataset(inputs, masks)
prediction_sampler = SequentialSampler(data)
return DataLoader(data, sampler=prediction_sampler, batch_size=batch_size)
def predict(model, dataloader, device):
# Put model in evaluation mode
model.eval()
# Tracking variables
predictions_test , true_labels = [], []
pred_labels_ = []
# Predict
for batch in dataloader:
# Add batch to GPU
batch = tuple(t.to(device) for t in batch)
# Unpack the inputs from the dataloader
b_input_ids, b_input_mask = batch
# Telling the model not to compute or store gradients, saving memory and
# speeding up prediction
with torch.no_grad():
# Forward pass, calculate logit predictions
outputs = model(b_input_ids, token_type_ids=None,
attention_mask=b_input_mask)
logits = outputs[0]
#print(logits)
# Move logits and labels to CPU ???
logits = logits.detach().cpu().numpy()
#print(logits)
# Store predictions and true labels
predictions_test.append(logits)
pred_labels = []
for i in range(len(predictions_test)):
# The predictions for this batch are a 2-column ndarray (one column for "0"
# and one column for "1"). Pick the label with the highest value and turn this
# in to a list of 0s and 1s.
pred_labels_i = np.argmax(predictions_test[i], axis=1).flatten()
pred_labels.append(pred_labels_i)
pred_labels_ += [item for sublist in pred_labels for item in sublist]
return pred_labels_
def text_folder_to_dataframe(path):
data = []
# id,tome,filename,nb_words,content,domain
for tome in sorted(os.listdir(path)):
try:
for article in tqdm(sorted(os.listdir(path + "/" + tome))):
filename = article[:-4]
id = tome + filename
if article[-4:] == ".txt":
with open(path + "/" + tome + "/" + article) as f:
content = f.read()
data.append([id, tome, filename, content, len(content.split(' '))])
except NotADirectoryError:
pass
return pd.DataFrame(data, columns=['id', 'tome', 'filename', 'content', 'nb_words'])
```
%% Cell type:markdown id: tags:
## 3. Load Data
%% Cell type:markdown id: tags:
### 3.1 LGE (Nakala)
%% Cell type:code id: tags:
``` python
!wget https://api.nakala.fr/data/10.34847/nkl.74eb1xfd/e522413b58b04ab7c283f8fa68642e9cb69ab5c5
```
%% Cell type:code id: tags:
``` python
!unzip e522413b58b04ab7c283f8fa68642e9cb69ab5c5
```
%% Cell type:code id: tags:
``` python
#input_path = "/Users/lmoncla/Documents/Data/Corpus/LGE/Text"
input_path = "./Text"
input_path = "/Users/lmoncla/Documents/Data/Corpus/LGE/Text"
#input_path = "./Text"
```
%% Cell type:code id: tags:
``` python
df_LGE = text_folder_to_dataframe(input_path)
#df_LGE = pd.read_csv(path + "data/LGE_withContent.tsv", sep="\t")
data_LGE = df_LGE["content"].values
```
%% Cell type:code id: tags:
``` python
df_LGE.head()
```
%% Output
id tome rank domain remark \
0 abrabeses-0 1 623 geography NaN
1 accius-0 1 1076 biography NaN
2 achenbach-2 1 1357 biography NaN
3 acireale-0 1 1513 geography NaN
4 actée-0 1 1731 botany NaN
content
0 ABRABESES. Village d’Espagne de la prov. de Za...
1 ACCIUS, L. ou L. ATTIUS (170-94 av. J.-C.), po...
2 ACHENBACH(Henri), administrateur prussien, né ...
3 ACIREALE. Yille de Sicile, de la province et d...
4 ACTÉE(Actœa L.). Genre de plantes de la famill...
%% Cell type:code id: tags:
``` python
df_LGE.shape
```
%% Output
(310, 6)
%% Cell type:markdown id: tags:
## 3. Load model and predict
### 3.1 BERT / CamemBERT
%% Cell type:code id: tags:
``` python
#path = "drive/MyDrive/Classification-EDdA/"
path = "../"
model_name = "bert-base-multilingual-cased"
#model_name = "camembert-base"
model_path = path + "models/model_" + model_name + "_s10000.pt"
```
%% Cell type:code id: tags:
``` python
if model_name == 'bert-base-multilingual-cased' :
print('Loading Bert Tokenizer...')
tokenizer = BertTokenizer.from_pretrained(model_name)
elif model_name == 'camembert-base':
print('Loading Camembert Tokenizer...')
tokenizer = CamembertTokenizer.from_pretrained(model_name)
print('Loading Bert Tokenizer...')
tokenizer = BertTokenizer.from_pretrained(model_name)
```
%% Output
Loading Bert Tokenizer...
%% Cell type:code id: tags:
``` python
data_loader = generate_dataloader(tokenizer, data_LGE)
```
%% Output
Token indices sequence length is longer than the specified maximum sequence length for this model (1204 > 512). Running this sequence through the model will result in indexing errors
%% Cell type:code id: tags:
``` python
model = BertForSequenceClassification.from_pretrained(model_path).to(gpu_name) #.to("cuda")
```
%% Cell type:code id: tags:
``` python
pred = predict(model, data_loader, device)
```
%% Cell type:code id: tags:
``` python
encoder_filename = "models/label_encoder.pkl"
with open(path + encoder_filename, 'rb') as file:
encoder = pickle.load(file)
```
%% Cell type:code id: tags:
``` python
p2 = list(encoder.inverse_transform(pred))
```
%% Cell type:code id: tags:
``` python
df_LGE['domain'] = p2
```
%% Cell type:code id: tags:
``` python
df_LGE.head(50)
```
%% Output
id tome rank domain remark \
0 abrabeses-0 1 623 geography NaN
1 accius-0 1 1076 biography NaN
2 achenbach-2 1 1357 biography NaN
3 acireale-0 1 1513 geography NaN
4 actée-0 1 1731 botany NaN
5 adulteration-0 1 2197 NaN cross reference
6 aérides-0 1 2334 botany NaN
7 ager-0 1 2710 biography NaN
8 aigu-1 1 3160 NaN cross reference
9 alavika-0 1 3664 theology NaN
10 allassac-0 2 755 geography NaN
11 allegretto-0 2 786 NaN cross reference
12 alleuze-0 2 908 geography NaN
13 alliat-0 2 933 geography NaN
14 amanty-0 2 1651 geography NaN
15 âmasserah-0 2 1701 geography explicit domain
16 a-118 2 2971 history NaN
17 androclès-0 2 3261 mythology explicit domain
18 anfouson-0 2 3394 zoology NaN
19 anicet-bourgeois-0 2 3717 biography NaN
20 anomalistique-0 3 238 astronomy explicit domain
21 anostostome-0 3 298 zoology NaN
22 anthoxanthème-0 3 571 chemistry NaN
23 aod-0 3 1024 theology NaN
24 aphellan-0 3 1177 astronomy NaN
25 appelle-0 3 1494 geography NaN
26 aragona-1 3 1841 biography NaN
27 araujuzon-0 3 1940 geography NaN
28 ardant-0 3 2421 biography NaN
29 ariano-0 3 2839 geography NaN
30 athabaska-0 4 1118 anthropology NaN
31 aslonnes-0 4 446 geography NaN
32 astr0rh1za-0 4 992 zoology explicit domain
33 atthidographes-0 4 1397 NaN cross reference
34 aubery-2 4 1577 biography NaN
35 aula-0 4 1992 history NaN
36 au-113 4 2112 botany explicit domain
37 auriol-4 4 2224 NaN cross reference
38 ave-lalleniant-0 4 2739 biography NaN
39 badin-2 4 3857 biography NaN
40 baizieux-0 5 133 geography NaN
41 balsam1te-0 5 677 botany explicit domain
42 balze-0 5 757 navy explicit domain
43 bande-2 5 880 history NaN
44 barbosa-5 5 1580 biography NaN
45 bati-0 5 2955 architecture NaN
46 baveuse-0 5 3457 zoology explicit domain
47 beard-2 5 3728 biography NaN
48 beaufort-4 5 3838 geography NaN
49 beaumont-26 5 4018 biography NaN
content \
0 ABRABESES. Village d’Espagne de la prov. de Za...
1 ACCIUS, L. ou L. ATTIUS (170-94 av. J.-C.), po...
2 ACHENBACH(Henri), administrateur prussien, né ...
3 ACIREALE. Yille de Sicile, de la province et d...
4 ACTÉE(Actœa L.). Genre de plantes de la famill...
5 ADULTERATION. Altération d’un médicament, d’un...
6 AÉRIDES{Aérides Lour.). Genres de plantes de l...
7 AGERouAGERIUS (Nicolaus), médecin alsacien, né...
8 AIGU1 LH E (V. Raimond d’).\n
9 ALAVIKA« qui est d'Alava »(V. ce mot) : Bhikch...
10 ALLASSAC. Com. du dép. de la Corrèze, arr. de ...
11 ALLEGRETTO(V. Allegro).\n
12 ALLEUZE. Com. du dép. du Cantal, arr. et cant....
13 ALLIAT. Com. du dép. de l’Ariège, arr. de Foix...
14 AMANTY. Corn, du dép. de la Meuse, arr. de Com...
15 ÂMASSERAH, AMASR1 ou AMASRAH (Géogr.). Ville d...
16 AN Cl LIA. Boucliers sacrés des Romains, au no...
17 ANDROCLÈS(Myth.), un fils d’Eole qui régna sur...
18 ANFOUSON. Nom donné à Nice au Néron brun\n(V. ...
19 ANICET-BOURGEOIS(Auguste Anicet, connu sous le...
20 ANOMALISTIQUE(Astron.). On appelle révolution\...
21 ANOSTOSTOME(Anostostoma Gray). Genre d’insecte...
22 ANTHOXANTHÈME. L’un des deux principes coloran...
23 AOD, plus exactement Ehoud. personnage des com...
24 APHELLAN(Astron.). Un des noms de l’étoile a2 ...
25 APPELLE. Com. du dép. du Tarn, arr. de Lavaux,...
26 ARAGONA, cardinal d’origine sicilienne, né en ...
27 ARAUJUZON. Com. du dép. des Basses-Pyrénées, a...
28 ARDANT(Paul-Joseph), général français, né en 1...
29 ARIANOdi Puglia. Ville de la prov. de principa...
30 ATHABASKA. Col, rivière, lac, territoire et fa...
31 ASLONNES, corn, du dép. de la Vienne, arr. de ...
32 ASTR0RH1ZA(Zool.).Genre deForaminifèresimperfo...
33 ATTHIDOGRAPHES(V. Atthide).\n
34 AUBERY(Antoine;, historien français, né le .18...
35 AULA. Mot latin signifiant cour, lieu découver...
36 AUNÉE (bot.). L'Aunée, Grande Année, Année off...
37 AURIOL. Nom donné à Marseille au Maquereau (V....
38 AVE-LALLENIANT(Robert-Christian-Barthold), méd...
39 BADIN(Pierre-Adolphe), peintre français, né à ...
40 BAIZIEUX(Bacium, Basium). Com. du dép. de la\n...
41 BALSAM1TE(Bot.) (Balsamita Desf.). Genre de Co...
42 BALZE(Mar.). Radeau delà côte occidentale de l...
43 BANDE(Ordre delà) ou de l’ECHARPE.Ordre milita...
44 BARBOSA(Antonio), jésuite et orientaliste port...
45 BATIÈRE. Toit en forme de bât se terminant à c...
46 BAVEUSE(Zool.). Nom vulgaire par lequel les\np...
47 BEARD(James-Henry), peintre américain contempo...
48 BEAUFORT. Com. du dép. de la Meuse, arr. de Mo...
49 BEAUMONT(J.-G. Leprevôt de), secrétaire du cle...
class_bert
0 Géographie
1 Belles-lettres - Poésie
2 Histoire
3 Géographie
4 Histoire naturelle
5 Chimie
6 Histoire naturelle
7 Histoire
8 Marine
9 Religion
10 Géographie
11 Musique
12 Géographie
13 Géographie
14 Géographie
15 Géographie
16 Antiquité
17 Antiquité
18 Histoire naturelle
19 Belles-lettres - Poésie
20 Physique - [Sciences physico-mathématiques]
21 Histoire naturelle
22 Pharmacie
23 Histoire
24 Physique - [Sciences physico-mathématiques]
25 Géographie
26 Religion
27 Géographie
28 Militaire (Art) - Guerre - Arme
29 Géographie
30 Géographie
31 Géographie
32 Histoire naturelle
33 Géographie
34 Histoire
35 Architecture
36 Histoire naturelle
37 Histoire naturelle
38 Histoire
39 Arts et métiers
40 Géographie
41 Histoire naturelle
42 Marine
43 Histoire
44 Religion
45 Architecture
46 Histoire naturelle
47 Beaux-arts
48 Géographie
49 Histoire
%% Cell type:code id: tags:
``` python
filepath = path + "results_LGE/LGE-metadata-withContent.csv"
df_LGE.to_csv(filepath, sep="\,")
```
%% Cell type:code id: tags:
``` python
df_LGE.drop(columns=['content'], inplace=True)
filepath = path + "results_LGE/LGE-metadata.csv"
df_LGE.to_csv(filepath, sep="\,")
```
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment