From 2f7d5c052db22152c73c2529219fb85485ebdce2 Mon Sep 17 00:00:00 2001
From: Ludovic Moncla <ludovic.moncla@insa-lyon.fr>
Date: Mon, 20 Sep 2021 15:58:57 +0000
Subject: [PATCH] Update training_bertFineTuning.py

---
 training_bertFineTuning.py | 14 ++++++++------
 1 file changed, 8 insertions(+), 6 deletions(-)

diff --git a/training_bertFineTuning.py b/training_bertFineTuning.py
index dee7279..72cd70f 100644
--- a/training_bertFineTuning.py
+++ b/training_bertFineTuning.py
@@ -147,9 +147,9 @@ def training_bertFineTuning(chosen_model, model_path,  sentences, labels, max_le
 
 
     # Use 90% for training and 10% for validation.
-    train_inputs, validation_inputs, train_labels, validation_labels = train_test_split(padded, labels, random_state=2018, test_size=0.1, stratify = labels )
+    train_inputs, validation_inputs, train_labels, validation_labels = train_test_split(padded, labels, random_state=2018, test_size=0.3, stratify = labels )
     # Do the same for the masks.
-    train_masks, validation_masks, _, _ = train_test_split(attention_masks, labels, random_state=2018, test_size=0.1, stratify = labels)
+    train_masks, validation_masks, _, _ = train_test_split(attention_masks, labels, random_state=2018, test_size=0.3, stratify = labels)
 
 
     # Convert all inputs and labels into torch tensors, the required datatype
@@ -464,7 +464,7 @@ if __name__ == "__main__":
     epochs = int(config.get('model','epochs'))
 
 
-    df = pd.read_csv(INPUT_DATASET, sep="\t", quoting=csv.QUOTE_NONE)
+    df = pd.read_csv(INPUT_DATASET, sep="\t")
     df = remove_weak_classes(df, columnClass, minOfInstancePerClass)
     df = resample_classes(df, columnClass, maxOfInstancePerClass)
     #df = df[df[columnClass] != 'unclassified']
@@ -476,10 +476,12 @@ if __name__ == "__main__":
     y = encoder.fit_transform(y)
 
 
-    train_x, test_x, train_y, test_y = train_test_split(df, y, test_size=0.33, random_state=42, stratify = y )
+    #train_x, test_x, train_y, test_y = train_test_split(df, y, test_size=0.33, random_state=42, stratify = y )
 
-    sentences = train_x[columnText].values
-    labels = train_y.tolist()
+    #sentences = train_x[columnText].values
+    sentences = df[columnText].values
+    #labels = train_y.tolist()
+    labels = y.tolist()
 
 
     #call train method
-- 
GitLab