Skip to content
Snippets Groups Projects
Commit 2f7d5c05 authored by Ludovic Moncla's avatar Ludovic Moncla
Browse files

Update training_bertFineTuning.py

parent c443f187
No related branches found
No related tags found
No related merge requests found
...@@ -147,9 +147,9 @@ def training_bertFineTuning(chosen_model, model_path, sentences, labels, max_le ...@@ -147,9 +147,9 @@ def training_bertFineTuning(chosen_model, model_path, sentences, labels, max_le
# Use 90% for training and 10% for validation. # Use 90% for training and 10% for validation.
train_inputs, validation_inputs, train_labels, validation_labels = train_test_split(padded, labels, random_state=2018, test_size=0.1, stratify = labels ) train_inputs, validation_inputs, train_labels, validation_labels = train_test_split(padded, labels, random_state=2018, test_size=0.3, stratify = labels )
# Do the same for the masks. # Do the same for the masks.
train_masks, validation_masks, _, _ = train_test_split(attention_masks, labels, random_state=2018, test_size=0.1, stratify = labels) train_masks, validation_masks, _, _ = train_test_split(attention_masks, labels, random_state=2018, test_size=0.3, stratify = labels)
# Convert all inputs and labels into torch tensors, the required datatype # Convert all inputs and labels into torch tensors, the required datatype
...@@ -464,7 +464,7 @@ if __name__ == "__main__": ...@@ -464,7 +464,7 @@ if __name__ == "__main__":
epochs = int(config.get('model','epochs')) epochs = int(config.get('model','epochs'))
df = pd.read_csv(INPUT_DATASET, sep="\t", quoting=csv.QUOTE_NONE) df = pd.read_csv(INPUT_DATASET, sep="\t")
df = remove_weak_classes(df, columnClass, minOfInstancePerClass) df = remove_weak_classes(df, columnClass, minOfInstancePerClass)
df = resample_classes(df, columnClass, maxOfInstancePerClass) df = resample_classes(df, columnClass, maxOfInstancePerClass)
#df = df[df[columnClass] != 'unclassified'] #df = df[df[columnClass] != 'unclassified']
...@@ -476,10 +476,12 @@ if __name__ == "__main__": ...@@ -476,10 +476,12 @@ if __name__ == "__main__":
y = encoder.fit_transform(y) y = encoder.fit_transform(y)
train_x, test_x, train_y, test_y = train_test_split(df, y, test_size=0.33, random_state=42, stratify = y ) #train_x, test_x, train_y, test_y = train_test_split(df, y, test_size=0.33, random_state=42, stratify = y )
sentences = train_x[columnText].values #sentences = train_x[columnText].values
labels = train_y.tolist() sentences = df[columnText].values
#labels = train_y.tolist()
labels = y.tolist()
#call train method #call train method
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment