Skip to content
Snippets Groups Projects
Commit d4f07c5f authored by Ludovic Moncla's avatar Ludovic Moncla
Browse files

Update training_bertFineTuning.py

parent 078c0596
No related branches found
No related tags found
No related merge requests found
...@@ -60,7 +60,7 @@ def format_time(elapsed): ...@@ -60,7 +60,7 @@ def format_time(elapsed):
return str(datetime.timedelta(seconds=elapsed_rounded)) return str(datetime.timedelta(seconds=elapsed_rounded))
def training_bertFineTuning(chosen_model, sentences, labels, max_len, batch_size, epochs = 4): def training_bertFineTuning(chosen_model, model_path, sentences, labels, max_len, batch_size, epochs = 4):
# If there's a GPU available... # If there's a GPU available...
if torch.cuda.is_available(): if torch.cuda.is_available():
...@@ -82,12 +82,12 @@ def training_bertFineTuning(chosen_model, sentences, labels, max_len, batch_si ...@@ -82,12 +82,12 @@ def training_bertFineTuning(chosen_model, sentences, labels, max_len, batch_si
########################################################################################################### ###########################################################################################################
if chosen_model == 'bert-base-multilingual-cased' : if chosen_model == 'bert' :
print('Loading Bert Tokenizer...') print('Loading Bert Tokenizer...')
tokenizer = BertTokenizer.from_pretrained(chosen_model, do_lower_case=True) tokenizer = BertTokenizer.from_pretrained(model_path, do_lower_case=True)
elif chosen_model == 'camembert-base': elif chosen_model == 'camembert':
print('Loading Camembert Tokenizer...') print('Loading Camembert Tokenizer...')
tokenizer = CamembertTokenizer.from_pretrained(chosen_model , do_lower_case=True) tokenizer = CamembertTokenizer.from_pretrained(model_path , do_lower_case=True)
...@@ -192,18 +192,18 @@ def training_bertFineTuning(chosen_model, sentences, labels, max_len, batch_si ...@@ -192,18 +192,18 @@ def training_bertFineTuning(chosen_model, sentences, labels, max_len, batch_si
# Load BertForSequenceClassification, the pretrained BERT model with a single # Load BertForSequenceClassification, the pretrained BERT model with a single
# linear classification layer on top. # linear classification layer on top.
if chosen_model == 'bert-base-multilingual-cased': if chosen_model == 'bert':
model = BertForSequenceClassification.from_pretrained( model = BertForSequenceClassification.from_pretrained(
chosen_model, # Use the 12-layer BERT model, with an uncased vocab. model_path, # Use the 12-layer BERT model, with an uncased vocab.
num_labels = numberOfClasses, # The number of output labels--2 for binary classification. num_labels = numberOfClasses, # The number of output labels--2 for binary classification.
# You can increase this for multi-class tasks. # You can increase this for multi-class tasks.
output_attentions = False, # Whether the model returns attentions weights. output_attentions = False, # Whether the model returns attentions weights.
output_hidden_states = False, # Whether the model returns all hidden-states. output_hidden_states = False, # Whether the model returns all hidden-states.
) )
elif chosen_model == 'camembert-base': elif chosen_model == 'camembert':
model = CamembertForSequenceClassification.from_pretrained( model = CamembertForSequenceClassification.from_pretrained(
chosen_model, # Use the 12-layer BERT model, with an uncased vocab. model_path, # Use the 12-layer BERT model, with an uncased vocab.
num_labels = numberOfClasses, # The number of output labels--2 for binary classification. num_labels = numberOfClasses, # The number of output labels--2 for binary classification.
# You can increase this for multi-class tasks. # You can increase this for multi-class tasks.
output_attentions = False, # Whether the model returns attentions weights. output_attentions = False, # Whether the model returns attentions weights.
...@@ -456,7 +456,7 @@ if __name__ == "__main__": ...@@ -456,7 +456,7 @@ if __name__ == "__main__":
minOfInstancePerClass = int(config.get('general','minOfInstancePerClass')) minOfInstancePerClass = int(config.get('general','minOfInstancePerClass'))
maxOfInstancePerClass = int(config.get('general','maxOfInstancePerClass')) maxOfInstancePerClass = int(config.get('general','maxOfInstancePerClass'))
chosen_tokeniser = config.get('model','tokeniser') model_path = config.get('model','path')
chosen_model = config.get('model','model') chosen_model = config.get('model','model')
max_len = int(config.get('model','max_len_sequences')) max_len = int(config.get('model','max_len_sequences'))
...@@ -484,7 +484,7 @@ if __name__ == "__main__": ...@@ -484,7 +484,7 @@ if __name__ == "__main__":
#call train method #call train method
model = training_bertFineTuning(chosen_model, sentences, labels, max_len, batch_size, epochs) model = training_bertFineTuning(chosen_model,model_path, sentences, labels, max_len, batch_size, epochs)
#save the model #save the model
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment