Skip to content
Snippets Groups Projects
Commit 04dd1c35 authored by Ludovic Moncla's avatar Ludovic Moncla
Browse files

Merge branch 'branch_v1' into 'master'

[ADD] save and load trained models

See merge request !2
parents 9e575a9d 780f702c
No related branches found
No related tags found
1 merge request!2[ADD] save and load trained models
......@@ -12,6 +12,7 @@ from sklearn import preprocessing
from evaluate_model import evaluate_model
from sklearn.model_selection import GridSearchCV
import configparser
import pickle
import nltk
nltk.download('stopwords')
......@@ -33,7 +34,7 @@ maxOfInstancePerClass = args.maxOfInstancePerClass
if not os.path.exists('reports'):
os.makedirs('reports')
if not os.path.exists(os.path.join('reports', columnClass)):
os.makedirs(os.path.join('reports', columnClass))
......@@ -42,6 +43,11 @@ dir_name_report = str(minOfInstancePerClass) + '_' + str(maxOfInstancePerClass)
if not os.path.exists(os.path.join('reports', columnClass, dir_name_report)):
os.makedirs(os.path.join('reports', columnClass, dir_name_report))
# create directory to save and load models
if not os.path.exists('models'):
os.makedirs('models')
# Reading data and preprocessings steps
preprocessor = Preprocessor()
......@@ -89,22 +95,42 @@ for columnInput in [columnText, 'firstParagraph']:
clf_name, clf = tmp_clf
grid_param_name, grid_param = tmp_grid_params
print(clf_name, clf, grid_param_name, grid_param)
model_file_name = columnInput + '_' +feature_technique_name + '_' + clf_name+ str(minOfInstancePerClass) + '_' + str(maxOfInstancePerClass) +".pkl"
if clf_name == 'bayes' :
if feature_technique_name == 'doc2vec':
continue
else:
t_begin = time.time()
clf.fit(train_x, train_y)
# if model exist
if os.path.isfile(os.path.join('./model', model_file_name)):
with open(model_file_name, 'rb') as file:
clf = pickle.load(file)
else:
#if model not exists we save
with open(Pkl_Filename, 'wb') as file:
clf.fit(train_x, train_y)
pickle.dump(clf, file)
t_end =time.time()
training_time = t_end - t_begin
y_pred = clf.predict(test_x)
else :
clf = GridSearchCV(clf, grid_param, refit = True, verbose = 3)
t_begin = time.time()
clf.fit(train_x, train_y)
if os.path.isfile(os.path.join('./model', model_file_name)):
with open(model_file_name, 'rb') as file:
clf = pickle.load(file)
else:
with open(Pkl_Filename, 'wb') as file:
clf.fit(train_x, train_y)
pickle.dump(clf, file)
t_end =time.time()
training_time = t_end - t_begin
y_pred = clf.predict(test_x)
......@@ -126,4 +152,3 @@ for columnInput in [columnText, 'firstParagraph']:
print('training time : {}'.format(training_time))
#sys.stdout = sys.stdout # Reset the standard output to its original value
sys.stdout = sys.__stdout__
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment