From 6fe99473a48eb1b802bf51e24d5f44ccae3766fa Mon Sep 17 00:00:00 2001 From: rtalbi <dr_talbi@esi.dz> Date: Mon, 20 Dec 2021 12:54:51 +0100 Subject: [PATCH] non-privacy presrerving neural networks (finished code, started debug) --- ML/NN/NN.cpp | 49 +++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 41 insertions(+), 8 deletions(-) diff --git a/ML/NN/NN.cpp b/ML/NN/NN.cpp index 371c8f48..786da49e 100644 --- a/ML/NN/NN.cpp +++ b/ML/NN/NN.cpp @@ -221,34 +221,67 @@ void NN::train () // } -//todo : transform this method so that prediction happens for fa test batch void NN::Test( ){ int counter =0; + int recordCounter = 0; int size= dt->test_size; Record * record; + vector<Record*> XB; int label; - + int sizeBatch=batchSize; std::ofstream classOutput; classOutput.open (logfile); extTestBd = 0; auto begin = chrono::high_resolution_clock::now(); + + while (counter < size) { - counter++; + if (size - counter < batchSize) + sizeBatch = size - counter; - label = predict(record,true); - if(classOutput.is_open()) - { - classOutput<<label<< endl; + for (recordCounter = 0; recordCounter < sizeBatch; recordCounter++) { + try { + + record = dt->getTrainRecord(); + XB.push_back(record); + extTrainBd += record->values.size() + 1; + counter++; + } + catch (std::exception const &e) { + cout << e.what() << endl; + } + + } + + vector<int> prediction = predict(XB, true); + + for (int k=0; k < prediction.size();k++) { + int label = prediction[k]; + + if (classOutput.is_open()) { + classOutput << label << endl; + + } + } + + + for (int i = 0; i < XB.size(); i++) { + delete XB[i]; } - delete record; + + XB.clear(); + + } + + auto end = chrono::high_resolution_clock::now(); std::chrono::duration<double, std::milli> duration = end - begin ; -- GitLab