diff --git a/training_bertFineTuning.py b/training_bertFineTuning.py index 6f1381e22b02bc627fc623f332b2b1e5dd609ac4..662647865ff5e42ae8db2b9343c83299336d3a7d 100644 --- a/training_bertFineTuning.py +++ b/training_bertFineTuning.py @@ -2,6 +2,7 @@ import torch import pandas as pd import numpy as np from sklearn import preprocessing +from sklearn.model_selection import train_test_split from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler from transformers import BertTokenizer, CamembertTokenizer from transformers import BertForSequenceClassification, AdamW, BertConfig, CamembertForSequenceClassification @@ -11,6 +12,35 @@ import datetime import random import os import argparse +import configparser + + + + +def create_dict(df, classColumnName): + return dict(df[classColumnName].value_counts()) + + +def remove_weak_classes(df, classColumnName, threshold): + + dictOfClassInstances = create_dict(df,classColumnName) + + + dictionary = {k: v for k, v in dictOfClassInstances.items() if v >= threshold } + keys = [*dictionary] + df_tmp = df[~ df[classColumnName].isin(keys)] + df = pd.concat([df,df_tmp]).drop_duplicates(keep=False) + return df + + +def resample_classes(df, classColumnName, numberOfInstances): + + #random numberOfInstances elements + replace = False # with replacement + + fn = lambda obj: obj.loc[np.random.choice(obj.index, numberOfInstances if len(obj) > numberOfInstances else len(obj), replace),:] + return df.groupby(classColumnName, as_index=False).apply(fn) + def flat_accuracy(preds, labels):