Skip to content
Snippets Groups Projects
Commit b9d052a7 authored by Khalleud's avatar Khalleud
Browse files

[FIX] update experimentationclassicclassifiers by saving and loading models

parent 780f702c
No related branches found
No related tags found
1 merge request!3Branch dev
......@@ -22,6 +22,8 @@ classifiers = [
param_grid_svm = {'C':[1,10,100,1000],'gamma':[1,0.1,0.001,0.0001], 'kernel':['linear','rbf']}
#param_grid_svm = {'C':[1,10],'gamma':[1], 'kernel':['linear','rbf']}
#param_grid_svm = [{'kernel': ['rbf'], 'gamma': [1e-3, 1e-4], 'C': [1, 10, 100, 1000]}, {'kernel': ['linear'], 'C': [1, 10, 100, 1000]}]
param_grid_decisionTree = { 'criterion' : ['gini', 'entropy'], 'max_depth':range(5,10), 'min_samples_split': range(5,10), 'min_samples_leaf': range(1,5) }
param_grid_rfc = { 'n_estimators': [200, 500], 'max_features': ['auto', 'sqrt', 'log2'], 'max_depth' : [4,5,6,7,8], 'criterion' :['gini', 'entropy'] }
param_grid_lr = {"C":np.logspace(-3,3,7), "penalty":["l1","l2"]}
......
......@@ -51,9 +51,8 @@ if not os.path.exists('models'):
# Reading data and preprocessings steps
preprocessor = Preprocessor()
df_original = pd.read_csv(dataPath)
df = pd.read_csv(dataPath)
df = df_original[[columnClass,columnText]].copy()
df = remove_weak_classes(df, columnClass, minOfInstancePerClass)
df = resample_classes(df, columnClass, maxOfInstancePerClass)
......@@ -73,7 +72,7 @@ for columnInput in [columnText, 'firstParagraph']:
print('Process: ' + columnInput)
extractor = feature_extractor(df,columnText, columnClass)
extractor = feature_extractor(df, columnInput, columnClass)
features_techniques = [
('counter', extractor.count_vect(max_df = vectorization_max_df, min_df = vectorization_min_df, numberOfFeatures = vectorization_numberOfFeatures )),
......@@ -95,19 +94,22 @@ for columnInput in [columnText, 'firstParagraph']:
clf_name, clf = tmp_clf
grid_param_name, grid_param = tmp_grid_params
print(clf_name, clf, grid_param_name, grid_param)
model_file_name = columnInput + '_' +feature_technique_name + '_' + clf_name+ str(minOfInstancePerClass) + '_' + str(maxOfInstancePerClass) +".pkl"
model_file_name = columnInput + '_' + feature_technique_name + '_' + clf_name + '_' + str(minOfInstancePerClass) + '_' + str(maxOfInstancePerClass) +".pkl"
if clf_name == 'bayes' :
if feature_technique_name == 'doc2vec':
continue
else:
t_begin = time.time()
# if model exist
if os.path.isfile(os.path.join('./model', model_file_name)):
with open(model_file_name, 'rb') as file:
if os.path.isfile(os.path.join('./models', model_file_name)):
print('trained model loaded')
with open(os.path.join('./models', model_file_name), 'rb') as file:
clf = pickle.load(file)
else:
print('model training')
#if model not exists we save
with open(Pkl_Filename, 'wb') as file:
with open(os.path.join('./models', model_file_name), 'wb') as file:
clf.fit(train_x, train_y)
pickle.dump(clf, file)
......@@ -121,11 +123,13 @@ for columnInput in [columnText, 'firstParagraph']:
clf = GridSearchCV(clf, grid_param, refit = True, verbose = 3)
t_begin = time.time()
if os.path.isfile(os.path.join('./model', model_file_name)):
with open(model_file_name, 'rb') as file:
if os.path.isfile(os.path.join('./models', model_file_name)):
print('trained model loaded')
with open(os.path.join('./models', model_file_name), 'rb') as file:
clf = pickle.load(file)
else:
with open(Pkl_Filename, 'wb') as file:
print('model training')
with open(os.path.join('./models', model_file_name), 'wb') as file:
clf.fit(train_x, train_y)
pickle.dump(clf, file)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment