Update to 2.0.0 tree from current Fremantle build
[opencv] / apps / traincascade / cascadeclassifier.h
1 #ifndef _OPENCV_CASCADECLASSIFIER_H_
2 #define _OPENCV_CASCADECLASSIFIER_H_
3
4 #include <ctime>
5 #include "features.h"
6 #include "haarfeatures.h"
7 #include "lbpfeatures.h"
8 #include "boost.h"
9 #include "cv.h"
10 #include "cxcore.h"
11
12 #define CC_CASCADE_FILENAME "cascade.xml"
13 #define CC_PARAMS_FILENAME "params.xml"
14
15 #define CC_CASCADE_PARAMS "cascadeParams"
16 #define CC_STAGE_TYPE "stageType"
17 #define CC_FEATURE_TYPE "featureType"
18 #define CC_HEIGHT "height"
19 #define CC_WIDTH  "width"
20
21 #define CC_STAGE_NUM    "stageNum"
22 #define CC_STAGES       "stages"
23 #define CC_STAGE_PARAMS "stageParams"
24
25 #define CC_BOOST            "BOOST"
26 #define CC_BOOST_TYPE       "boostType"
27 #define CC_DISCRETE_BOOST   "DAB"
28 #define CC_REAL_BOOST       "RAB"
29 #define CC_LOGIT_BOOST      "LB"
30 #define CC_GENTLE_BOOST     "GAB"
31 #define CC_MINHITRATE       "minHitRate"
32 #define CC_MAXFALSEALARM    "maxFalseAlarm"
33 #define CC_TRIM_RATE        "weightTrimRate"
34 #define CC_MAX_DEPTH        "maxDepth"
35 #define CC_WEAK_COUNT       "maxWeakCount"
36 #define CC_STAGE_THRESHOLD  "stageThreshold"
37 #define CC_WEAK_CLASSIFIERS "weakClassifiers"
38 #define CC_INTERNAL_NODES   "internalNodes"
39 #define CC_LEAF_VALUES      "leafValues"
40
41 #define CC_FEATURES       FEATURES
42 #define CC_FEATURE_PARAMS "featureParams"
43 #define CC_MAX_CAT_COUNT  "maxCatCount"
44
45 #define CC_HAAR        "HAAR"
46 #define CC_MODE        "mode"
47 #define CC_MODE_BASIC  "BASIC"
48 #define CC_MODE_CORE   "CORE"
49 #define CC_MODE_ALL    "ALL"
50 #define CC_RECTS       "rects"
51 #define CC_TILTED      "tilted"
52
53 #define CC_LBP  "LBP"
54 #define CC_RECT "rect"
55
56 #ifdef _WIN32
57 #define TIME( arg ) (((double) clock()) / CLOCKS_PER_SEC)
58 #else
59 #define TIME( arg ) (time( arg ))
60 #endif
61
62 class CvCascadeParams : public CvParams
63 {
64 public:
65     enum { BOOST = 0 };
66     static const int defaultStageType = BOOST;
67     static const int defaultFeatureType = CvFeatureParams::HAAR;
68
69     CvCascadeParams();
70     CvCascadeParams( int _stageType, int _featureType );
71     void write( FileStorage &fs ) const;
72     bool read( const FileNode &node );
73
74     void printDefaults() const;
75     void printAttrs() const;    
76     bool scanAttr( const String prmName, const String val );
77
78     int stageType;
79     int featureType;
80     Size winSize;
81 };
82
83 class CvCascadeClassifier
84 {
85 public:
86     bool train( const String _cascadeDirName,
87                 const String _posFilename,
88                 const String _negFilename, 
89                 int _numPos, int _numNeg, 
90                 int _precalcValBufSize, int _precalcIdxBufSize,
91                 int _numStages,
92                 const CvCascadeParams& _cascadeParams,
93                 const CvFeatureParams& _featureParams,
94                 const CvCascadeBoostParams& _stageParams,
95                 bool baseFormatSave = false );
96 private:
97     int predict( int sampleIdx );
98     void save( const String cascadeDirName, bool baseFormat = false );
99     bool load( const String cascadeDirName );
100     bool updateTrainingSet( double& acceptanceRatio );
101     int fillPassedSamles( int first, int count, bool isPositive, int64& consumed );
102
103     void writeParams( FileStorage &fs ) const;
104     void writeStages( FileStorage &fs, const Mat& featureMap ) const;
105     void writeFeatures( FileStorage &fs, const Mat& featureMap ) const;
106     bool readParams( const FileNode &node );
107     bool readStages( const FileNode &node );
108     
109     void getUsedFeaturesIdxMap( Mat& featureMap );
110
111     CvCascadeParams cascadeParams;
112     Ptr<CvFeatureParams> featureParams;
113     Ptr<CvCascadeBoostParams> stageParams;
114
115     Ptr<CvFeatureEvaluator> featureEvaluator;    
116     vector< Ptr<CvCascadeBoost> > stageClassifiers;
117     CvCascadeImageReader imgReader;
118     int numStages, curNumSamples;
119     int numPos, numNeg;
120 };
121
122 #endif