Update to 2.0.0 tree from current Fremantle build
[opencv] / tests / ml / src / mltests.cpp
diff --git a/tests/ml/src/mltests.cpp b/tests/ml/src/mltests.cpp
new file mode 100644 (file)
index 0000000..33db951
--- /dev/null
@@ -0,0 +1,785 @@
+/*M///////////////////////////////////////////////////////////////////////////////////////\r
+//\r
+//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.\r
+//\r
+//  By downloading, copying, installing or using the software you agree to this license.\r
+//  If you do not agree to this license, do not download, install,\r
+//  copy or use the software.\r
+//\r
+//\r
+//                        Intel License Agreement\r
+//                For Open Source Computer Vision Library\r
+//\r
+// Copyright (C) 2000, Intel Corporation, all rights reserved.\r
+// Third party copyrights are property of their respective owners.\r
+//\r
+// Redistribution and use in source and binary forms, with or without modification,\r
+// are permitted provided that the following conditions are met:\r
+//\r
+//   * Redistribution's of source code must retain the above copyright notice,\r
+//     this list of conditions and the following disclaimer.\r
+//\r
+//   * Redistribution's in binary form must reproduce the above copyright notice,\r
+//     this list of conditions and the following disclaimer in the documentation\r
+//     and/or other materials provided with the distribution.\r
+//\r
+//   * The name of Intel Corporation may not be used to endorse or promote products\r
+//     derived from this software without specific prior written permission.\r
+//\r
+// This software is provided by the copyright holders and contributors "as is" and\r
+// any express or implied warranties, including, but not limited to, the implied\r
+// warranties of merchantability and fitness for a particular purpose are disclaimed.\r
+// In no event shall the Intel Corporation or contributors be liable for any direct,\r
+// indirect, incidental, special, exemplary, or consequential damages\r
+// (including, but not limited to, procurement of substitute goods or services;\r
+// loss of use, data, or profits; or business interruption) however caused\r
+// and on any theory of liability, whether in contract, strict liability,\r
+// or tort (including negligence or otherwise) arising in any way out of\r
+// the use of this software, even if advised of the possibility of such damage.\r
+//\r
+//M*/\r
+\r
+#include "mltest.h"\r
+\r
+// auxiliary functions\r
+// 1. nbayes\r
+void nbayes_check_data( CvMLData* _data )
+{
+    if( _data->get_missing() )
+        CV_Error( CV_StsBadArg, "missing values are not supported" );
+    const CvMat* var_types = _data->get_var_types();
+    bool is_classifier = var_types->data.ptr[var_types->cols-1] == CV_VAR_CATEGORICAL;
+    if( ( fabs( cvNorm( var_types, 0, CV_L1 ) - \r
+        (var_types->rows + var_types->cols - 2)*CV_VAR_ORDERED - CV_VAR_CATEGORICAL ) > FLT_EPSILON ) ||\r
+        !is_classifier )\r
+        CV_Error( CV_StsBadArg, "incorrect types of predictors or responses" );
+}
+bool nbayes_train( CvNormalBayesClassifier* nbayes, CvMLData* _data )
+{
+    nbayes_check_data( _data );\r
+    const CvMat* values = _data->get_values();\r
+    const CvMat* responses = _data->get_responses();\r
+    const CvMat* train_sidx = _data->get_train_sample_idx();\r
+    const CvMat* var_idx = _data->get_var_idx();\r
+    return nbayes->train( values, responses, var_idx, train_sidx );\r
+}
+float nbayes_calc_error( CvNormalBayesClassifier* nbayes, CvMLData* _data, int type, vector<float> *resp )\r
+{\r
+    float err = 0;\r
+    nbayes_check_data( _data );\r
+    const CvMat* values = _data->get_values();
+    const CvMat* response = _data->get_responses();
+    const CvMat* sample_idx = (type == CV_TEST_ERROR) ? _data->get_test_sample_idx() : _data->get_train_sample_idx();\r
+    int* sidx = sample_idx ? sample_idx->data.i : 0;\r
+    int r_step = CV_IS_MAT_CONT(response->type) ?\r
+        1 : response->step / CV_ELEM_SIZE(response->type);\r
+    int sample_count = sample_idx ? sample_idx->cols : 0;\r
+    sample_count = (type == CV_TRAIN_ERROR && sample_count == 0) ? values->rows : sample_count;\r
+    float* pred_resp = 0;\r
+    if( resp && (sample_count > 0) )\r
+    {\r
+        resp->resize( sample_count );\r
+        pred_resp = &((*resp)[0]);\r
+    }\r
+\r
+    for( int i = 0; i < sample_count; i++ )\r
+    {\r
+        CvMat sample;\r
+        int si = sidx ? sidx[i] : i;\r
+        cvGetRow( values, &sample, si ); \r
+        float r = (float)nbayes->predict( &sample, 0 );\r
+        if( pred_resp )\r
+            pred_resp[i] = r;\r
+        int d = fabs((double)r - response->data.fl[si*r_step]) <= FLT_EPSILON ? 0 : 1;\r
+        err += d;\r
+    }\r
+    err = sample_count ? err / (float)sample_count * 100 : -FLT_MAX;\r
+    return err;\r
+}\r
+\r
+// 2. knearest\r
+void knearest_check_data_and_get_predictors( CvMLData* _data, CvMat* _predictors )
+{
+    const CvMat* values = _data->get_values();
+    const CvMat* var_idx = _data->get_var_idx();
+    if( var_idx->cols + var_idx->rows != values->cols )\r
+        CV_Error( CV_StsBadArg, "var_idx is not supported" );\r
+    if( _data->get_missing() )\r
+        CV_Error( CV_StsBadArg, "missing values are not supported" );\r
+    int resp_idx = _data->get_response_idx();\r
+    if( resp_idx == 0)\r
+        cvGetCols( values, _predictors, 1, values->cols );\r
+    else if( resp_idx == values->cols - 1 )\r
+        cvGetCols( values, _predictors, 0, values->cols - 1 );\r
+    else\r
+        CV_Error( CV_StsBadArg, "responses must be in the first or last column; other cases are not supported" );\r
+}
+bool knearest_train( CvKNearest* knearest, CvMLData* _data )
+{
+    const CvMat* responses = _data->get_responses();\r
+    const CvMat* train_sidx = _data->get_train_sample_idx();\r
+    bool is_regression = _data->get_var_type( _data->get_response_idx() ) == CV_VAR_ORDERED;\r
+    CvMat predictors;\r
+    knearest_check_data_and_get_predictors( _data, &predictors );\r
+    return knearest->train( &predictors, responses, train_sidx, is_regression );
+}
+float knearest_calc_error( CvKNearest* knearest, CvMLData* _data, int k, int type, vector<float> *resp )
+{
+    float err = 0;\r
+    const CvMat* response = _data->get_responses();\r
+    const CvMat* sample_idx = (type == CV_TEST_ERROR) ? _data->get_test_sample_idx() : _data->get_train_sample_idx();\r
+    int* sidx = sample_idx ? sample_idx->data.i : 0;\r
+    int r_step = CV_IS_MAT_CONT(response->type) ?\r
+        1 : response->step / CV_ELEM_SIZE(response->type);\r
+    bool is_regression = _data->get_var_type( _data->get_response_idx() ) == CV_VAR_ORDERED;\r
+    CvMat predictors;\r
+    knearest_check_data_and_get_predictors( _data, &predictors );\r
+    int sample_count = sample_idx ? sample_idx->cols : 0;\r
+    sample_count = (type == CV_TRAIN_ERROR && sample_count == 0) ? predictors.rows : sample_count;\r
+    float* pred_resp = 0;\r
+    if( resp && (sample_count > 0) )\r
+    {\r
+        resp->resize( sample_count );\r
+        pred_resp = &((*resp)[0]);\r
+    }\r
+    if ( !is_regression )\r
+    {\r
+        for( int i = 0; i < sample_count; i++ )\r
+        {\r
+            CvMat sample;\r
+            int si = sidx ? sidx[i] : i;\r
+            cvGetRow( &predictors, &sample, si ); \r
+            float r = knearest->find_nearest( &sample, k );\r
+            if( pred_resp )\r
+                pred_resp[i] = r;\r
+            int d = fabs((double)r - response->data.fl[si*r_step]) <= FLT_EPSILON ? 0 : 1;\r
+            err += d;\r
+        }\r
+        err = sample_count ? err / (float)sample_count * 100 : -FLT_MAX;\r
+    }\r
+    else\r
+    {\r
+        for( int i = 0; i < sample_count; i++ )\r
+        {\r
+            CvMat sample;\r
+            int si = sidx ? sidx[i] : i;\r
+            cvGetRow( &predictors, &sample, si ); \r
+            float r = knearest->find_nearest( &sample, k );\r
+            if( pred_resp )\r
+                pred_resp[i] = r;\r
+            float d = r - response->data.fl[si*r_step];\r
+            err += d*d;\r
+        }\r
+        err = sample_count ? err / (float)sample_count : -FLT_MAX;    \r
+    }\r
+    return err;
+}\r
+\r
+// 3. svm\r
+int str_to_svm_type(string& str)\r
+{\r
+    if( !str.compare("C_SVC") )\r
+        return CvSVM::C_SVC;\r
+    if( !str.compare("NU_SVC") )\r
+        return CvSVM::NU_SVC;\r
+    if( !str.compare("ONE_CLASS") )\r
+        return CvSVM::ONE_CLASS;\r
+    if( !str.compare("EPS_SVR") )\r
+        return CvSVM::EPS_SVR;\r
+    if( !str.compare("NU_SVR") )\r
+        return CvSVM::NU_SVR;\r
+    CV_Error( CV_StsBadArg, "incorrect svm type string" );\r
+    return -1;\r
+}\r
+int str_to_svm_kernel_type( string& str )\r
+{\r
+    if( !str.compare("LINEAR") )\r
+        return CvSVM::LINEAR;\r
+    if( !str.compare("POLY") )\r
+        return CvSVM::POLY;\r
+    if( !str.compare("RBF") )\r
+        return CvSVM::RBF;\r
+    if( !str.compare("SIGMOID") )\r
+        return CvSVM::SIGMOID;\r
+    CV_Error( CV_StsBadArg, "incorrect svm type string" );\r
+    return -1;\r
+}\r
+void svm_check_data( CvMLData* _data )
+{
+    if( _data->get_missing() )
+        CV_Error( CV_StsBadArg, "missing values are not supported" );
+    const CvMat* var_types = _data->get_var_types();
+    for( int i = 0; i < var_types->cols-1; i++ )
+        if (var_types->data.ptr[i] == CV_VAR_CATEGORICAL)
+        {
+            char msg[50];
+            sprintf( msg, "incorrect type of %d-predictor", i );
+            CV_Error( CV_StsBadArg, msg );
+        }
+}
+bool svm_train( CvSVM* svm, CvMLData* _data, CvSVMParams _params )
+{
+    svm_check_data(_data);
+    const CvMat* _train_data = _data->get_values();
+    const CvMat* _responses = _data->get_responses();
+    const CvMat* _var_idx = _data->get_var_idx();
+    const CvMat* _sample_idx = _data->get_train_sample_idx();
+    return svm->train( _train_data, _responses, _var_idx, _sample_idx, _params );
+}
+bool svm_train_auto( CvSVM* svm, CvMLData* _data, CvSVMParams _params,
+                    int k_fold, CvParamGrid C_grid, CvParamGrid gamma_grid,
+                    CvParamGrid p_grid, CvParamGrid nu_grid, CvParamGrid coef_grid,
+                    CvParamGrid degree_grid )
+{
+    svm_check_data(_data);
+    const CvMat* _train_data = _data->get_values();
+    const CvMat* _responses = _data->get_responses();
+    const CvMat* _var_idx = _data->get_var_idx();
+    const CvMat* _sample_idx = _data->get_train_sample_idx();
+    return svm->train_auto( _train_data, _responses, _var_idx, 
+        _sample_idx, _params, k_fold, C_grid, gamma_grid, p_grid, nu_grid, coef_grid, degree_grid );
+}
+float svm_calc_error( CvSVM* svm, CvMLData* _data, int type, vector<float> *resp )
+{
+    svm_check_data(_data);
+    float err = 0;\r
+    const CvMat* values = _data->get_values();\r
+    const CvMat* response = _data->get_responses();\r
+    const CvMat* sample_idx = (type == CV_TEST_ERROR) ? _data->get_test_sample_idx() : _data->get_train_sample_idx();\r
+    const CvMat* var_types = _data->get_var_types();\r
+    int* sidx = sample_idx ? sample_idx->data.i : 0;\r
+    int r_step = CV_IS_MAT_CONT(response->type) ?\r
+        1 : response->step / CV_ELEM_SIZE(response->type);\r
+    bool is_classifier = var_types->data.ptr[var_types->cols-1] == CV_VAR_CATEGORICAL;\r
+    int sample_count = sample_idx ? sample_idx->cols : 0;\r
+    sample_count = (type == CV_TRAIN_ERROR && sample_count == 0) ? values->rows : sample_count;\r
+    float* pred_resp = 0;\r
+    if( resp && (sample_count > 0) )\r
+    {\r
+        resp->resize( sample_count );\r
+        pred_resp = &((*resp)[0]);\r
+    }\r
+    if ( is_classifier )\r
+    {\r
+        for( int i = 0; i < sample_count; i++ )\r
+        {\r
+            CvMat sample;\r
+            int si = sidx ? sidx[i] : i;\r
+            cvGetRow( values, &sample, si ); \r
+            float r = svm->predict( &sample );\r
+            if( pred_resp )\r
+                pred_resp[i] = r;\r
+            int d = fabs((double)r - response->data.fl[si*r_step]) <= FLT_EPSILON ? 0 : 1;\r
+            err += d;\r
+        }\r
+        err = sample_count ? err / (float)sample_count * 100 : -FLT_MAX;\r
+    }\r
+    else\r
+    {\r
+        for( int i = 0; i < sample_count; i++ )\r
+        {\r
+            CvMat sample;\r
+            int si = sidx ? sidx[i] : i;\r
+            cvGetRow( values, &sample, si );\r
+            float r = svm->predict( &sample );\r
+            if( pred_resp )\r
+                pred_resp[i] = r;\r
+            float d = r - response->data.fl[si*r_step];\r
+            err += d*d;\r
+        }\r
+        err = sample_count ? err / (float)sample_count : -FLT_MAX;    \r
+    }\r
+    return err;
+}\r
+\r
+// 4. em\r
+// 5. ann\r
+int str_to_ann_train_method( string& str )\r
+{\r
+    if( !str.compare("BACKPROP") )\r
+        return CvANN_MLP_TrainParams::BACKPROP;\r
+    if( !str.compare("RPROP") )\r
+        return CvANN_MLP_TrainParams::RPROP;\r
+    CV_Error( CV_StsBadArg, "incorrect ann train method string" );\r
+    return -1;\r
+}\r
+void ann_check_data_and_get_predictors( CvMLData* _data, CvMat* _inputs )
+{
+    const CvMat* values = _data->get_values();
+    const CvMat* var_idx = _data->get_var_idx();
+    if( var_idx->cols + var_idx->rows != values->cols )\r
+        CV_Error( CV_StsBadArg, "var_idx is not supported" );\r
+    if( _data->get_missing() )\r
+        CV_Error( CV_StsBadArg, "missing values are not supported" );\r
+    int resp_idx = _data->get_response_idx();\r
+    if( resp_idx == 0)\r
+        cvGetCols( values, _inputs, 1, values->cols );\r
+    else if( resp_idx == values->cols - 1 )\r
+        cvGetCols( values, _inputs, 0, values->cols - 1 );\r
+    else\r
+        CV_Error( CV_StsBadArg, "outputs must be in the first or last column; other cases are not supported" );\r
+}
+void ann_get_new_responses( CvMLData* _data, Mat& new_responses, map<int, int>& cls_map )
+{
+    const CvMat* train_sidx = _data->get_train_sample_idx();
+    int* train_sidx_ptr = train_sidx->data.i;\r
+    const CvMat* responses = _data->get_responses();\r
+    float* responses_ptr = responses->data.fl;\r
+    int r_step = CV_IS_MAT_CONT(responses->type) ?\r
+        1 : responses->step / CV_ELEM_SIZE(responses->type);\r
+    int cls_count = 0;\r
+    // construct cls_map\r
+    cls_map.clear();\r
+    for( int si = 0; si < train_sidx->cols; si++ )\r
+    {\r
+        int sidx = train_sidx_ptr[si];\r
+        int r = cvRound(responses_ptr[sidx*r_step]);\r
+        CV_DbgAssert( fabs(responses_ptr[sidx*r_step]-r) < FLT_EPSILON );\r
+        int cls_map_size = (int)cls_map.size();\r
+        cls_map[r];\r
+        if ( (int)cls_map.size() > cls_map_size )\r
+            cls_map[r] = cls_count++;\r
+    }\r
+    new_responses.create( responses->rows, cls_count, CV_32F );\r
+    new_responses.setTo( 0 );\r
+    for( int si = 0; si < train_sidx->cols; si++ )\r
+    {\r
+        int sidx = train_sidx_ptr[si];\r
+        int r = cvRound(responses_ptr[sidx*r_step]);\r
+        int cidx = cls_map[r];\r
+        new_responses.ptr<float>(sidx)[cidx] = 1;\r
+    }
+}
+int ann_train( CvANN_MLP* ann, CvMLData* _data, Mat& new_responses, CvANN_MLP_TrainParams _params, int flags = 0 )
+{
+    const CvMat* train_sidx = _data->get_train_sample_idx();\r
+    CvMat predictors;\r
+    ann_check_data_and_get_predictors( _data, &predictors );\r
+    CvMat _new_responses = CvMat( new_responses );\r
+    return ann->train( &predictors, &_new_responses, 0, train_sidx, _params, flags );
+}
+float ann_calc_error( CvANN_MLP* ann, CvMLData* _data, map<int, int>& cls_map, int type , vector<float> *resp_labels )
+{
+    float err = 0;\r
+    const CvMat* responses = _data->get_responses();\r
+    const CvMat* sample_idx = (type == CV_TEST_ERROR) ? _data->get_test_sample_idx() : _data->get_train_sample_idx();\r
+    int* sidx = sample_idx ? sample_idx->data.i : 0;\r
+    int r_step = CV_IS_MAT_CONT(responses->type) ?\r
+        1 : responses->step / CV_ELEM_SIZE(responses->type);\r
+    CvMat predictors;\r
+    ann_check_data_and_get_predictors( _data, &predictors );\r
+    int sample_count = sample_idx ? sample_idx->cols : 0;\r
+    sample_count = (type == CV_TRAIN_ERROR && sample_count == 0) ? predictors.rows : sample_count;\r
+    float* pred_resp = 0;\r
+    vector<float> innresp;\r
+    if( sample_count > 0 )\r
+    {\r
+        if( resp_labels )\r
+        {\r
+            resp_labels->resize( sample_count );\r
+            pred_resp = &((*resp_labels)[0]);\r
+        }\r
+        else\r
+        {\r
+            innresp.resize( sample_count );\r
+            pred_resp = &(innresp[0]);\r
+        }\r
+    }\r
+    int cls_count = (int)cls_map.size();\r
+    Mat output( 1, cls_count, CV_32FC1 );\r
+    CvMat _output = CvMat(output);\r
+    map<int, int>::iterator b_it = cls_map.begin();\r
+    for( int i = 0; i < sample_count; i++ )\r
+    {\r
+        CvMat sample;\r
+        int si = sidx ? sidx[i] : i;\r
+        cvGetRow( &predictors, &sample, si ); \r
+        ann->predict( &sample, &_output );\r
+        CvPoint best_cls = {0,0};\r
+        cvMinMaxLoc( &_output, 0, 0, 0, &best_cls, 0 );
+        int r = cvRound(responses->data.fl[si*r_step]);
+        CV_DbgAssert( fabs(responses->data.fl[si*r_step]-r) < FLT_EPSILON );
+        r = cls_map[r];
+        int d = best_cls.x == r ? 0 : 1;
+        err += d;\r
+        pred_resp[i] = (float)best_cls.x;\r
+    }\r
+    err = sample_count ? err / (float)sample_count * 100 : -FLT_MAX;\r
+    return err;
+}\r
+\r
+// 6. dtree\r
+// 7. boost\r
+int str_to_boost_type( string& str )\r
+{\r
+    if ( !str.compare("DISCRETE") )\r
+        return CvBoost::DISCRETE;\r
+    if ( !str.compare("REAL") )\r
+        return CvBoost::REAL;    \r
+    if ( !str.compare("LOGIT") )\r
+        return CvBoost::LOGIT;\r
+    if ( !str.compare("GENTLE") )\r
+        return CvBoost::GENTLE;\r
+    CV_Error( CV_StsBadArg, "incorrect boost type string" );\r
+    return -1;\r
+}\r
+\r
+// 8. rtrees\r
+// 9. ertrees\r
+\r
+// ---------------------------------- MLBaseTest ---------------------------------------------------\r
+\r
+CV_MLBaseTest::CV_MLBaseTest( const char* _modelName, const char* _testName, const char* _testFuncs ) :
+CvTest( _testName, _testFuncs )
+{
+    modelName = _modelName;
+    nbayes = 0;\r
+    knearest = 0;\r
+    svm = 0;\r
+    em = 0;\r
+    ann = 0;\r
+    dtree = 0;\r
+    boost = 0;\r
+    rtrees = 0;\r
+    ertrees = 0;
+    if( !modelName.compare(CV_NBAYES) )
+        nbayes = new CvNormalBayesClassifier;
+    else if( !modelName.compare(CV_KNEAREST) )
+        knearest = new CvKNearest;
+    else if( !modelName.compare(CV_SVM) )
+        svm = new CvSVM;
+    else if( !modelName.compare(CV_EM) )
+        em = new CvEM;
+    else if( !modelName.compare(CV_ANN) )
+        ann = new CvANN_MLP;
+    else if( !modelName.compare(CV_DTREE) )
+        dtree = new CvDTree;
+    else if( !modelName.compare(CV_BOOST) )
+        boost = new CvBoost;
+    else if( !modelName.compare(CV_RTREES) )
+        rtrees = new CvRTrees;
+    else if( !modelName.compare(CV_ERTREES) )
+        ertrees = new CvERTrees;
+}
+\r
+int CV_MLBaseTest::init( CvTS* system )\r
+{\r
+    clear();
+    ts = system;
+
+    string filename = ts->get_data_path();
+    filename += get_validation_filename();
+    validationFS.open( filename, FileStorage::READ );
+    return read_params( *validationFS );\r
+}\r
+\r
+CV_MLBaseTest::~CV_MLBaseTest()\r
+{\r
+    if( validationFS.isOpened() )\r
+        validationFS.release();\r
+    if( nbayes )\r
+        delete nbayes;\r
+    if( knearest ) \r
+        delete knearest;\r
+    if( svm )\r
+        delete svm;\r
+    if( em )\r
+        delete em;\r
+    if( ann )\r
+        delete ann;\r
+    if( dtree )\r
+        delete dtree;\r
+    if( boost )\r
+        delete boost;\r
+    if( rtrees )\r
+        delete rtrees;\r
+    if( ertrees )\r
+        delete ertrees;\r
+}\r
+\r
+int CV_MLBaseTest::read_params( CvFileStorage* _fs )\r
+{\r
+    if( !_fs )\r
+        test_case_count = -1;\r
+    else\r
+    {\r
+        CvFileNode* fn = cvGetRootFileNode( _fs, 0 );\r
+        fn = (CvFileNode*)cvGetSeqElem( fn->data.seq, 0 );\r
+        fn = cvGetFileNodeByName( _fs, fn, "run_params" );\r
+        CvSeq* dataSetNamesSeq = cvGetFileNodeByName( _fs, fn, modelName.c_str() )->data.seq;\r
+        test_case_count = dataSetNamesSeq ? dataSetNamesSeq->total : -1;\r
+        if( test_case_count > 0 )\r
+        {\r
+            dataSetNames.resize( test_case_count );\r
+            vector<string>::iterator it = dataSetNames.begin();\r
+            for( int i = 0; i < test_case_count; i++, it++ )\r
+                *it = ((CvFileNode*)cvGetSeqElem( dataSetNamesSeq, i ))->data.str.ptr;\r
+        }\r
+    }\r
+    return CvTS::OK;;\r
+}\r
+\r
+void CV_MLBaseTest::run( int start_from )
+{
+    int code = CvTS::OK;
+    start_from = 0;
+    for (int i = 0; i < test_case_count; i++)
+    {
+        int temp_code = run_test_case( i );
+        if (temp_code == CvTS::OK)
+            temp_code = validate_test_results( i );
+        if (temp_code != CvTS::OK)
+            code = temp_code;
+    }
+    if ( test_case_count <= 0)
+    {\r
+        ts->printf( CvTS::LOG, "validation file is not determined or not correct" );
+        code = CvTS::FAIL_INVALID_TEST_DATA;\r
+    }
+    ts->set_failed_test_info( code );
+}
+
+int CV_MLBaseTest::prepare_test_case( int test_case_idx )
+{
+    int trainSampleCount, respIdx;
+    string varTypes;
+    clear();
+\r
+    string dataPath = ts->get_data_path();
+    if ( dataPath.empty() )\r
+    {\r
+        ts->printf( CvTS::LOG, "data path is empty" );
+        return CvTS::FAIL_INVALID_TEST_DATA;\r
+    }
+
+    string dataName = dataSetNames[test_case_idx],
+        filename = dataPath + dataName + ".data";
+    if ( data.read_csv( filename.c_str() ) != 0)
+    {\r
+        char msg[100];\r
+        sprintf( msg, "file %s can not be read", filename.c_str() );\r
+        ts->printf( CvTS::LOG, msg );
+        return CvTS::FAIL_INVALID_TEST_DATA;\r
+    }
+
+    FileNode dataParamsNode = validationFS.getFirstTopLevelNode()["validation"][modelName][dataName]["data_params"];
+    CV_DbgAssert( !dataParamsNode.empty() );
+
+    CV_DbgAssert( !dataParamsNode["LS"].empty() );
+    dataParamsNode["LS"] >> trainSampleCount;
+    CvTrainTestSplit spl( trainSampleCount );
+    data.set_train_test_split( &spl );
+
+    CV_DbgAssert( !dataParamsNode["resp_idx"].empty() );
+    dataParamsNode["resp_idx"] >> respIdx;
+    data.set_response_idx( respIdx );
+
+    CV_DbgAssert( !dataParamsNode["types"].empty() );
+    dataParamsNode["types"] >> varTypes;
+    data.set_var_types( varTypes.c_str() );
+
+    return CvTS::OK;
+}
+
+string& CV_MLBaseTest::get_validation_filename()
+{
+    return validationFN;
+}
+
+int CV_MLBaseTest::train( int testCaseIdx )\r
+{\r
+    bool is_trained = false;\r
+    FileNode modelParamsNode = 
+        validationFS.getFirstTopLevelNode()["validation"][modelName][dataSetNames[testCaseIdx]]["model_params"];
+\r
+    if( !modelName.compare(CV_NBAYES) )
+        is_trained = nbayes_train( nbayes, &data );
+    else if( !modelName.compare(CV_KNEAREST) )
+    {
+        assert( 0 );
+        //is_trained = knearest->train( &data );
+    }
+    else if( !modelName.compare(CV_SVM) )
+    {
+        string svm_type_str, kernel_type_str;
+        modelParamsNode["svm_type"] >> svm_type_str;
+        modelParamsNode["kernel_type"] >> kernel_type_str;
+        CvSVMParams params;
+        params.svm_type = str_to_svm_type( svm_type_str );
+        params.kernel_type = str_to_svm_kernel_type( kernel_type_str );
+        modelParamsNode["degree"] >> params.degree;
+        modelParamsNode["gamma"] >> params.gamma;
+        modelParamsNode["coef0"] >> params.coef0;
+        modelParamsNode["C"] >> params.C;
+        modelParamsNode["nu"] >> params.nu;
+        modelParamsNode["p"] >> params.p;
+        is_trained = svm_train( svm, &data, params );
+    }
+    else if( !modelName.compare(CV_EM) )
+    {
+        assert( 0 );
+    }
+    else if( !modelName.compare(CV_ANN) )
+    {
+        string train_method_str;
+        double param1, param2;
+        modelParamsNode["train_method"] >> train_method_str;
+        modelParamsNode["param1"] >> param1;
+        modelParamsNode["param2"] >> param2;
+        Mat new_responses;
+        ann_get_new_responses( &data, new_responses, cls_map );
+        int layer_sz[] = { data.get_values()->cols - 1, 100, 100, (int)cls_map.size() };
+        CvMat layer_sizes =
+            cvMat( 1, (int)(sizeof(layer_sz)/sizeof(layer_sz[0])), CV_32S, layer_sz );
+        ann->create( &layer_sizes );
+        is_trained = ann_train( ann, &data, new_responses, CvANN_MLP_TrainParams(cvTermCriteria(CV_TERMCRIT_ITER,300,0.01),
+            str_to_ann_train_method(train_method_str), param1, param2) ) >= 0;
+    }
+    else if( !modelName.compare(CV_DTREE) )
+    {
+        int MAX_DEPTH, MIN_SAMPLE_COUNT, MAX_CATEGORIES, CV_FOLDS;\r
+        float REG_ACCURACY = 0;\r
+        bool USE_SURROGATE, IS_PRUNED;
+        modelParamsNode["max_depth"] >> MAX_DEPTH;
+        modelParamsNode["min_sample_count"] >> MIN_SAMPLE_COUNT;
+        modelParamsNode["use_surrogate"] >> USE_SURROGATE;
+        modelParamsNode["max_categories"] >> MAX_CATEGORIES;
+        modelParamsNode["cv_folds"] >> CV_FOLDS;
+        modelParamsNode["is_pruned"] >> IS_PRUNED;
+        is_trained = dtree->train( &data, \r
+            CvDTreeParams(MAX_DEPTH, MIN_SAMPLE_COUNT, REG_ACCURACY, USE_SURROGATE,\r
+            MAX_CATEGORIES, CV_FOLDS, false, IS_PRUNED, 0 )) != 0;\r
+    }
+    else if( !modelName.compare(CV_BOOST) )
+    {
+        int BOOST_TYPE, WEAK_COUNT, MAX_DEPTH;\r
+        float WEIGHT_TRIM_RATE;\r
+        bool USE_SURROGATE;
+        string typeStr;
+        modelParamsNode["type"] >> typeStr;
+        BOOST_TYPE = str_to_boost_type( typeStr );
+        modelParamsNode["weak_count"] >> WEAK_COUNT;\r
+        modelParamsNode["weight_trim_rate"] >> WEIGHT_TRIM_RATE;\r
+        modelParamsNode["max_depth"] >> MAX_DEPTH;\r
+        modelParamsNode["use_surrogate"] >> USE_SURROGATE;\r
+        is_trained = boost->train( &data,\r
+            CvBoostParams(BOOST_TYPE, WEAK_COUNT, WEIGHT_TRIM_RATE, MAX_DEPTH, USE_SURROGATE, 0) ) != 0;
+    }
+    else if( !modelName.compare(CV_RTREES) )
+    {
+        int MAX_DEPTH, MIN_SAMPLE_COUNT, MAX_CATEGORIES, CV_FOLDS, NACTIVE_VARS, MAX_TREES_NUM;\r
+        float REG_ACCURACY = 0, OOB_EPS = 0.0;\r
+        bool USE_SURROGATE, IS_PRUNED;
+        modelParamsNode["max_depth"] >> MAX_DEPTH;
+        modelParamsNode["min_sample_count"] >> MIN_SAMPLE_COUNT;
+        modelParamsNode["use_surrogate"] >> USE_SURROGATE;
+        modelParamsNode["max_categories"] >> MAX_CATEGORIES;
+        modelParamsNode["cv_folds"] >> CV_FOLDS;
+        modelParamsNode["is_pruned"] >> IS_PRUNED;
+        modelParamsNode["nactive_vars"] >> NACTIVE_VARS;
+        modelParamsNode["max_trees_num"] >> MAX_TREES_NUM;
+        is_trained = rtrees->train( &data, CvRTParams( MAX_DEPTH, MIN_SAMPLE_COUNT, REG_ACCURACY,\r
+            USE_SURROGATE, MAX_CATEGORIES, 0, true, // (calc_var_importance == true) <=> RF processes variable importance\r
+            NACTIVE_VARS, MAX_TREES_NUM, OOB_EPS, CV_TERMCRIT_ITER)) != 0;
+    }
+    else if( !modelName.compare(CV_ERTREES) )
+    {\r
+        int MAX_DEPTH, MIN_SAMPLE_COUNT, MAX_CATEGORIES, CV_FOLDS, NACTIVE_VARS, MAX_TREES_NUM;\r
+        float REG_ACCURACY = 0, OOB_EPS = 0.0;\r
+        bool USE_SURROGATE, IS_PRUNED;\r
+        modelParamsNode["max_depth"] >> MAX_DEPTH;
+        modelParamsNode["min_sample_count"] >> MIN_SAMPLE_COUNT;
+        modelParamsNode["use_surrogate"] >> USE_SURROGATE;
+        modelParamsNode["max_categories"] >> MAX_CATEGORIES;
+        modelParamsNode["cv_folds"] >> CV_FOLDS;
+        modelParamsNode["is_pruned"] >> IS_PRUNED;
+        modelParamsNode["nactive_vars"] >> NACTIVE_VARS;
+        modelParamsNode["max_trees_num"] >> MAX_TREES_NUM;
+        is_trained = ertrees->train( &data, CvRTParams( MAX_DEPTH, MIN_SAMPLE_COUNT, REG_ACCURACY,\r
+            USE_SURROGATE, MAX_CATEGORIES, 0, false, // (calc_var_importance == true) <=> RF processes variable importance\r
+            NACTIVE_VARS, MAX_TREES_NUM, OOB_EPS, CV_TERMCRIT_ITER)) != 0;\r
+    }\r
+\r
+    if( !is_trained )\r
+    {\r
+        ts->printf( CvTS::LOG, "in test case %d model training was failed", testCaseIdx );\r
+        return CvTS::FAIL_INVALID_OUTPUT;\r
+    }\r
+    return CvTS::OK;\r
+}\r
+\r
+float CV_MLBaseTest::get_error( int testCaseIdx, int type, vector<float> *resp )\r
+{\r
+    float err = 0;\r
+    if( !modelName.compare(CV_NBAYES) )
+        err = nbayes_calc_error( nbayes, &data, type, resp );
+    else if( !modelName.compare(CV_KNEAREST) )
+    {
+        assert( 0 );
+        testCaseIdx = 0;
+        /*int k = 2;
+        validationFS.getFirstTopLevelNode()["validation"][modelName][dataSetNames[testCaseIdx]]["model_params"]["k"] >> k;
+        err = knearest->calc_error( &data, k, type, resp );*/
+    }
+    else if( !modelName.compare(CV_SVM) )
+        err = svm_calc_error( svm, &data, type, resp );
+    else if( !modelName.compare(CV_EM) )
+        assert( 0 );
+    else if( !modelName.compare(CV_ANN) )
+        err = ann_calc_error( ann, &data, cls_map, type, resp );
+    else if( !modelName.compare(CV_DTREE) )
+        err = dtree->calc_error( &data, type, resp );
+    else if( !modelName.compare(CV_BOOST) )
+        err = boost->calc_error( &data, type, resp );
+    else if( !modelName.compare(CV_RTREES) )
+        err = rtrees->calc_error( &data, type, resp );
+    else if( !modelName.compare(CV_ERTREES) )
+        err = ertrees->calc_error( &data, type, resp );\r
+    return err;\r
+}\r
+\r
+void CV_MLBaseTest::save( const char* filename )\r
+{\r
+    if( !modelName.compare(CV_NBAYES) )
+        nbayes->save( filename );
+    else if( !modelName.compare(CV_KNEAREST) )
+        knearest->save( filename );
+    else if( !modelName.compare(CV_SVM) )
+        svm->save( filename );
+    else if( !modelName.compare(CV_EM) )
+        em->save( filename );
+    else if( !modelName.compare(CV_ANN) )
+        ann->save( filename );
+    else if( !modelName.compare(CV_DTREE) )
+        dtree->save( filename );
+    else if( !modelName.compare(CV_BOOST) )
+        boost->save( filename );
+    else if( !modelName.compare(CV_RTREES) )
+        rtrees->save( filename );
+    else if( !modelName.compare(CV_ERTREES) )
+        ertrees->save( filename );\r
+}\r
+\r
+void CV_MLBaseTest::load( const char* filename )\r
+{\r
+    if( !modelName.compare(CV_NBAYES) )
+        nbayes->load( filename );
+    else if( !modelName.compare(CV_KNEAREST) )
+        knearest->load( filename );
+    else if( !modelName.compare(CV_SVM) )
+        svm->load( filename );
+    else if( !modelName.compare(CV_EM) )
+        em->load( filename );
+    else if( !modelName.compare(CV_ANN) )
+        ann->load( filename );
+    else if( !modelName.compare(CV_DTREE) )
+        dtree->load( filename );
+    else if( !modelName.compare(CV_BOOST) )
+        boost->load( filename );
+    else if( !modelName.compare(CV_RTREES) )
+        rtrees->load( filename );
+    else if( !modelName.compare(CV_ERTREES) )
+        ertrees->load( filename );\r
+}\r
+\r
+/* End of file. */\r