From 780f702c5a065d96a2480aa31d46bb16aa205259 Mon Sep 17 00:00:00 2001
From: Khalleud <ledk14@gmail.com>
Date: Mon, 7 Jun 2021 13:32:47 +0200
Subject: [PATCH] [ADD] save and load trained models

---
 experimentsClassicClassifiers.py | 33 ++++++++++++++++++++++++++++----
 1 file changed, 29 insertions(+), 4 deletions(-)

diff --git a/experimentsClassicClassifiers.py b/experimentsClassicClassifiers.py
index eb2aa0f..c6a9f72 100644
--- a/experimentsClassicClassifiers.py
+++ b/experimentsClassicClassifiers.py
@@ -12,6 +12,7 @@ from sklearn import preprocessing
 from evaluate_model import evaluate_model
 from sklearn.model_selection import GridSearchCV
 import configparser
+import pickle
 
 import nltk
 nltk.download('stopwords')
@@ -33,7 +34,7 @@ maxOfInstancePerClass = args.maxOfInstancePerClass
 
 if not os.path.exists('reports'):
     os.makedirs('reports')
-    
+
 if not os.path.exists(os.path.join('reports',  columnClass)):
     os.makedirs(os.path.join('reports', columnClass))
 
@@ -42,6 +43,11 @@ dir_name_report = str(minOfInstancePerClass) + '_' + str(maxOfInstancePerClass)
 if not os.path.exists(os.path.join('reports',  columnClass, dir_name_report)):
     os.makedirs(os.path.join('reports', columnClass, dir_name_report))
 
+
+# create directory to save and load models
+if not os.path.exists('models'):
+    os.makedirs('models')
+
 # Reading data and preprocessings steps
 preprocessor = Preprocessor()
 
@@ -89,22 +95,42 @@ for columnInput in [columnText, 'firstParagraph']:
             clf_name, clf = tmp_clf
             grid_param_name, grid_param = tmp_grid_params
             print(clf_name, clf, grid_param_name, grid_param)
+            model_file_name = columnInput + '_' +feature_technique_name + '_' + clf_name+ str(minOfInstancePerClass) + '_' + str(maxOfInstancePerClass) +".pkl"
             if clf_name == 'bayes' :
                 if feature_technique_name == 'doc2vec':
                     continue
                 else:
                     t_begin = time.time()
-                    clf.fit(train_x, train_y)
+                    # if model exist
+                    if os.path.isfile(os.path.join('./model', model_file_name)):
+                        with open(model_file_name, 'rb') as file:
+                            clf = pickle.load(file)
+                    else:
+                        #if model not exists we save
+                        with open(Pkl_Filename, 'wb') as file:
+                            clf.fit(train_x, train_y)
+                            pickle.dump(clf, file)
+
                     t_end =time.time()
                     training_time = t_end - t_begin
 
                     y_pred = clf.predict(test_x)
 
             else :
+
                 clf = GridSearchCV(clf, grid_param, refit = True, verbose = 3)
                 t_begin = time.time()
-                clf.fit(train_x, train_y)
+
+                if os.path.isfile(os.path.join('./model', model_file_name)):
+                    with open(model_file_name, 'rb') as file:
+                        clf = pickle.load(file)
+                else:
+                    with open(Pkl_Filename, 'wb') as file:
+                        clf.fit(train_x, train_y)
+                        pickle.dump(clf, file)
+
                 t_end =time.time()
+
                 training_time = t_end - t_begin
 
                 y_pred = clf.predict(test_x)
@@ -126,4 +152,3 @@ for columnInput in [columnText, 'firstParagraph']:
                 print('training time   : {}'.format(training_time))
                 #sys.stdout = sys.stdout # Reset the standard output to its original value
                 sys.stdout = sys.__stdout__
-
-- 
GitLab