9342c1b000a6d8655b20d1f8d9350a562dfed2f6
[opencv] / tests / ml / src / mltest.h
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                           License Agreement
11 //                For Open Source Computer Vision Library
12 //
13 // Copyright (C) 2000-2008, Intel Corporation, all rights reserved.
14 // Copyright (C) 2009, Willow Garage Inc., all rights reserved.
15 // Third party copyrights are property of their respective owners.
16 //
17 // Redistribution and use in source and binary forms, with or without modification,
18 // are permitted provided that the following conditions are met:
19 //
20 //   * Redistribution's of source code must retain the above copyright notice,
21 //     this list of conditions and the following disclaimer.
22 //
23 //   * Redistribution's in binary form must reproduce the above copyright notice,
24 //     this list of conditions and the following disclaimer in the documentation
25 //     and/or other materials provided with the distribution.
26 //
27 //   * The name of the copyright holders may not be used to endorse or promote products
28 //     derived from this software without specific prior written permission.
29 //
30 // This software is provided by the copyright holders and contributors "as is" and
31 // any express or implied warranties, including, but not limited to, the implied
32 // warranties of merchantability and fitness for a particular purpose are disclaimed.
33 // In no event shall the Intel Corporation or contributors be liable for any direct,
34 // indirect, incidental, special, exemplary, or consequential damages
35 // (including, but not limited to, procurement of substitute goods or services;
36 // loss of use, data, or profits; or business interruption) however caused
37 // and on any theory of liability, whether in contract, strict liability,
38 // or tort (including negligence or otherwise) arising in any way out of
39 // the use of this software, even if advised of the possibility of such damage.
40 //
41 //M*/
42
43 #ifndef _OPENCV_MLTEST_H_
44 #define _OPENCV_MLTEST_H_
45
46 #if defined _MSC_VER && _MSC_VER >= 1200
47 #pragma warning( disable: 4710 4711 4514 4996 )
48 #endif
49
50 #include "cxcore.h"
51 #include "cxmisc.h"
52 #include "cxts.h"
53 #include "ml.h"
54 #include <map>
55 #include <string>
56 #include <iostream>
57
58 using namespace std;
59 using namespace cv;
60
61 #define CV_NBAYES   "nbayes"
62 #define CV_KNEAREST "knearest"
63 #define CV_SVM      "svm"
64 #define CV_EM       "em"
65 #define CV_ANN      "ann"
66 #define CV_DTREE    "dtree"
67 #define CV_BOOST    "boost"
68 #define CV_RTREES   "rtrees"
69 #define CV_ERTREES  "ertrees"
70
71 class CV_MLBaseTest : public CvTest
72 {
73 public:
74     CV_MLBaseTest( const char* _modelName, const char* _testName, const char* _testFuncs );
75     virtual ~CV_MLBaseTest();
76     virtual int init( CvTS* system );
77 protected:
78     virtual int read_params( CvFileStorage* fs );
79     virtual void run( int startFrom );
80     virtual int prepare_test_case( int testCaseIdx );
81     virtual string& get_validation_filename();
82     virtual int run_test_case( int testCaseIdx ) = 0;
83     virtual int validate_test_results( int testCaseIdx ) = 0;
84
85     int train( int testCaseIdx );
86     float get_error( int testCaseIdx, int type, vector<float> *resp = 0 );
87     void save( const char* filename );
88     void load( const char* filename );
89
90     CvMLData data;
91     string modelName, validationFN;
92     vector<string> dataSetNames;
93     FileStorage validationFS;
94
95     // MLL models
96     CvNormalBayesClassifier* nbayes;
97     CvKNearest* knearest;
98     CvSVM* svm;
99     CvEM* em;
100     CvANN_MLP* ann;
101     CvDTree* dtree;
102     CvBoost* boost;
103     CvRTrees* rtrees;
104     CvERTrees* ertrees;
105
106     map<int, int> cls_map;
107 };
108
109 class CV_AMLTest : public CV_MLBaseTest
110 {
111 public:
112     CV_AMLTest( const char* _modelName, const char* _testName ); 
113 protected:
114     virtual int run_test_case( int testCaseIdx );
115     virtual int validate_test_results( int testCaseIdx );
116 };
117
118 class CV_SLMLTest : public CV_MLBaseTest
119 {
120 public:
121     CV_SLMLTest( const char* _modelName, const char* _testName ); 
122 protected:
123     virtual int run_test_case( int testCaseIdx );
124     virtual int validate_test_results( int testCaseIdx );
125
126     vector<float> test_resps1, test_resps2; // predicted responses for test data
127     char fname1[50], fname2[50];
128 };
129
130 /* End of file. */
131
132 #endif