From 2f7d5c052db22152c73c2529219fb85485ebdce2 Mon Sep 17 00:00:00 2001 From: Ludovic Moncla <ludovic.moncla@insa-lyon.fr> Date: Mon, 20 Sep 2021 15:58:57 +0000 Subject: [PATCH] Update training_bertFineTuning.py --- training_bertFineTuning.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/training_bertFineTuning.py b/training_bertFineTuning.py index dee7279..72cd70f 100644 --- a/training_bertFineTuning.py +++ b/training_bertFineTuning.py @@ -147,9 +147,9 @@ def training_bertFineTuning(chosen_model, model_path, sentences, labels, max_le # Use 90% for training and 10% for validation. - train_inputs, validation_inputs, train_labels, validation_labels = train_test_split(padded, labels, random_state=2018, test_size=0.1, stratify = labels ) + train_inputs, validation_inputs, train_labels, validation_labels = train_test_split(padded, labels, random_state=2018, test_size=0.3, stratify = labels ) # Do the same for the masks. - train_masks, validation_masks, _, _ = train_test_split(attention_masks, labels, random_state=2018, test_size=0.1, stratify = labels) + train_masks, validation_masks, _, _ = train_test_split(attention_masks, labels, random_state=2018, test_size=0.3, stratify = labels) # Convert all inputs and labels into torch tensors, the required datatype @@ -464,7 +464,7 @@ if __name__ == "__main__": epochs = int(config.get('model','epochs')) - df = pd.read_csv(INPUT_DATASET, sep="\t", quoting=csv.QUOTE_NONE) + df = pd.read_csv(INPUT_DATASET, sep="\t") df = remove_weak_classes(df, columnClass, minOfInstancePerClass) df = resample_classes(df, columnClass, maxOfInstancePerClass) #df = df[df[columnClass] != 'unclassified'] @@ -476,10 +476,12 @@ if __name__ == "__main__": y = encoder.fit_transform(y) - train_x, test_x, train_y, test_y = train_test_split(df, y, test_size=0.33, random_state=42, stratify = y ) + #train_x, test_x, train_y, test_y = train_test_split(df, y, test_size=0.33, random_state=42, stratify = y ) - sentences = train_x[columnText].values - labels = train_y.tolist() + #sentences = train_x[columnText].values + sentences = df[columnText].values + #labels = train_y.tolist() + labels = y.tolist() #call train method -- GitLab