Skip to content
Snippets Groups Projects
Commit d4f07c5f authored by Ludovic Moncla's avatar Ludovic Moncla
Browse files

Update training_bertFineTuning.py

parent 078c0596
No related branches found
No related tags found
No related merge requests found
......@@ -60,7 +60,7 @@ def format_time(elapsed):
return str(datetime.timedelta(seconds=elapsed_rounded))
def training_bertFineTuning(chosen_model, sentences, labels, max_len, batch_size, epochs = 4):
def training_bertFineTuning(chosen_model, model_path, sentences, labels, max_len, batch_size, epochs = 4):
# If there's a GPU available...
if torch.cuda.is_available():
......@@ -82,12 +82,12 @@ def training_bertFineTuning(chosen_model, sentences, labels, max_len, batch_si
###########################################################################################################
if chosen_model == 'bert-base-multilingual-cased' :
if chosen_model == 'bert' :
print('Loading Bert Tokenizer...')
tokenizer = BertTokenizer.from_pretrained(chosen_model, do_lower_case=True)
elif chosen_model == 'camembert-base':
tokenizer = BertTokenizer.from_pretrained(model_path, do_lower_case=True)
elif chosen_model == 'camembert':
print('Loading Camembert Tokenizer...')
tokenizer = CamembertTokenizer.from_pretrained(chosen_model , do_lower_case=True)
tokenizer = CamembertTokenizer.from_pretrained(model_path , do_lower_case=True)
......@@ -192,18 +192,18 @@ def training_bertFineTuning(chosen_model, sentences, labels, max_len, batch_si
# Load BertForSequenceClassification, the pretrained BERT model with a single
# linear classification layer on top.
if chosen_model == 'bert-base-multilingual-cased':
if chosen_model == 'bert':
model = BertForSequenceClassification.from_pretrained(
chosen_model, # Use the 12-layer BERT model, with an uncased vocab.
model_path, # Use the 12-layer BERT model, with an uncased vocab.
num_labels = numberOfClasses, # The number of output labels--2 for binary classification.
# You can increase this for multi-class tasks.
output_attentions = False, # Whether the model returns attentions weights.
output_hidden_states = False, # Whether the model returns all hidden-states.
)
elif chosen_model == 'camembert-base':
elif chosen_model == 'camembert':
model = CamembertForSequenceClassification.from_pretrained(
chosen_model, # Use the 12-layer BERT model, with an uncased vocab.
model_path, # Use the 12-layer BERT model, with an uncased vocab.
num_labels = numberOfClasses, # The number of output labels--2 for binary classification.
# You can increase this for multi-class tasks.
output_attentions = False, # Whether the model returns attentions weights.
......@@ -456,7 +456,7 @@ if __name__ == "__main__":
minOfInstancePerClass = int(config.get('general','minOfInstancePerClass'))
maxOfInstancePerClass = int(config.get('general','maxOfInstancePerClass'))
chosen_tokeniser = config.get('model','tokeniser')
model_path = config.get('model','path')
chosen_model = config.get('model','model')
max_len = int(config.get('model','max_len_sequences'))
......@@ -484,7 +484,7 @@ if __name__ == "__main__":
#call train method
model = training_bertFineTuning(chosen_model, sentences, labels, max_len, batch_size, epochs)
model = training_bertFineTuning(chosen_model,model_path, sentences, labels, max_len, batch_size, epochs)
#save the model
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment