From 6fe99473a48eb1b802bf51e24d5f44ccae3766fa Mon Sep 17 00:00:00 2001
From: rtalbi <dr_talbi@esi.dz>
Date: Mon, 20 Dec 2021 12:54:51 +0100
Subject: [PATCH] non-privacy presrerving neural networks (finished code,
 started debug)

---
 ML/NN/NN.cpp | 49 +++++++++++++++++++++++++++++++++++++++++--------
 1 file changed, 41 insertions(+), 8 deletions(-)

diff --git a/ML/NN/NN.cpp b/ML/NN/NN.cpp
index 371c8f48..786da49e 100644
--- a/ML/NN/NN.cpp
+++ b/ML/NN/NN.cpp
@@ -221,34 +221,67 @@ void NN::train () //
 
 }
 
-//todo : transform this method so that prediction happens for fa test batch
 void NN::Test( ){
 
     int counter =0;
+    int recordCounter = 0;
     int size= dt->test_size;
     Record * record;
+    vector<Record*> XB;
     int label;
-
+    int sizeBatch=batchSize;
     std::ofstream classOutput;
     classOutput.open (logfile);
 
     extTestBd = 0;
     auto begin = chrono::high_resolution_clock::now();
+
+
     while (counter < size) {
 
 
-        counter++;
+        if (size - counter < batchSize)
+            sizeBatch = size - counter;
 
-        label = predict(record,true);
 
-        if(classOutput.is_open())
-        {
-            classOutput<<label<< endl;
+        for (recordCounter = 0; recordCounter < sizeBatch; recordCounter++) {
+            try {
+
+                record = dt->getTrainRecord();
+                XB.push_back(record);
+                extTrainBd += record->values.size() + 1;
+                counter++;
+            }
+            catch (std::exception const &e) {
+                cout << e.what() << endl;
+            }
+
+        }
+
 
+        vector<int> prediction = predict(XB, true);
+
+        for (int k=0; k < prediction.size();k++) {
+            int label  = prediction[k];
+
+            if (classOutput.is_open()) {
+                classOutput << label << endl;
+
+            }
+        }
+
+
+        for (int i = 0; i < XB.size(); i++) {
+            delete XB[i];
         }
-        delete record;
+
+        XB.clear();
+
+
     }
 
+
+
     auto end = chrono::high_resolution_clock::now();
 
     std::chrono::duration<double, std::milli> duration = end  - begin ;
-- 
GitLab