Update to 2.0.0 tree from current Fremantle build
[opencv] / apps / traincascade / cascadeclassifier.cpp
1 #include "cascadeclassifier.h"
2 #include <queue>
3
4 using namespace std;
5
6 static const char* stageTypes[] = { CC_BOOST };
7 static const char* featureTypes[] = { CC_HAAR, CC_LBP };
8
9 CvCascadeParams::CvCascadeParams() : stageType( defaultStageType ), 
10     featureType( defaultFeatureType ), winSize( cvSize(24, 24) )
11
12     name = CC_CASCADE_PARAMS; 
13 }
14 CvCascadeParams::CvCascadeParams( int _stageType, int _featureType ) : stageType( _stageType ),
15     featureType( _featureType ), winSize( cvSize(24, 24) )
16
17     name = CC_CASCADE_PARAMS;
18 }
19
20 //---------------------------- CascadeParams --------------------------------------
21
22 void CvCascadeParams::write( FileStorage &fs ) const
23 {
24     String stageTypeStr = stageType == BOOST ? CC_BOOST : String();
25     CV_Assert( !stageTypeStr.empty() );
26     fs << CC_STAGE_TYPE << stageTypeStr;
27     String featureTypeStr = featureType == CvFeatureParams::HAAR ? CC_HAAR :
28                             featureType == CvFeatureParams::LBP ? CC_LBP : 0;
29     CV_Assert( !stageTypeStr.empty() );
30     fs << CC_FEATURE_TYPE << featureTypeStr;
31     fs << CC_HEIGHT << winSize.height;
32     fs << CC_WIDTH << winSize.width;
33 }
34
35 bool CvCascadeParams::read( const FileNode &node )
36 {
37     if ( node.empty() )
38         return false;
39     String stageTypeStr, featureTypeStr;
40     FileNode rnode = node[CC_STAGE_TYPE];
41     if ( !rnode.isString() )
42         return false;
43     rnode >> stageTypeStr;
44     stageType = !stageTypeStr.compare( CC_BOOST ) ? BOOST : -1;
45     if (stageType == -1)
46         return false;
47     rnode = node[CC_FEATURE_TYPE];
48     if ( !rnode.isString() )
49         return false;
50     rnode >> featureTypeStr;
51     featureType = !featureTypeStr.compare( CC_HAAR ) ? CvFeatureParams::HAAR :
52                   !featureTypeStr.compare( CC_LBP ) ? CvFeatureParams::LBP : -1;
53     if (featureType == -1)
54         return false;
55     node[CC_HEIGHT] >> winSize.height;
56     node[CC_WIDTH] >> winSize.width;
57     return winSize.height > 0 && winSize.width > 0;
58 }
59
60 void CvCascadeParams::printDefaults() const
61 {
62     CvParams::printDefaults();
63     cout << "  [-stageType <";
64     for( int i = 0; i < (int)(sizeof(stageTypes)/sizeof(stageTypes[0])); i++ )
65     {
66         cout << (i ? " | " : "") << stageTypes[i];
67         if ( i == defaultStageType )
68             cout << "(default)";
69     }
70     cout << ">]" << endl;
71
72     cout << "  [-featureType <{";
73     for( int i = 0; i < (int)(sizeof(featureTypes)/sizeof(featureTypes[0])); i++ )
74     {
75         cout << (i ? ", " : "") << featureTypes[i];
76         if ( i == defaultStageType )
77             cout << "(default)";
78     }
79     cout << "}>]" << endl;
80     cout << "  [-w <sampleWidth = " << winSize.width << ">]" << endl;
81     cout << "  [-h <sampleHeight = " << winSize.height << ">]" << endl;
82 }
83
84 void CvCascadeParams::printAttrs() const
85 {
86     cout << "stageType: " << stageTypes[stageType] << endl;
87     cout << "featureType: " << featureTypes[featureType] << endl;
88     cout << "sampleWidth: " << winSize.width << endl;
89     cout << "sampleHeight: " << winSize.height << endl;
90 }
91
92 bool CvCascadeParams::scanAttr( const String prmName, const String val )
93 {
94     bool res = true;
95     if( !prmName.compare( "-stageType" ) )
96     {
97         for( int i = 0; i < (int)(sizeof(stageTypes)/sizeof(stageTypes[0])); i++ )
98             if( !val.compare( stageTypes[i] ) )
99                 stageType = i;
100     }
101     else if( !prmName.compare( "-featureType" ) )
102     {
103         for( int i = 0; i < (int)(sizeof(featureTypes)/sizeof(featureTypes[0])); i++ )
104             if( !val.compare( featureTypes[i] ) )
105                 featureType = i;
106     }
107     else if( !prmName.compare( "-w" ) )
108     {
109         winSize.width = atoi( val.c_str() );
110     }
111     else if( !prmName.compare( "-h" ) )
112     {
113         winSize.height = atoi( val.c_str() );
114     }
115     else
116         res = false;
117     return res;
118 }
119
120 //---------------------------- CascadeClassifier --------------------------------------
121
122 bool CvCascadeClassifier::train( const String _cascadeDirName,
123                                 const String _posFilename,
124                                 const String _negFilename, 
125                                 int _numPos, int _numNeg, 
126                                 int _precalcValBufSize, int _precalcIdxBufSize,
127                                 int _numStages,
128                                 const CvCascadeParams& _cascadeParams,
129                                 const CvFeatureParams& _featureParams,
130                                 const CvCascadeBoostParams& _stageParams,
131                                 bool baseFormatSave )
132 {   
133     if( _cascadeDirName.empty() || _posFilename.empty() || _negFilename.empty() )
134         CV_Error( CV_StsBadArg, "_cascadeDirName or _bgfileName or _vecFileName is NULL" );
135
136     String dirName;
137     if ( _cascadeDirName.find('/') )
138         dirName = _cascadeDirName + '/';
139     else
140         dirName = _cascadeDirName + '\\';
141
142     numPos = _numPos;
143     numNeg = _numNeg;
144     numStages = _numStages;
145     if ( !imgReader.create( _posFilename, _negFilename, cascadeParams.winSize ) )
146         return false;
147     if ( !load( dirName ) )
148     {
149         cascadeParams = _cascadeParams;
150         featureParams = CvFeatureParams::create(cascadeParams.featureType);
151         featureParams->init(_featureParams);
152         stageParams = new CvCascadeBoostParams;
153         *stageParams = _stageParams;
154         featureEvaluator = CvFeatureEvaluator::create(cascadeParams.featureType);
155         featureEvaluator->init( (CvFeatureParams*)featureParams, numPos + numNeg, cascadeParams.winSize );
156         stageClassifiers.reserve( numStages );
157     }
158     cout << "PARAMETERS:" << endl;
159     cout << "cascadeDirName: " << _cascadeDirName << endl;
160     cout << "vecFileName: " << _posFilename << endl;
161     cout << "bgFileName: " << _negFilename << endl;
162     cout << "numPos: " << _numPos << endl;
163     cout << "numNeg: " << _numNeg << endl;
164     cout << "numStages: " << numStages << endl;
165     cout << "precalcValBufSize[Mb] : " << _precalcValBufSize << endl;
166     cout << "precalcIdxBufSize[Mb] : " << _precalcIdxBufSize << endl;
167     cascadeParams.printAttrs();
168     stageParams->printAttrs();
169     featureParams->printAttrs();
170
171     int startNumStages = (int)stageClassifiers.size();
172     if ( startNumStages > 1 )
173         cout << endl << "Stages 0-" << startNumStages-1 << " are loaded" << endl;
174     else if ( startNumStages == 1)
175         cout << endl << "Stage 0 is loaded" << endl;
176     
177     double requiredLeafFARate = pow( (double) stageParams->maxFalseAlarm, (double) numStages ) /
178                                 (double)stageParams->max_depth;
179     double tempLeafFARate;
180     
181     for( int i = startNumStages; i < numStages; i++ )
182     {
183         cout << endl << "===== TRAINING " << i << "-stage =====" << endl;
184         cout << "<BEGIN" << endl;
185
186         if ( !updateTrainingSet( tempLeafFARate ) ) 
187         {
188             cout << "Train dataset for temp stage can not be filled."
189                 "Branch training terminated." << endl;
190             break;
191         }
192         if( tempLeafFARate <= requiredLeafFARate )
193         {
194             cout << "Required leaf false alarm rate achieved. "
195                  "Branch training terminated." << endl;
196             break;
197         }
198
199         CvCascadeBoost* tempStage = new CvCascadeBoost;
200         tempStage->train( (CvFeatureEvaluator*)featureEvaluator,
201                            curNumSamples, _precalcValBufSize, _precalcIdxBufSize,
202                           *((CvCascadeBoostParams*)stageParams) );
203         stageClassifiers.push_back( tempStage );
204
205         cout << "END>" << endl;
206         
207         // save params
208         String filename;
209         if ( i == 0) 
210         {
211             filename = dirName + CC_PARAMS_FILENAME;
212             FileStorage fs( filename, FileStorage::WRITE);
213             if ( !fs.isOpened() )
214                 return false;
215             fs << FileStorage::getDefaultObjectName(filename) << "{";
216             writeParams( fs );
217             fs << "}";
218         }
219         // save temp stage
220         char buf[10];
221         sprintf(buf, "%s%d", "stage", i );
222         filename = dirName + buf + ".xml";
223         FileStorage fs( filename, FileStorage::WRITE );
224         if ( !fs.isOpened() )
225             return false;
226         fs << FileStorage::getDefaultObjectName(filename) << "{";
227         tempStage->write( fs, Mat() );
228         fs << "}";
229     }
230     save( dirName + CC_CASCADE_FILENAME, baseFormatSave );
231     return true;
232 }
233
234 int CvCascadeClassifier::predict( int sampleIdx )
235 {
236     CV_DbgAssert( sampleIdx < numPos + numNeg );
237     for (vector< Ptr<CvCascadeBoost> >::iterator it = stageClassifiers.begin();
238         it != stageClassifiers.end(); it++ )
239     {
240         if ( (*it)->predict( sampleIdx ) == 0.f )
241             return 0;
242     }
243     return 1;
244 }
245
246 bool CvCascadeClassifier::updateTrainingSet( double& acceptanceRatio)
247 {
248     int64 posConsumed = 0, negConsumed = 0;
249     imgReader.restart();
250     int posCount = fillPassedSamles( 0, numPos, true, posConsumed );
251     if( !posCount )
252         return false;
253     cout << "POS count : consumed   " << posCount << " : " << (int)posConsumed << endl;
254
255     int negCount = fillPassedSamles( numPos, numNeg, false, negConsumed );
256     if ( !negCount )
257         return false;
258     curNumSamples = posCount + negCount;
259     acceptanceRatio = negConsumed == 0 ? 0 : ( (double)negCount/(double)(int64)negConsumed );
260     cout << "NEG count : acceptanceRatio    " << negCount << " : " << acceptanceRatio << endl;
261     return true;
262 }
263
264 int CvCascadeClassifier::fillPassedSamles( int first, int count, bool isPositive, int64& consumed )
265 {
266     int getcount = 0;
267     Mat img(cascadeParams.winSize, CV_8UC1);
268     for( int i = first; i < first + count; i++ )
269     {
270         for( ; ; )
271         {
272             bool isGetImg = isPositive ? imgReader.getPos( img ) :
273                                            imgReader.getNeg( img );
274             if( !isGetImg ) 
275                 return getcount;
276             consumed++;
277
278             featureEvaluator->setImage( img, isPositive ? 1 : 0, i );
279             if( predict( i ) == 1.0F )
280             {
281                 getcount++;
282                 break;
283             }
284         }
285     }
286     return getcount;
287 }
288
289 void CvCascadeClassifier::writeParams( FileStorage &fs ) const
290 {
291     cascadeParams.write( fs );
292     fs << CC_STAGE_PARAMS << "{"; stageParams->write( fs ); fs << "}";
293     fs << CC_FEATURE_PARAMS << "{"; featureParams->write( fs ); fs << "}";
294 }
295
296 void CvCascadeClassifier::writeFeatures( FileStorage &fs, const Mat& featureMap ) const
297 {
298     ((CvFeatureEvaluator*)((Ptr<CvFeatureEvaluator>)featureEvaluator))->writeFeatures( fs, featureMap ); 
299 }
300
301 void CvCascadeClassifier::writeStages( FileStorage &fs, const Mat& featureMap ) const
302 {
303     //char cmnt[30];
304     //int i = 0;
305     fs << CC_STAGES << "["; 
306     for( vector< Ptr<CvCascadeBoost> >::const_iterator it = stageClassifiers.begin();
307         it != stageClassifiers.end(); it++/*, i++*/ )
308     {
309         /*sprintf( cmnt, "stage %d", i );
310         CV_CALL( cvWriteComment( fs, cmnt, 0 ) );*/
311         fs << "{";
312         ((CvCascadeBoost*)((Ptr<CvCascadeBoost>)*it))->write( fs, featureMap );
313         fs << "}";
314     }
315     fs << "]";
316 }
317
318 bool CvCascadeClassifier::readParams( const FileNode &node )
319 {
320     if ( !node.isMap() || !cascadeParams.read( node ) )
321         return false;
322     
323     stageParams = new CvCascadeBoostParams;
324     FileNode rnode = node[CC_STAGE_PARAMS];
325     if ( !stageParams->read( rnode ) )
326         return false;
327     
328     featureParams = CvFeatureParams::create(cascadeParams.featureType);
329     rnode = node[CC_FEATURE_PARAMS];
330     if ( !featureParams->read( rnode ) )
331         return false;
332     return true;    
333 }
334
335 bool CvCascadeClassifier::readStages( const FileNode &node)
336 {
337     FileNode rnode = node[CC_STAGES];
338     if (!rnode.empty() || !rnode.isSeq())
339         return false;
340     stageClassifiers.reserve(numStages);
341     FileNodeIterator it = rnode.begin();
342     for( int i = 0; i < min( (int)rnode.size(), numStages ); i++, it++ )
343     {
344         CvCascadeBoost* tempStage = new CvCascadeBoost;
345         if ( !tempStage->read( *it, (CvFeatureEvaluator *)featureEvaluator, *((CvCascadeBoostParams*)stageParams) ) )
346         {
347             delete tempStage;
348             return false;
349         }
350         stageClassifiers.push_back(tempStage);
351     }
352     return true;
353 }
354
355 // For old Haar Classifier file saving
356 #define ICV_HAAR_SIZE_NAME            "size"
357 #define ICV_HAAR_STAGES_NAME          "stages"
358 #define ICV_HAAR_TREES_NAME             "trees"
359 #define ICV_HAAR_FEATURE_NAME             "feature"
360 #define ICV_HAAR_RECTS_NAME                 "rects"
361 #define ICV_HAAR_TILTED_NAME                "tilted"
362 #define ICV_HAAR_THRESHOLD_NAME           "threshold"
363 #define ICV_HAAR_LEFT_NODE_NAME           "left_node"
364 #define ICV_HAAR_LEFT_VAL_NAME            "left_val"
365 #define ICV_HAAR_RIGHT_NODE_NAME          "right_node"
366 #define ICV_HAAR_RIGHT_VAL_NAME           "right_val"
367 #define ICV_HAAR_STAGE_THRESHOLD_NAME   "stage_threshold"
368 #define ICV_HAAR_PARENT_NAME            "parent"
369 #define ICV_HAAR_NEXT_NAME              "next"
370
371 void CvCascadeClassifier::save( const String filename, bool baseFormat )
372 {
373     FileStorage fs( filename, FileStorage::WRITE );
374
375     if ( !fs.isOpened() )
376         return;
377
378     fs << FileStorage::getDefaultObjectName(filename) << "{";
379     if ( !baseFormat )
380     {
381         Mat featureMap; 
382         getUsedFeaturesIdxMap( featureMap );
383         writeParams( fs );
384         fs << CC_STAGE_NUM << (int)stageClassifiers.size();
385         writeStages( fs, featureMap );
386         writeFeatures( fs, featureMap );
387     }
388     else
389     {
390         //char buf[256];
391         CvSeq* weak;
392         if ( cascadeParams.featureType != CvFeatureParams::HAAR )
393             CV_Error( CV_StsBadFunc, "old file format is used for Haar-like features only");
394         fs << ICV_HAAR_SIZE_NAME << "[:" << cascadeParams.winSize.width << 
395             cascadeParams.winSize.height << "]";
396         fs << ICV_HAAR_STAGES_NAME << "[";
397         for( size_t si = 0; si < stageClassifiers.size(); si++ )
398         {
399             fs << "{"; //stage
400             /*sprintf( buf, "stage %d", si );
401             CV_CALL( cvWriteComment( fs, buf, 1 ) );*/
402             weak = stageClassifiers[si]->get_weak_predictors();
403             fs << ICV_HAAR_TREES_NAME << "[";
404             for( int wi = 0; wi < weak->total; wi++ )
405             {
406                 int inner_node_idx = -1, total_inner_node_idx = -1;
407                 queue<const CvDTreeNode*> inner_nodes_queue;
408                 CvCascadeBoostTree* tree = *((CvCascadeBoostTree**) cvGetSeqElem( weak, wi ));
409                 
410                 fs << "[";
411                 /*sprintf( buf, "tree %d", wi );
412                 CV_CALL( cvWriteComment( fs, buf, 1 ) );*/
413
414                 const CvDTreeNode* tempNode;
415                 
416                 inner_nodes_queue.push( tree->get_root() );
417                 total_inner_node_idx++;
418                 
419                 while (!inner_nodes_queue.empty())
420                 {
421                     tempNode = inner_nodes_queue.front();
422                     inner_node_idx++;
423
424                     fs << "{";
425                     fs << ICV_HAAR_FEATURE_NAME << "{";
426                     ((CvHaarEvaluator*)((CvFeatureEvaluator*)featureEvaluator))->writeFeature( fs, tempNode->split->var_idx );
427                     fs << "}";
428
429                     fs << ICV_HAAR_THRESHOLD_NAME << tempNode->split->ord.c;
430
431                     if( tempNode->left->left || tempNode->left->right )
432                     {
433                         inner_nodes_queue.push( tempNode->left );
434                         total_inner_node_idx++;
435                         fs << ICV_HAAR_LEFT_NODE_NAME << total_inner_node_idx;
436                     }
437                     else
438                         fs << ICV_HAAR_LEFT_VAL_NAME << tempNode->left->value;
439
440                     if( tempNode->right->left || tempNode->right->right )
441                     {
442                         inner_nodes_queue.push( tempNode->right );
443                         total_inner_node_idx++;
444                         fs << ICV_HAAR_RIGHT_NODE_NAME << total_inner_node_idx;
445                     }
446                     else
447                         fs << ICV_HAAR_RIGHT_VAL_NAME << tempNode->right->value;
448                     fs << "}"; // ICV_HAAR_FEATURE_NAME
449                     inner_nodes_queue.pop();
450                 }
451                 fs << "]";
452             }
453             fs << "]"; //ICV_HAAR_TREES_NAME
454             fs << ICV_HAAR_STAGE_THRESHOLD_NAME << stageClassifiers[si]->getThreshold();
455             fs << ICV_HAAR_PARENT_NAME << (int)si-1 << ICV_HAAR_NEXT_NAME << -1;
456             fs << "}"; //stage
457         } /* for each stage */
458         fs << "]"; //ICV_HAAR_STAGES_NAME
459     }
460     fs << "}";
461 }
462
463 bool CvCascadeClassifier::load( const String cascadeDirName )
464 {
465     FileStorage fs( cascadeDirName + CC_PARAMS_FILENAME, FileStorage::READ );
466     if ( !fs.isOpened() )
467         return false;
468     FileNode node = fs.getFirstTopLevelNode();
469     if ( !readParams( node ) )
470         return false;
471     featureEvaluator = CvFeatureEvaluator::create(cascadeParams.featureType);
472     featureEvaluator->init( ((CvFeatureParams*)featureParams), numPos + numNeg, cascadeParams.winSize );
473     fs.release();
474
475     char buf[10];
476     for ( int si = 0; si < numStages; si++ )
477     {
478         sprintf( buf, "%s%d", "stage", si);
479         fs.open( cascadeDirName + buf + ".xml", FileStorage::READ );
480         node = fs.getFirstTopLevelNode();
481         if ( !fs.isOpened() )
482             break;
483         CvCascadeBoost *tempStage = new CvCascadeBoost; 
484
485         if ( !tempStage->read( node, (CvFeatureEvaluator*)featureEvaluator, *((CvCascadeBoostParams*)stageParams )) )
486         {
487             delete tempStage;
488             fs.release();
489             break;
490         }
491         stageClassifiers.push_back(tempStage);
492     }
493     return true;
494 }
495
496 void CvCascadeClassifier::getUsedFeaturesIdxMap( Mat& featureMap )
497 {
498     featureMap.create( 1, featureEvaluator->getNumFeatures(), CV_32SC1 );
499     featureMap.setTo(Scalar(-1));
500     
501     for( vector< Ptr<CvCascadeBoost> >::const_iterator it = stageClassifiers.begin();
502         it != stageClassifiers.end(); it++ )
503         ((CvCascadeBoost*)((Ptr<CvCascadeBoost>)(*it)))->markUsedFeaturesInMap( featureMap );
504     
505     for( int fi = 0, idx = 0; fi < featureEvaluator->getNumFeatures(); fi++ )
506         if ( featureMap.at<int>(0, fi) >= 0 )
507             featureMap.ptr<int>(0)[fi] = idx++;
508 }