Update to 2.0.0 tree from current Fremantle build
[opencv] / apps / traincascade / boost.h
1 #ifndef _OPENCV_BOOST_H_
2 #define _OPENCV_BOOST_H_
3
4 #include "features.h"
5 #include "ml.h"
6
7 struct CvCascadeBoostParams : CvBoostParams
8 {
9     float minHitRate;
10     float maxFalseAlarm;
11     
12     CvCascadeBoostParams();
13     CvCascadeBoostParams( int _boostType, float _minHitRate, float _maxFalseAlarm,
14                           double _weightTrimRate, int _maxDepth, int _maxWeakCount );
15     virtual ~CvCascadeBoostParams() {}
16     void write( FileStorage &fs ) const;
17     bool read( const FileNode &node );
18     virtual void printDefaults() const;
19     virtual void printAttrs() const;
20     virtual bool scanAttr( const String prmName, const String val);
21 };
22
23 struct CvCascadeBoostTrainData : CvDTreeTrainData
24 {
25     CvCascadeBoostTrainData( const CvFeatureEvaluator* _featureEvaluator,
26                              const CvDTreeParams& _params );
27     CvCascadeBoostTrainData( const CvFeatureEvaluator* _featureEvaluator,
28                              int _numSamples, int _precalcValBufSize, int _precalcIdxBufSize,
29                              const CvDTreeParams& _params = CvDTreeParams() );
30     virtual void setData( const CvFeatureEvaluator* _featureEvaluator,
31                           int _numSamples, int _precalcValBufSize, int _precalcIdxBufSize,
32                           const CvDTreeParams& _params=CvDTreeParams() );
33     void precalculate();
34
35     virtual void get_class_labels( CvDTreeNode* n, int* labelsBuf, const int** labels );
36     virtual void get_cv_labels( CvDTreeNode* n, int* labelsBuf, const int** labels );
37     virtual void get_sample_indices( CvDTreeNode* n, int* indicesBuf, const int** labels );
38     
39     virtual int get_ord_var_data( CvDTreeNode* n, int vi, float* ordValuesBuf, int* indicesBuf,
40                                   const float** ordValues, const int** indices );
41     virtual int get_cat_var_data( CvDTreeNode* n, int vi, int* catValuesBuf, const int** catValues );
42     virtual float getVarValue( int vi, int si );
43     virtual void free_train_data();
44
45     const CvFeatureEvaluator* featureEvaluator;
46     Mat valCache; // precalculated feature values (CV_32FC1)
47     CvMat _resp; // for casting
48     int numPrecalcVal, numPrecalcIdx;
49 };
50
51 class CvCascadeBoostTree : public CvBoostTree
52 {
53 public:
54     virtual CvDTreeNode* predict( int sampleIdx ) const;
55     void write( FileStorage &fs, const Mat& featureMap );
56     void read( const FileNode &node, CvBoost* _ensemble, CvDTreeTrainData* _data );
57     void markFeaturesInMap( Mat& featureMap );
58 protected:
59     virtual void split_node_data( CvDTreeNode* n );
60 };
61
62 class CvCascadeBoost : public CvBoost
63 {
64 public:
65     virtual bool train( const CvFeatureEvaluator* _featureEvaluator,
66                         int _numSamples, int _precalcValBufSize, int _precalcIdxBufSize,
67                         const CvCascadeBoostParams& _params=CvCascadeBoostParams() );
68     virtual float predict( int sampleIdx, bool returnSum = false ) const;
69
70     float getThreshold() const { return threshold; }; 
71     void write( FileStorage &fs, const Mat& featureMap ) const;
72     bool read( const FileNode &node, const CvFeatureEvaluator* _featureEvaluator,
73                const CvCascadeBoostParams& _params );
74     void markUsedFeaturesInMap( Mat& featureMap );
75 protected:
76     virtual bool set_params( const CvBoostParams& _params );
77     virtual void update_weights( CvBoostTree* tree );
78     virtual bool isErrDesired();
79
80     float threshold;
81     float minHitRate, maxFalseAlarm;
82 };
83
84 #endif