From c9c23588453d221a0520eb1ad6b757b5006d9620 Mon Sep 17 00:00:00 2001
From: lmoncla <ludovic.moncla@insa-lyon.fr>
Date: Fri, 17 Sep 2021 11:46:03 +0200
Subject: [PATCH] update

---
 training_bertFineTuning.py | 30 ++++++++++++++++++++++++++++++
 1 file changed, 30 insertions(+)

diff --git a/training_bertFineTuning.py b/training_bertFineTuning.py
index 6f1381e..6626478 100644
--- a/training_bertFineTuning.py
+++ b/training_bertFineTuning.py
@@ -2,6 +2,7 @@ import torch
 import pandas as pd
 import numpy as np
 from sklearn import preprocessing
+from sklearn.model_selection import train_test_split
 from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
 from transformers import BertTokenizer, CamembertTokenizer
 from transformers import BertForSequenceClassification, AdamW, BertConfig, CamembertForSequenceClassification
@@ -11,6 +12,35 @@ import datetime
 import random
 import os
 import argparse
+import configparser
+
+
+
+
+def create_dict(df, classColumnName):
+    return dict(df[classColumnName].value_counts())
+
+
+def remove_weak_classes(df, classColumnName, threshold):
+
+    dictOfClassInstances = create_dict(df,classColumnName)
+
+
+    dictionary = {k: v for k, v in dictOfClassInstances.items() if v >= threshold }
+    keys = [*dictionary]
+    df_tmp = df[~ df[classColumnName].isin(keys)]
+    df =  pd.concat([df,df_tmp]).drop_duplicates(keep=False)
+    return df
+
+
+def resample_classes(df, classColumnName, numberOfInstances):
+
+    #random numberOfInstances elements
+    replace = False  # with replacement
+
+    fn = lambda obj: obj.loc[np.random.choice(obj.index, numberOfInstances if len(obj) > numberOfInstances else len(obj), replace),:]
+    return df.groupby(classColumnName, as_index=False).apply(fn)
+
 
 
 def flat_accuracy(preds, labels):
-- 
GitLab