1 /*M///////////////////////////////////////////////////////////////////////////////////////
3 // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
5 // By downloading, copying, installing or using the software you agree to this license.
6 // If you do not agree to this license, do not download, install,
7 // copy or use the software.
10 // Intel License Agreement
12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
13 // Third party copyrights are property of their respective owners.
15 // Redistribution and use in source and binary forms, with or without modification,
16 // are permitted provided that the following conditions are met:
18 // * Redistribution's of source code must retain the above copyright notice,
19 // this list of conditions and the following disclaimer.
21 // * Redistribution's in binary form must reproduce the above copyright notice,
22 // this list of conditions and the following disclaimer in the documentation
23 // and/or other materials provided with the distribution.
25 // * The name of Intel Corporation may not be used to endorse or promote products
26 // derived from this software without specific prior written permission.
28 // This software is provided by the copyright holders and contributors "as is" and
29 // any express or implied warranties, including, but not limited to, the implied
30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
31 // In no event shall the Intel Corporation or contributors be liable for any direct,
32 // indirect, incidental, special, exemplary, or consequential damages
33 // (including, but not limited to, procurement of substitute goods or services;
34 // loss of use, data, or profits; or business interruption) however caused
35 // and on any theory of liability, whether in contract, strict liability,
36 // or tort (including negligence or otherwise) arising in any way out of
37 // the use of this software, even if advised of the possibility of such damage.
44 log_ratio( double val )
46 const double eps = 1e-5;
48 val = MAX( val, eps );
49 val = MIN( val, 1. - eps );
50 return log( val/(1. - val) );
54 CvBoostParams::CvBoostParams()
56 boost_type = CvBoost::REAL;
58 weight_trim_rate = 0.95;
64 CvBoostParams::CvBoostParams( int _boost_type, int _weak_count,
65 double _weight_trim_rate, int _max_depth,
66 bool _use_surrogates, const float* _priors )
68 boost_type = _boost_type;
69 weak_count = _weak_count;
70 weight_trim_rate = _weight_trim_rate;
71 split_criteria = CvBoost::DEFAULT;
73 max_depth = _max_depth;
74 use_surrogates = _use_surrogates;
80 ///////////////////////////////// CvBoostTree ///////////////////////////////////
82 CvBoostTree::CvBoostTree()
88 CvBoostTree::~CvBoostTree()
103 CvBoostTree::train( CvDTreeTrainData* _train_data,
104 const CvMat* _subsample_idx, CvBoost* _ensemble )
107 ensemble = _ensemble;
111 return do_train( _subsample_idx );
116 CvBoostTree::train( const CvMat*, int, const CvMat*, const CvMat*,
117 const CvMat*, const CvMat*, const CvMat*, CvDTreeParams )
125 CvBoostTree::train( CvDTreeTrainData*, const CvMat* )
133 CvBoostTree::scale( double scale )
135 CvDTreeNode* node = root;
137 // traverse the tree and scale all the node values
143 node->value *= scale;
149 for( parent = node->parent; parent && parent->right == node;
150 node = parent, parent = parent->parent )
156 node = parent->right;
162 CvBoostTree::try_split_node( CvDTreeNode* node )
164 CvDTree::try_split_node( node );
168 // if the node has not been split,
169 // store the responses for the corresponding training samples
170 double* weak_eval = ensemble->get_weak_response()->data.db;
171 int* labels = data->get_labels( node );
172 int i, count = node->sample_count;
173 double value = node->value;
175 for( i = 0; i < count; i++ )
176 weak_eval[labels[i]] = value;
182 CvBoostTree::calc_node_dir( CvDTreeNode* node )
184 char* dir = (char*)data->direction->data.ptr;
185 const double* weights = ensemble->get_subtree_weights()->data.db;
186 int i, n = node->sample_count, vi = node->split->var_idx;
189 assert( !node->split->inversed );
191 if( data->get_var_type(vi) >= 0 ) // split on categorical var
193 const int* cat_labels = data->get_cat_var_data( node, vi );
194 const int* subset = node->split->subset;
195 double sum = 0, sum_abs = 0;
197 for( i = 0; i < n; i++ )
199 int idx = cat_labels[i];
200 double w = weights[i];
201 int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0;
202 sum += d*w; sum_abs += (d & 1)*w;
206 R = (sum_abs + sum) * 0.5;
207 L = (sum_abs - sum) * 0.5;
209 else // split on ordered var
211 const CvPair32s32f* sorted = data->get_ord_var_data(node,vi);
212 int split_point = node->split->ord.split_point;
213 int n1 = node->get_num_valid(vi);
215 assert( 0 <= split_point && split_point < n1-1 );
218 for( i = 0; i <= split_point; i++ )
220 int idx = sorted[i].i;
221 double w = weights[idx];
228 int idx = sorted[i].i;
229 double w = weights[idx];
235 dir[sorted[i].i] = (char)0;
238 node->maxlr = MAX( L, R );
239 return node->split->quality/(L + R);
244 CvBoostTree::find_split_ord_class( CvDTreeNode* node, int vi )
246 const float epsilon = FLT_EPSILON*2;
247 const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
248 const int* responses = data->get_class_labels(node);
249 const double* weights = ensemble->get_subtree_weights()->data.db;
250 int n = node->sample_count;
251 int n1 = node->get_num_valid(vi);
252 const double* rcw0 = weights + n;
253 double lcw[2] = {0,0}, rcw[2];
256 int boost_type = ensemble->get_params().boost_type;
257 int split_criteria = ensemble->get_params().split_criteria;
259 rcw[0] = rcw0[0]; rcw[1] = rcw0[1];
260 for( i = n1; i < n; i++ )
262 int idx = sorted[i].i;
263 double w = weights[idx];
264 rcw[responses[idx]] -= w;
267 if( split_criteria != CvBoost::GINI && split_criteria != CvBoost::MISCLASS )
268 split_criteria = boost_type == CvBoost::DISCRETE ? CvBoost::MISCLASS : CvBoost::GINI;
270 if( split_criteria == CvBoost::GINI )
272 double L = 0, R = rcw[0] + rcw[1];
273 double lsum2 = 0, rsum2 = rcw[0]*rcw[0] + rcw[1]*rcw[1];
275 for( i = 0; i < n1 - 1; i++ )
277 int idx = sorted[i].i;
278 double w = weights[idx], w2 = w*w;
280 idx = responses[idx];
282 lv = lcw[idx]; rv = rcw[idx];
283 lsum2 += 2*lv*w + w2;
284 rsum2 -= 2*rv*w - w2;
285 lcw[idx] = lv + w; rcw[idx] = rv - w;
287 if( sorted[i].val + epsilon < sorted[i+1].val )
289 double val = (lsum2*R + rsum2*L)/(L*R);
300 for( i = 0; i < n1 - 1; i++ )
302 int idx = sorted[i].i;
303 double w = weights[idx];
304 idx = responses[idx];
308 if( sorted[i].val + epsilon < sorted[i+1].val )
310 double val = lcw[0] + rcw[1], val2 = lcw[1] + rcw[0];
311 val = MAX(val, val2);
321 return best_i >= 0 ? data->new_split_ord( vi,
322 (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
323 0, (float)best_val ) : 0;
327 #define CV_CMP_NUM_PTR(a,b) (*(a) < *(b))
328 static CV_IMPLEMENT_QSORT_EX( icvSortDblPtr, double*, CV_CMP_NUM_PTR, int )
331 CvBoostTree::find_split_cat_class( CvDTreeNode* node, int vi )
334 const int* cat_labels = data->get_cat_var_data(node, vi);
335 const int* responses = data->get_class_labels(node);
336 int ci = data->get_var_type(vi);
337 int n = node->sample_count;
338 int mi = data->cat_count->data.i[ci];
339 double lcw[2]={0,0}, rcw[2]={0,0};
340 double* cjk = (double*)cvStackAlloc(2*(mi+1)*sizeof(cjk[0]))+2;
341 const double* weights = ensemble->get_subtree_weights()->data.db;
342 double** dbl_ptr = (double**)cvStackAlloc( mi*sizeof(dbl_ptr[0]) );
346 int best_subset = -1, subset_i;
347 int boost_type = ensemble->get_params().boost_type;
348 int split_criteria = ensemble->get_params().split_criteria;
350 // init array of counters:
351 // c_{jk} - number of samples that have vi-th input variable = j and response = k.
352 for( j = -1; j < mi; j++ )
353 cjk[j*2] = cjk[j*2+1] = 0;
355 for( i = 0; i < n; i++ )
357 double w = weights[i];
363 for( j = 0; j < mi; j++ )
366 rcw[1] += cjk[j*2+1];
367 dbl_ptr[j] = cjk + j*2 + 1;
372 if( split_criteria != CvBoost::GINI && split_criteria != CvBoost::MISCLASS )
373 split_criteria = boost_type == CvBoost::DISCRETE ? CvBoost::MISCLASS : CvBoost::GINI;
375 // sort rows of c_jk by increasing c_j,1
376 // (i.e. by the weight of samples in j-th category that belong to class 1)
377 icvSortDblPtr( dbl_ptr, mi, 0 );
379 for( subset_i = 0; subset_i < mi-1; subset_i++ )
381 idx = (int)(dbl_ptr[subset_i] - cjk)/2;
382 const double* crow = cjk + idx*2;
383 double w0 = crow[0], w1 = crow[1];
384 double weight = w0 + w1;
386 if( weight < FLT_EPSILON )
389 lcw[0] += w0; rcw[0] -= w0;
390 lcw[1] += w1; rcw[1] -= w1;
392 if( split_criteria == CvBoost::GINI )
394 double lsum2 = lcw[0]*lcw[0] + lcw[1]*lcw[1];
395 double rsum2 = rcw[0]*rcw[0] + rcw[1]*rcw[1];
400 if( L > FLT_EPSILON && R > FLT_EPSILON )
402 double val = (lsum2*R + rsum2*L)/(L*R);
406 best_subset = subset_i;
412 double val = lcw[0] + rcw[1];
413 double val2 = lcw[1] + rcw[0];
415 val = MAX(val, val2);
419 best_subset = subset_i;
424 if( best_subset < 0 )
427 split = data->new_split_cat( vi, (float)best_val );
429 for( i = 0; i <= best_subset; i++ )
431 idx = (int)(dbl_ptr[i] - cjk) >> 1;
432 split->subset[idx >> 5] |= 1 << (idx & 31);
440 CvBoostTree::find_split_ord_reg( CvDTreeNode* node, int vi )
442 const float epsilon = FLT_EPSILON*2;
443 const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
444 const float* responses = data->get_ord_responses(node);
445 const double* weights = ensemble->get_subtree_weights()->data.db;
446 int n = node->sample_count;
447 int n1 = node->get_num_valid(vi);
449 double best_val = 0, lsum = 0, rsum = node->value*n;
450 double L = 0, R = weights[n];
452 // compensate for missing values
453 for( i = n1; i < n; i++ )
455 int idx = sorted[i].i;
456 double w = weights[idx];
457 rsum -= responses[idx]*w;
461 // find the optimal split
462 for( i = 0; i < n1 - 1; i++ )
464 int idx = sorted[i].i;
465 double w = weights[idx];
466 double t = responses[idx]*w;
468 lsum += t; rsum -= t;
470 if( sorted[i].val + epsilon < sorted[i+1].val )
472 double val = (lsum*lsum*R + rsum*rsum*L)/(L*R);
481 return best_i >= 0 ? data->new_split_ord( vi,
482 (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
483 0, (float)best_val ) : 0;
488 CvBoostTree::find_split_cat_reg( CvDTreeNode* node, int vi )
491 const int* cat_labels = data->get_cat_var_data(node, vi);
492 const float* responses = data->get_ord_responses(node);
493 const double* weights = ensemble->get_subtree_weights()->data.db;
494 int ci = data->get_var_type(vi);
495 int n = node->sample_count;
496 int mi = data->cat_count->data.i[ci];
497 double* sum = (double*)cvStackAlloc( (mi+1)*sizeof(sum[0]) ) + 1;
498 double* counts = (double*)cvStackAlloc( (mi+1)*sizeof(counts[0]) ) + 1;
499 double** sum_ptr = (double**)cvStackAlloc( mi*sizeof(sum_ptr[0]) );
500 double L = 0, R = 0, best_val = 0, lsum = 0, rsum = 0;
501 int i, best_subset = -1, subset_i;
503 for( i = -1; i < mi; i++ )
504 sum[i] = counts[i] = 0;
506 // calculate sum response and weight of each category of the input var
507 for( i = 0; i < n; i++ )
509 int idx = cat_labels[i];
510 double w = weights[i];
511 double s = sum[idx] + responses[i]*w;
512 double nc = counts[idx] + w;
517 // calculate average response in each category
518 for( i = 0; i < mi; i++ )
523 sum_ptr[i] = sum + i;
526 icvSortDblPtr( sum_ptr, mi, 0 );
528 // revert back to unnormalized sums
529 // (there should be a very little loss in accuracy)
530 for( i = 0; i < mi; i++ )
533 for( subset_i = 0; subset_i < mi-1; subset_i++ )
535 int idx = (int)(sum_ptr[subset_i] - sum);
536 double ni = counts[idx];
538 if( ni > FLT_EPSILON )
544 if( L > FLT_EPSILON && R > FLT_EPSILON )
546 double val = (lsum*lsum*R + rsum*rsum*L)/(L*R);
550 best_subset = subset_i;
556 if( best_subset < 0 )
559 split = data->new_split_cat( vi, (float)best_val );
560 for( i = 0; i <= best_subset; i++ )
562 int idx = (int)(sum_ptr[i] - sum);
563 split->subset[idx >> 5] |= 1 << (idx & 31);
571 CvBoostTree::find_surrogate_split_ord( CvDTreeNode* node, int vi )
573 const float epsilon = FLT_EPSILON*2;
574 const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
575 const double* weights = ensemble->get_subtree_weights()->data.db;
576 const char* dir = (char*)data->direction->data.ptr;
577 int n1 = node->get_num_valid(vi);
578 // LL - number of samples that both the primary and the surrogate splits send to the left
579 // LR - ... primary split sends to the left and the surrogate split sends to the right
580 // RL - ... primary split sends to the right and the surrogate split sends to the left
581 // RR - ... both send to the right
582 int i, best_i = -1, best_inversed = 0;
584 double LL = 0, RL = 0, LR, RR;
585 double worst_val = node->maxlr;
586 double sum = 0, sum_abs = 0;
587 best_val = worst_val;
589 for( i = 0; i < n1; i++ )
591 int idx = sorted[i].i;
592 double w = weights[idx];
594 sum += d*w; sum_abs += (d & 1)*w;
597 // sum_abs = R + L; sum = R - L
598 RR = (sum_abs + sum)*0.5;
599 LR = (sum_abs - sum)*0.5;
601 // initially all the samples are sent to the right by the surrogate split,
602 // LR of them are sent to the left by primary split, and RR - to the right.
603 // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.
604 for( i = 0; i < n1 - 1; i++ )
606 int idx = sorted[i].i;
607 double w = weights[idx];
613 if( LL + RR > best_val && sorted[i].val + epsilon < sorted[i+1].val )
616 best_i = i; best_inversed = 0;
622 if( RL + LR > best_val && sorted[i].val + epsilon < sorted[i+1].val )
625 best_i = i; best_inversed = 1;
630 return best_i >= 0 && best_val > node->maxlr ? data->new_split_ord( vi,
631 (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
632 best_inversed, (float)best_val ) : 0;
637 CvBoostTree::find_surrogate_split_cat( CvDTreeNode* node, int vi )
639 const int* cat_labels = data->get_cat_var_data(node, vi);
640 const char* dir = (char*)data->direction->data.ptr;
641 const double* weights = ensemble->get_subtree_weights()->data.db;
642 int n = node->sample_count;
643 // LL - number of samples that both the primary and the surrogate splits send to the left
644 // LR - ... primary split sends to the left and the surrogate split sends to the right
645 // RL - ... primary split sends to the right and the surrogate split sends to the left
646 // RR - ... both send to the right
647 CvDTreeSplit* split = data->new_split_cat( vi, 0 );
648 int i, mi = data->cat_count->data.i[data->get_var_type(vi)];
650 double* lc = (double*)cvStackAlloc( (mi+1)*2*sizeof(lc[0]) ) + 1;
651 double* rc = lc + mi + 1;
653 for( i = -1; i < mi; i++ )
656 // 1. for each category calculate the weight of samples
657 // sent to the left (lc) and to the right (rc) by the primary split
658 for( i = 0; i < n; i++ )
660 int idx = cat_labels[i];
661 double w = weights[i];
663 double sum = lc[idx] + d*w;
664 double sum_abs = rc[idx] + (d & 1)*w;
665 lc[idx] = sum; rc[idx] = sum_abs;
668 for( i = 0; i < mi; i++ )
671 double sum_abs = rc[i];
672 lc[i] = (sum_abs - sum) * 0.5;
673 rc[i] = (sum_abs + sum) * 0.5;
676 // 2. now form the split.
677 // in each category send all the samples to the same direction as majority
678 for( i = 0; i < mi; i++ )
680 double lval = lc[i], rval = rc[i];
683 split->subset[i >> 5] |= 1 << (i & 31);
690 split->quality = (float)best_val;
691 if( split->quality <= node->maxlr )
692 cvSetRemoveByPtr( data->split_heap, split ), split = 0;
699 CvBoostTree::calc_node_value( CvDTreeNode* node )
701 int i, count = node->sample_count;
702 const double* weights = ensemble->get_weights()->data.db;
703 const int* labels = data->get_labels(node);
704 double* subtree_weights = ensemble->get_subtree_weights()->data.db;
705 double rcw[2] = {0,0};
706 int boost_type = ensemble->get_params().boost_type;
707 //const double* priors = data->priors->data.db;
709 if( data->is_classifier )
711 const int* responses = data->get_class_labels(node);
713 for( i = 0; i < count; i++ )
716 double w = weights[idx]/*priors[responses[i]]*/;
717 rcw[responses[i]] += w;
718 subtree_weights[i] = w;
721 node->class_idx = rcw[1] > rcw[0];
723 if( boost_type == CvBoost::DISCRETE )
725 // ignore cat_map for responses, and use {-1,1},
726 // as the whole ensemble response is computes as sign(sum_i(weak_response_i)
727 node->value = node->class_idx*2 - 1;
731 double p = rcw[1]/(rcw[0] + rcw[1]);
732 assert( boost_type == CvBoost::REAL );
734 // store log-ratio of the probability
735 node->value = 0.5*log_ratio(p);
740 // in case of regression tree:
741 // * node value is 1/n*sum_i(Y_i), where Y_i is i-th response,
742 // n is the number of samples in the node.
743 // * node risk is the sum of squared errors: sum_i((Y_i - <node_value>)^2)
744 double sum = 0, sum2 = 0, iw;
745 const float* values = data->get_ord_responses(node);
747 for( i = 0; i < count; i++ )
750 double w = weights[idx]/*priors[values[i] > 0]*/;
751 double t = values[i];
753 subtree_weights[i] = w;
759 node->value = sum*iw;
760 node->node_risk = sum2 - (sum*iw)*sum;
762 // renormalize the risk, as in try_split_node the unweighted formula
763 // sqrt(risk)/n is used, rather than sqrt(risk)/sum(weights_i)
764 node->node_risk *= count*iw*count*iw;
767 // store summary weights
768 subtree_weights[count] = rcw[0];
769 subtree_weights[count+1] = rcw[1];
773 void CvBoostTree::read( CvFileStorage* fs, CvFileNode* fnode, CvBoost* _ensemble, CvDTreeTrainData* _data )
775 CvDTree::read( fs, fnode, _data );
776 ensemble = _ensemble;
780 void CvBoostTree::read( CvFileStorage*, CvFileNode* )
785 void CvBoostTree::read( CvFileStorage* _fs, CvFileNode* _node,
786 CvDTreeTrainData* _data )
788 CvDTree::read( _fs, _node, _data );
792 /////////////////////////////////// CvBoost /////////////////////////////////////
798 default_model_name = "my_boost_tree";
799 orig_response = sum_response = weak_eval = subsample_mask =
800 weights = subtree_weights = 0;
806 void CvBoost::prune( CvSlice slice )
811 int i, count = cvSliceLength( slice, weak );
813 cvStartReadSeq( weak, &reader );
814 cvSetSeqReaderPos( &reader, slice.start_index );
816 for( i = 0; i < count; i++ )
819 CV_READ_SEQ_ELEM( w, reader );
823 cvSeqRemoveSlice( weak, slice );
828 void CvBoost::clear()
832 prune( CV_WHOLE_SEQ );
833 cvReleaseMemStorage( &weak->storage );
839 cvReleaseMat( &orig_response );
840 cvReleaseMat( &sum_response );
841 cvReleaseMat( &weak_eval );
842 cvReleaseMat( &subsample_mask );
843 cvReleaseMat( &weights );
844 have_subsample = false;
854 CvBoost::CvBoost( const CvMat* _train_data, int _tflag,
855 const CvMat* _responses, const CvMat* _var_idx,
856 const CvMat* _sample_idx, const CvMat* _var_type,
857 const CvMat* _missing_mask, CvBoostParams _params )
861 default_model_name = "my_boost_tree";
862 orig_response = sum_response = weak_eval = subsample_mask = weights = 0;
864 train( _train_data, _tflag, _responses, _var_idx, _sample_idx,
865 _var_type, _missing_mask, _params );
870 CvBoost::set_params( const CvBoostParams& _params )
874 CV_FUNCNAME( "CvBoost::set_params" );
879 if( params.boost_type != DISCRETE && params.boost_type != REAL &&
880 params.boost_type != LOGIT && params.boost_type != GENTLE )
881 CV_ERROR( CV_StsBadArg, "Unknown/unsupported boosting type" );
883 params.weak_count = MAX( params.weak_count, 1 );
884 params.weight_trim_rate = MAX( params.weight_trim_rate, 0. );
885 params.weight_trim_rate = MIN( params.weight_trim_rate, 1. );
886 if( params.weight_trim_rate < FLT_EPSILON )
887 params.weight_trim_rate = 1.f;
889 if( params.boost_type == DISCRETE &&
890 params.split_criteria != GINI && params.split_criteria != MISCLASS )
891 params.split_criteria = MISCLASS;
892 if( params.boost_type == REAL &&
893 params.split_criteria != GINI && params.split_criteria != MISCLASS )
894 params.split_criteria = GINI;
895 if( (params.boost_type == LOGIT || params.boost_type == GENTLE) &&
896 params.split_criteria != SQERR )
897 params.split_criteria = SQERR;
908 CvBoost::train( const CvMat* _train_data, int _tflag,
909 const CvMat* _responses, const CvMat* _var_idx,
910 const CvMat* _sample_idx, const CvMat* _var_type,
911 const CvMat* _missing_mask,
912 CvBoostParams _params, bool _update )
915 CvMemStorage* storage = 0;
917 CV_FUNCNAME( "CvBoost::train" );
923 set_params( _params );
925 if( !_update || !data )
928 data = new CvDTreeTrainData( _train_data, _tflag, _responses, _var_idx,
929 _sample_idx, _var_type, _missing_mask, _params, true, true );
931 if( data->get_num_classes() != 2 )
932 CV_ERROR( CV_StsNotImplemented,
933 "Boosted trees can only be used for 2-class classification." );
934 CV_CALL( storage = cvCreateMemStorage() );
935 weak = cvCreateSeq( 0, sizeof(CvSeq), sizeof(CvBoostTree*), storage );
940 data->set_data( _train_data, _tflag, _responses, _var_idx,
941 _sample_idx, _var_type, _missing_mask, _params, true, true, true );
946 for( i = 0; i < params.weak_count; i++ )
948 CvBoostTree* tree = new CvBoostTree;
949 if( !tree->train( data, subsample_mask, this ) )
954 //cvCheckArr( get_weak_response());
955 cvSeqPush( weak, &tree );
956 update_weights( tree );
960 data->is_classifier = true;
970 CvBoost::update_weights( CvBoostTree* tree )
972 CV_FUNCNAME( "CvBoost::update_weights" );
976 int i, count = data->sample_count;
979 if( !tree ) // before training the first tree, initialize weights and other parameters
981 const int* class_labels = data->get_class_labels(data->data_root);
982 // in case of logitboost and gentle adaboost each weak tree is a regression tree,
983 // so we need to convert class labels to floating-point values
984 float* responses = data->get_ord_responses(data->data_root);
985 int* labels = data->get_labels(data->data_root);
986 double w0 = 1./count;
987 double p[2] = { 1, 1 };
989 cvReleaseMat( &orig_response );
990 cvReleaseMat( &sum_response );
991 cvReleaseMat( &weak_eval );
992 cvReleaseMat( &subsample_mask );
993 cvReleaseMat( &weights );
995 CV_CALL( orig_response = cvCreateMat( 1, count, CV_32S ));
996 CV_CALL( weak_eval = cvCreateMat( 1, count, CV_64F ));
997 CV_CALL( subsample_mask = cvCreateMat( 1, count, CV_8U ));
998 CV_CALL( weights = cvCreateMat( 1, count, CV_64F ));
999 CV_CALL( subtree_weights = cvCreateMat( 1, count + 2, CV_64F ));
1001 if( data->have_priors )
1003 // compute weight scale for each class from their prior probabilities
1005 for( i = 0; i < count; i++ )
1006 c1 += class_labels[i];
1007 p[0] = data->priors->data.db[0]*(c1 < count ? 1./(count - c1) : 0.);
1008 p[1] = data->priors->data.db[1]*(c1 > 0 ? 1./c1 : 0.);
1009 p[0] /= p[0] + p[1];
1013 for( i = 0; i < count; i++ )
1015 // save original categorical responses {0,1}, convert them to {-1,1}
1016 orig_response->data.i[i] = class_labels[i]*2 - 1;
1017 // make all the samples active at start.
1018 // later, in trim_weights() deactivate/reactive again some, if need
1019 subsample_mask->data.ptr[i] = (uchar)1;
1020 // make all the initial weights the same.
1021 weights->data.db[i] = w0*p[class_labels[i]];
1022 // set the labels to find (from within weak tree learning proc)
1023 // the particular sample weight, and where to store the response.
1027 if( params.boost_type == LOGIT )
1029 CV_CALL( sum_response = cvCreateMat( 1, count, CV_64F ));
1031 for( i = 0; i < count; i++ )
1033 sum_response->data.db[i] = 0;
1034 responses[i] = orig_response->data.i[i] > 0 ? 2.f : -2.f;
1037 // in case of logitboost each weak tree is a regression tree.
1038 // the target function values are recalculated for each of the trees
1039 data->is_classifier = false;
1041 else if( params.boost_type == GENTLE )
1043 for( i = 0; i < count; i++ )
1044 responses[i] = (float)orig_response->data.i[i];
1046 data->is_classifier = false;
1051 // at this moment, for all the samples that participated in the training of the most
1052 // recent weak classifier we know the responses. For other samples we need to compute them
1053 if( have_subsample )
1055 float* values = (float*)(data->buf->data.ptr + data->buf->step);
1056 uchar* missing = data->buf->data.ptr + data->buf->step*2;
1057 CvMat _sample, _mask;
1059 // invert the subsample mask
1060 cvXorS( subsample_mask, cvScalar(1.), subsample_mask );
1061 data->get_vectors( subsample_mask, values, missing, 0 );
1062 //data->get_vectors( 0, values, missing, 0 );
1064 _sample = cvMat( 1, data->var_count, CV_32F );
1065 _mask = cvMat( 1, data->var_count, CV_8U );
1067 // run tree through all the non-processed samples
1068 for( i = 0; i < count; i++ )
1069 if( subsample_mask->data.ptr[i] )
1071 _sample.data.fl = values;
1072 _mask.data.ptr = missing;
1073 values += _sample.cols;
1074 missing += _mask.cols;
1075 weak_eval->data.db[i] = tree->predict( &_sample, &_mask, true )->value;
1079 // now update weights and other parameters for each type of boosting
1080 if( params.boost_type == DISCRETE )
1082 // Discrete AdaBoost:
1083 // weak_eval[i] (=f(x_i)) is in {-1,1}
1084 // err = sum(w_i*(f(x_i) != y_i))/sum(w_i)
1085 // C = log((1-err)/err)
1086 // w_i *= exp(C*(f(x_i) != y_i))
1089 double scale[] = { 1., 0. };
1091 for( i = 0; i < count; i++ )
1093 double w = weights->data.db[i];
1095 err += w*(weak_eval->data.db[i] != orig_response->data.i[i]);
1100 C = err = -log_ratio( err );
1101 scale[1] = exp(err);
1104 for( i = 0; i < count; i++ )
1106 double w = weights->data.db[i]*
1107 scale[weak_eval->data.db[i] != orig_response->data.i[i]];
1109 weights->data.db[i] = w;
1114 else if( params.boost_type == REAL )
1117 // weak_eval[i] = f(x_i) = 0.5*log(p(x_i)/(1-p(x_i))), p(x_i)=P(y=1|x_i)
1118 // w_i *= exp(-y_i*f(x_i))
1120 for( i = 0; i < count; i++ )
1121 weak_eval->data.db[i] *= -orig_response->data.i[i];
1123 cvExp( weak_eval, weak_eval );
1125 for( i = 0; i < count; i++ )
1127 double w = weights->data.db[i]*weak_eval->data.db[i];
1129 weights->data.db[i] = w;
1132 else if( params.boost_type == LOGIT )
1135 // weak_eval[i] = f(x_i) in [-z_max,z_max]
1136 // sum_response = F(x_i).
1137 // F(x_i) += 0.5*f(x_i)
1138 // p(x_i) = exp(F(x_i))/(exp(F(x_i)) + exp(-F(x_i))=1/(1+exp(-2*F(x_i)))
1139 // reuse weak_eval: weak_eval[i] <- p(x_i)
1140 // w_i = p(x_i)*1(1 - p(x_i))
1141 // z_i = ((y_i+1)/2 - p(x_i))/(p(x_i)*(1 - p(x_i)))
1142 // store z_i to the data->data_root as the new target responses
1144 const double lb_weight_thresh = FLT_EPSILON;
1145 const double lb_z_max = 10.;
1146 float* responses = data->get_ord_responses(data->data_root);
1148 /*if( weak->total == 7 )
1151 for( i = 0; i < count; i++ )
1153 double s = sum_response->data.db[i] + 0.5*weak_eval->data.db[i];
1154 sum_response->data.db[i] = s;
1155 weak_eval->data.db[i] = -2*s;
1158 cvExp( weak_eval, weak_eval );
1160 for( i = 0; i < count; i++ )
1162 double p = 1./(1. + weak_eval->data.db[i]);
1163 double w = p*(1 - p), z;
1164 w = MAX( w, lb_weight_thresh );
1165 weights->data.db[i] = w;
1167 if( orig_response->data.i[i] > 0 )
1170 responses[i] = (float)MIN(z, lb_z_max);
1175 responses[i] = (float)-MIN(z, lb_z_max);
1182 // weak_eval[i] = f(x_i) in [-1,1]
1183 // w_i *= exp(-y_i*f(x_i))
1184 assert( params.boost_type == GENTLE );
1186 for( i = 0; i < count; i++ )
1187 weak_eval->data.db[i] *= -orig_response->data.i[i];
1189 cvExp( weak_eval, weak_eval );
1191 for( i = 0; i < count; i++ )
1193 double w = weights->data.db[i] * weak_eval->data.db[i];
1194 weights->data.db[i] = w;
1200 // renormalize weights
1201 if( sumw > FLT_EPSILON )
1204 for( i = 0; i < count; ++i )
1205 weights->data.db[i] *= sumw;
1212 static CV_IMPLEMENT_QSORT_EX( icvSort_64f, double, CV_LT, int )
1216 CvBoost::trim_weights()
1218 CV_FUNCNAME( "CvBoost::trim_weights" );
1222 int i, count = data->sample_count, nz_count = 0;
1223 double sum, threshold;
1225 if( params.weight_trim_rate <= 0. || params.weight_trim_rate >= 1. )
1228 // use weak_eval as temporary buffer for sorted weights
1229 cvCopy( weights, weak_eval );
1231 icvSort_64f( weak_eval->data.db, count, 0 );
1233 // as weight trimming occurs immediately after updating the weights,
1234 // where they are renormalized, we assume that the weight sum = 1.
1235 sum = 1. - params.weight_trim_rate;
1237 for( i = 0; i < count; i++ )
1239 double w = weak_eval->data.db[i];
1245 threshold = i < count ? weak_eval->data.db[i] : DBL_MAX;
1247 for( i = 0; i < count; i++ )
1249 double w = weights->data.db[i];
1250 int f = w > threshold;
1251 subsample_mask->data.ptr[i] = (uchar)f;
1255 have_subsample = nz_count < count;
1262 CvBoost::predict( const CvMat* _sample, const CvMat* _missing,
1263 CvMat* weak_responses, CvSlice slice,
1264 bool raw_mode ) const
1267 bool allocated = false;
1268 float value = -FLT_MAX;
1270 CV_FUNCNAME( "CvBoost::predict" );
1274 int i, weak_count, var_count;
1275 CvMat sample, missing;
1285 CV_ERROR( CV_StsError, "The boosted tree ensemble has not been trained yet" );
1287 if( !CV_IS_MAT(_sample) || CV_MAT_TYPE(_sample->type) != CV_32FC1 ||
1288 _sample->cols != 1 && _sample->rows != 1 ||
1289 _sample->cols + _sample->rows - 1 != data->var_all && !raw_mode ||
1290 _sample->cols + _sample->rows - 1 != data->var_count && raw_mode )
1291 CV_ERROR( CV_StsBadArg,
1292 "the input sample must be 1d floating-point vector with the same "
1293 "number of elements as the total number of variables used for training" );
1297 if( !CV_IS_MAT(_missing) || !CV_IS_MASK_ARR(_missing) ||
1298 !CV_ARE_SIZES_EQ(_missing, _sample) )
1299 CV_ERROR( CV_StsBadArg,
1300 "the missing data mask must be 8-bit vector of the same size as input sample" );
1303 weak_count = cvSliceLength( slice, weak );
1304 if( weak_count >= weak->total )
1306 weak_count = weak->total;
1307 slice.start_index = 0;
1310 if( weak_responses )
1312 if( !CV_IS_MAT(weak_responses) ||
1313 CV_MAT_TYPE(weak_responses->type) != CV_32FC1 ||
1314 weak_responses->cols != 1 && weak_responses->rows != 1 ||
1315 weak_responses->cols + weak_responses->rows - 1 != weak_count )
1316 CV_ERROR( CV_StsBadArg,
1317 "The output matrix of weak classifier responses must be valid "
1318 "floating-point vector of the same number of components as the length of input slice" );
1319 wstep = CV_IS_MAT_CONT(weak_responses->type) ? 1 : weak_responses->step/sizeof(float);
1322 var_count = data->var_count;
1323 vtype = data->var_type->data.i;
1324 cmap = data->cat_map->data.i;
1325 cofs = data->cat_ofs->data.i;
1327 // if need, preprocess the input vector
1328 if( !raw_mode && (data->cat_var_count > 0 || data->var_idx) )
1331 int step, mstep = 0;
1332 const float* src_sample;
1333 const uchar* src_mask = 0;
1336 const int* vidx = data->var_idx && !raw_mode ? data->var_idx->data.i : 0;
1337 bool have_mask = _missing != 0;
1339 bufsize = var_count*(sizeof(float) + sizeof(uchar));
1340 if( bufsize <= CV_MAX_LOCAL_SIZE )
1341 buf = (float*)cvStackAlloc( bufsize );
1344 CV_CALL( buf = (float*)cvAlloc( bufsize ));
1348 dst_mask = (uchar*)(buf + var_count);
1350 src_sample = _sample->data.fl;
1351 step = CV_IS_MAT_CONT(_sample->type) ? 1 : _sample->step/sizeof(src_sample[0]);
1355 src_mask = _missing->data.ptr;
1356 mstep = CV_IS_MAT_CONT(_missing->type) ? 1 : _missing->step;
1359 for( i = 0; i < var_count; i++ )
1361 int idx = vidx ? vidx[i] : i;
1362 float val = src_sample[idx*step];
1364 uchar m = src_mask ? src_mask[i] : (uchar)0;
1368 int a = cofs[ci], b = cofs[ci+1], c = a;
1369 int ival = cvRound(val);
1371 CV_ERROR( CV_StsBadArg,
1372 "one of input categorical variable is not an integer" );
1377 if( ival < cmap[c] )
1379 else if( ival > cmap[c] )
1385 if( c < 0 || ival != cmap[c] )
1392 val = (float)(c - cofs[ci]);
1396 dst_sample[i] = val;
1400 sample = cvMat( 1, var_count, CV_32F, dst_sample );
1405 missing = cvMat( 1, var_count, CV_8UC1, dst_mask );
1406 _missing = &missing;
1410 cvStartReadSeq( weak, &reader );
1411 cvSetSeqReaderPos( &reader, slice.start_index );
1413 for( i = 0; i < weak_count; i++ )
1418 CV_READ_SEQ_ELEM( wtree, reader );
1420 val = wtree->predict( _sample, _missing, true )->value;
1421 if( weak_responses )
1422 weak_responses->data.fl[i*wstep] = (float)val;
1429 value = (float)cls_idx;
1431 value = (float)cmap[cofs[vtype[var_count]] + cls_idx];
1443 void CvBoost::write_params( CvFileStorage* fs )
1445 CV_FUNCNAME( "CvBoost::write_params" );
1449 const char* boost_type_str =
1450 params.boost_type == DISCRETE ? "DiscreteAdaboost" :
1451 params.boost_type == REAL ? "RealAdaboost" :
1452 params.boost_type == LOGIT ? "LogitBoost" :
1453 params.boost_type == GENTLE ? "GentleAdaboost" : 0;
1455 const char* split_crit_str =
1456 params.split_criteria == DEFAULT ? "Default" :
1457 params.split_criteria == GINI ? "Gini" :
1458 params.boost_type == MISCLASS ? "Misclassification" :
1459 params.boost_type == SQERR ? "SquaredErr" : 0;
1461 if( boost_type_str )
1462 cvWriteString( fs, "boosting_type", boost_type_str );
1464 cvWriteInt( fs, "boosting_type", params.boost_type );
1466 if( split_crit_str )
1467 cvWriteString( fs, "splitting_criteria", split_crit_str );
1469 cvWriteInt( fs, "splitting_criteria", params.split_criteria );
1471 cvWriteInt( fs, "ntrees", params.weak_count );
1472 cvWriteReal( fs, "weight_trimming_rate", params.weight_trim_rate );
1474 data->write_params( fs );
1480 void CvBoost::read_params( CvFileStorage* fs, CvFileNode* fnode )
1482 CV_FUNCNAME( "CvBoost::read_params" );
1488 if( !fnode || !CV_NODE_IS_MAP(fnode->tag) )
1491 data = new CvDTreeTrainData();
1492 CV_CALL( data->read_params(fs, fnode));
1493 data->shared = true;
1495 params.max_depth = data->params.max_depth;
1496 params.min_sample_count = data->params.min_sample_count;
1497 params.max_categories = data->params.max_categories;
1498 params.priors = data->params.priors;
1499 params.regression_accuracy = data->params.regression_accuracy;
1500 params.use_surrogates = data->params.use_surrogates;
1502 temp = cvGetFileNodeByName( fs, fnode, "boosting_type" );
1506 if( temp && CV_NODE_IS_STRING(temp->tag) )
1508 const char* boost_type_str = cvReadString( temp, "" );
1509 params.boost_type = strcmp( boost_type_str, "DiscreteAdaboost" ) == 0 ? DISCRETE :
1510 strcmp( boost_type_str, "RealAdaboost" ) == 0 ? REAL :
1511 strcmp( boost_type_str, "LogitBoost" ) == 0 ? LOGIT :
1512 strcmp( boost_type_str, "GentleAdaboost" ) == 0 ? GENTLE : -1;
1515 params.boost_type = cvReadInt( temp, -1 );
1517 if( params.boost_type < DISCRETE || params.boost_type > GENTLE )
1518 CV_ERROR( CV_StsBadArg, "Unknown boosting type" );
1520 temp = cvGetFileNodeByName( fs, fnode, "splitting_criteria" );
1521 if( temp && CV_NODE_IS_STRING(temp->tag) )
1523 const char* split_crit_str = cvReadString( temp, "" );
1524 params.split_criteria = strcmp( split_crit_str, "Default" ) == 0 ? DEFAULT :
1525 strcmp( split_crit_str, "Gini" ) == 0 ? GINI :
1526 strcmp( split_crit_str, "Misclassification" ) == 0 ? MISCLASS :
1527 strcmp( split_crit_str, "SquaredErr" ) == 0 ? SQERR : -1;
1530 params.split_criteria = cvReadInt( temp, -1 );
1532 if( params.split_criteria < DEFAULT || params.boost_type > SQERR )
1533 CV_ERROR( CV_StsBadArg, "Unknown boosting type" );
1535 params.weak_count = cvReadIntByName( fs, fnode, "ntrees" );
1536 params.weight_trim_rate = cvReadRealByName( fs, fnode, "weight_trimming_rate", 0. );
1544 CvBoost::read( CvFileStorage* fs, CvFileNode* node )
1546 CV_FUNCNAME( "CvRTrees::read" );
1551 CvFileNode* trees_fnode;
1552 CvMemStorage* storage;
1556 read_params( fs, node );
1561 trees_fnode = cvGetFileNodeByName( fs, node, "trees" );
1562 if( !trees_fnode || !CV_NODE_IS_SEQ(trees_fnode->tag) )
1563 CV_ERROR( CV_StsParseError, "<trees> tag is missing" );
1565 cvStartReadSeq( trees_fnode->data.seq, &reader );
1566 ntrees = trees_fnode->data.seq->total;
1568 if( ntrees != params.weak_count )
1569 CV_ERROR( CV_StsUnmatchedSizes,
1570 "The number of trees stored does not match <ntrees> tag value" );
1572 CV_CALL( storage = cvCreateMemStorage() );
1573 weak = cvCreateSeq( 0, sizeof(CvSeq), sizeof(CvBoostTree*), storage );
1575 for( i = 0; i < ntrees; i++ )
1577 CvBoostTree* tree = new CvBoostTree();
1578 CV_CALL(tree->read( fs, (CvFileNode*)reader.ptr, this, data ));
1579 CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
1580 cvSeqPush( weak, &tree );
1588 CvBoost::write( CvFileStorage* fs, const char* name )
1590 CV_FUNCNAME( "CvBoost::write" );
1597 cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_BOOSTING );
1600 CV_ERROR( CV_StsBadArg, "The classifier has not been trained yet" );
1603 cvStartWriteStruct( fs, "trees", CV_NODE_SEQ );
1605 cvStartReadSeq( weak, &reader );
1607 for( i = 0; i < weak->total; i++ )
1610 CV_READ_SEQ_ELEM( tree, reader );
1611 cvStartWriteStruct( fs, 0, CV_NODE_MAP );
1613 cvEndWriteStruct( fs );
1616 cvEndWriteStruct( fs );
1617 cvEndWriteStruct( fs );
1624 CvBoost::get_weights()
1631 CvBoost::get_subtree_weights()
1633 return subtree_weights;
1638 CvBoost::get_weak_response()
1644 const CvBoostParams&
1645 CvBoost::get_params() const
1650 CvSeq* CvBoost::get_weak_predictors()