diff --git a/ML/NN/NN.cpp b/ML/NN/NN.cpp index 371c8f486be694906a99330b8ea0f39e36cc2057..786da49e0a8bd8f58796431ad92e519b436168a0 100644 --- a/ML/NN/NN.cpp +++ b/ML/NN/NN.cpp @@ -221,34 +221,67 @@ void NN::train () // } -//todo : transform this method so that prediction happens for fa test batch void NN::Test( ){ int counter =0; + int recordCounter = 0; int size= dt->test_size; Record * record; + vector<Record*> XB; int label; - + int sizeBatch=batchSize; std::ofstream classOutput; classOutput.open (logfile); extTestBd = 0; auto begin = chrono::high_resolution_clock::now(); + + while (counter < size) { - counter++; + if (size - counter < batchSize) + sizeBatch = size - counter; - label = predict(record,true); - if(classOutput.is_open()) - { - classOutput<<label<< endl; + for (recordCounter = 0; recordCounter < sizeBatch; recordCounter++) { + try { + + record = dt->getTrainRecord(); + XB.push_back(record); + extTrainBd += record->values.size() + 1; + counter++; + } + catch (std::exception const &e) { + cout << e.what() << endl; + } + + } + + vector<int> prediction = predict(XB, true); + + for (int k=0; k < prediction.size();k++) { + int label = prediction[k]; + + if (classOutput.is_open()) { + classOutput << label << endl; + + } + } + + + for (int i = 0; i < XB.size(); i++) { + delete XB[i]; } - delete record; + + XB.clear(); + + } + + auto end = chrono::high_resolution_clock::now(); std::chrono::duration<double, std::milli> duration = end - begin ;