1 /*M///////////////////////////////////////////////////////////////////////////////////////
3 // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
5 // By downloading, copying, installing or using the software you agree to this license.
6 // If you do not agree to this license, do not download, install,
7 // copy or use the software.
10 // Intel License Agreement
12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
13 // Third party copyrights are property of their respective owners.
15 // Redistribution and use in source and binary forms, with or without modification,
16 // are permitted provided that the following conditions are met:
18 // * Redistribution's of source code must retain the above copyright notice,
19 // this list of conditions and the following disclaimer.
21 // * Redistribution's in binary form must reproduce the above copyright notice,
22 // this list of conditions and the following disclaimer in the documentation
23 // and/or other materials provided with the distribution.
25 // * The name of Intel Corporation may not be used to endorse or promote products
26 // derived from this software without specific prior written permission.
28 // This software is provided by the copyright holders and contributors "as is" and
29 // any express or implied warranties, including, but not limited to, the implied
30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
31 // In no event shall the Intel Corporation or contributors be liable for any direct,
32 // indirect, incidental, special, exemplary, or consequential damages
33 // (including, but not limited to, procurement of substitute goods or services;
34 // loss of use, data, or profits; or business interruption) however caused
35 // and on any theory of liability, whether in contract, strict liability,
36 // or tort (including negligence or otherwise) arising in any way out of
37 // the use of this software, even if advised of the possibility of such damage.
44 log_ratio( double val )
46 const double eps = 1e-5;
48 val = MAX( val, eps );
49 val = MIN( val, 1. - eps );
50 return log( val/(1. - val) );
54 CvBoostParams::CvBoostParams()
56 boost_type = CvBoost::REAL;
58 weight_trim_rate = 0.95;
64 CvBoostParams::CvBoostParams( int _boost_type, int _weak_count,
65 double _weight_trim_rate, int _max_depth,
66 bool _use_surrogates, const float* _priors )
68 boost_type = _boost_type;
69 weak_count = _weak_count;
70 weight_trim_rate = _weight_trim_rate;
71 split_criteria = CvBoost::DEFAULT;
73 max_depth = _max_depth;
74 use_surrogates = _use_surrogates;
80 ///////////////////////////////// CvBoostTree ///////////////////////////////////
82 CvBoostTree::CvBoostTree()
88 CvBoostTree::~CvBoostTree()
103 CvBoostTree::train( CvDTreeTrainData* _train_data,
104 const CvMat* _subsample_idx, CvBoost* _ensemble )
107 ensemble = _ensemble;
111 return do_train( _subsample_idx );
116 CvBoostTree::train( const CvMat*, int, const CvMat*, const CvMat*,
117 const CvMat*, const CvMat*, const CvMat*, CvDTreeParams )
125 CvBoostTree::train( CvDTreeTrainData*, const CvMat* )
133 CvBoostTree::scale( double scale )
135 CvDTreeNode* node = root;
137 // traverse the tree and scale all the node values
143 node->value *= scale;
149 for( parent = node->parent; parent && parent->right == node;
150 node = parent, parent = parent->parent )
156 node = parent->right;
162 CvBoostTree::try_split_node( CvDTreeNode* node )
164 CvDTree::try_split_node( node );
168 // if the node has not been split,
169 // store the responses for the corresponding training samples
170 double* weak_eval = ensemble->get_weak_response()->data.db;
171 int* labels_buf = data->get_cv_lables_buf();
172 const int* labels = 0;
173 data->get_cv_labels( node, labels_buf, &labels );
174 int i, count = node->sample_count;
175 double value = node->value;
177 for( i = 0; i < count; i++ )
178 weak_eval[labels[i]] = value;
184 CvBoostTree::calc_node_dir( CvDTreeNode* node )
186 char* dir = (char*)data->direction->data.ptr;
187 const double* weights = ensemble->get_subtree_weights()->data.db;
188 int i, n = node->sample_count, vi = node->split->var_idx;
191 assert( !node->split->inversed );
193 if( data->get_var_type(vi) >= 0 ) // split on categorical var
195 int* cat_labels_buf = data->get_pred_int_buf();
196 const int* cat_labels = 0;
197 data->get_cat_var_data( node, vi, cat_labels_buf, &cat_labels );
198 const int* subset = node->split->subset;
199 double sum = 0, sum_abs = 0;
201 for( i = 0; i < n; i++ )
203 int idx = ((cat_labels[i] == 65535) && data->is_buf_16u) ? -1 : cat_labels[i];
204 double w = weights[i];
205 int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0;
206 sum += d*w; sum_abs += (d & 1)*w;
210 R = (sum_abs + sum) * 0.5;
211 L = (sum_abs - sum) * 0.5;
213 else // split on ordered var
215 float* values_buf = data->get_pred_float_buf();
216 const float* values = 0;
217 int* indices_buf = data->get_pred_int_buf();
218 const int* indices = 0;
219 data->get_ord_var_data( node, vi, values_buf, indices_buf, &values, &indices );
220 int split_point = node->split->ord.split_point;
221 int n1 = node->get_num_valid(vi);
223 assert( 0 <= split_point && split_point < n1-1 );
226 for( i = 0; i <= split_point; i++ )
228 int idx = indices[i];
229 double w = weights[idx];
236 int idx = indices[i];
237 double w = weights[idx];
243 dir[indices[i]] = (char)0;
246 node->maxlr = MAX( L, R );
247 return node->split->quality/(L + R);
252 CvBoostTree::find_split_ord_class( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split )
254 const float epsilon = FLT_EPSILON*2;
256 const double* weights = ensemble->get_subtree_weights()->data.db;
257 int n = node->sample_count;
258 int n1 = node->get_num_valid(vi);
259 float* values_buf = data->get_pred_float_buf();
260 const float* values = 0;
261 int* indices_buf = data->get_pred_int_buf();
262 const int* indices = 0;
263 data->get_ord_var_data( node, vi, values_buf, indices_buf, &values, &indices );
264 int* responses_buf = data->get_resp_int_buf();
265 const int* responses = 0;
266 data->get_class_labels( node, responses_buf, &responses);
267 const double* rcw0 = weights + n;
268 double lcw[2] = {0,0}, rcw[2];
270 double best_val = init_quality;
271 int boost_type = ensemble->get_params().boost_type;
272 int split_criteria = ensemble->get_params().split_criteria;
274 rcw[0] = rcw0[0]; rcw[1] = rcw0[1];
275 for( i = n1; i < n; i++ )
277 int idx = indices[i];
278 double w = weights[idx];
279 rcw[responses[idx]] -= w;
282 if( split_criteria != CvBoost::GINI && split_criteria != CvBoost::MISCLASS )
283 split_criteria = boost_type == CvBoost::DISCRETE ? CvBoost::MISCLASS : CvBoost::GINI;
285 if( split_criteria == CvBoost::GINI )
287 double L = 0, R = rcw[0] + rcw[1];
288 double lsum2 = 0, rsum2 = rcw[0]*rcw[0] + rcw[1]*rcw[1];
290 for( i = 0; i < n1 - 1; i++ )
292 int idx = indices[i];
293 double w = weights[idx], w2 = w*w;
295 idx = responses[idx];
297 lv = lcw[idx]; rv = rcw[idx];
298 lsum2 += 2*lv*w + w2;
299 rsum2 -= 2*rv*w - w2;
300 lcw[idx] = lv + w; rcw[idx] = rv - w;
302 if( values[i] + epsilon < values[i+1] )
304 double val = (lsum2*R + rsum2*L)/(L*R);
315 for( i = 0; i < n1 - 1; i++ )
317 int idx = indices[i];
318 double w = weights[idx];
319 idx = responses[idx];
323 if( values[i] + epsilon < values[i+1] )
325 double val = lcw[0] + rcw[1], val2 = lcw[1] + rcw[0];
326 val = MAX(val, val2);
336 CvDTreeSplit* split = 0;
339 split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f );
341 split->ord.c = (values[best_i] + values[best_i+1])*0.5f;
342 split->ord.split_point = best_i;
344 split->quality = (float)best_val;
350 #define CV_CMP_NUM_PTR(a,b) (*(a) < *(b))
351 static CV_IMPLEMENT_QSORT_EX( icvSortDblPtr, double*, CV_CMP_NUM_PTR, int )
354 CvBoostTree::find_split_cat_class( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split )
356 int ci = data->get_var_type(vi);
357 int n = node->sample_count;
358 int mi = data->cat_count->data.i[ci];
359 int* cat_labels_buf = data->get_pred_int_buf();
360 const int* cat_labels = 0;
361 data->get_cat_var_data(node, vi, cat_labels_buf, &cat_labels);
362 int* responses_buf = data->get_resp_int_buf();
363 const int* responses = 0;
364 data->get_class_labels(node, responses_buf, &responses);
365 double lcw[2]={0,0}, rcw[2]={0,0};
366 double* cjk = (double*)cvStackAlloc(2*(mi+1)*sizeof(cjk[0]))+2;
367 const double* weights = ensemble->get_subtree_weights()->data.db;
368 double** dbl_ptr = (double**)cvStackAlloc( mi*sizeof(dbl_ptr[0]) );
371 double best_val = init_quality;
372 int best_subset = -1, subset_i;
373 int boost_type = ensemble->get_params().boost_type;
374 int split_criteria = ensemble->get_params().split_criteria;
376 // init array of counters:
377 // c_{jk} - number of samples that have vi-th input variable = j and response = k.
378 for( j = -1; j < mi; j++ )
379 cjk[j*2] = cjk[j*2+1] = 0;
381 for( i = 0; i < n; i++ )
383 double w = weights[i];
384 j = ((cat_labels[i] == 65535) && data->is_buf_16u) ? -1 : cat_labels[i];
389 for( j = 0; j < mi; j++ )
392 rcw[1] += cjk[j*2+1];
393 dbl_ptr[j] = cjk + j*2 + 1;
398 if( split_criteria != CvBoost::GINI && split_criteria != CvBoost::MISCLASS )
399 split_criteria = boost_type == CvBoost::DISCRETE ? CvBoost::MISCLASS : CvBoost::GINI;
401 // sort rows of c_jk by increasing c_j,1
402 // (i.e. by the weight of samples in j-th category that belong to class 1)
403 icvSortDblPtr( dbl_ptr, mi, 0 );
405 for( subset_i = 0; subset_i < mi-1; subset_i++ )
407 idx = (int)(dbl_ptr[subset_i] - cjk)/2;
408 const double* crow = cjk + idx*2;
409 double w0 = crow[0], w1 = crow[1];
410 double weight = w0 + w1;
412 if( weight < FLT_EPSILON )
415 lcw[0] += w0; rcw[0] -= w0;
416 lcw[1] += w1; rcw[1] -= w1;
418 if( split_criteria == CvBoost::GINI )
420 double lsum2 = lcw[0]*lcw[0] + lcw[1]*lcw[1];
421 double rsum2 = rcw[0]*rcw[0] + rcw[1]*rcw[1];
426 if( L > FLT_EPSILON && R > FLT_EPSILON )
428 double val = (lsum2*R + rsum2*L)/(L*R);
432 best_subset = subset_i;
438 double val = lcw[0] + rcw[1];
439 double val2 = lcw[1] + rcw[0];
441 val = MAX(val, val2);
445 best_subset = subset_i;
450 CvDTreeSplit* split = 0;
451 if( best_subset >= 0 )
453 split = _split ? _split : data->new_split_cat( 0, -1.0f);
455 split->quality = (float)best_val;
456 memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int));
457 for( i = 0; i <= best_subset; i++ )
459 idx = (int)(dbl_ptr[i] - cjk) >> 1;
460 split->subset[idx >> 5] |= 1 << (idx & 31);
468 CvBoostTree::find_split_ord_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split )
470 const float epsilon = FLT_EPSILON*2;
471 const double* weights = ensemble->get_subtree_weights()->data.db;
472 int n = node->sample_count;
473 int n1 = node->get_num_valid(vi);
475 float* values_buf = data->get_pred_float_buf();
476 const float* values = 0;
477 int* indices_buf = data->get_pred_int_buf();
478 const int* indices = 0;
479 data->get_ord_var_data( node, vi, values_buf, indices_buf, &values, &indices );
480 float* responses_buf = data->get_resp_float_buf();
481 const float* responses = 0;
482 data->get_ord_responses(node, responses_buf, &responses);
485 double L = 0, R = weights[n];
486 double best_val = init_quality, lsum = 0, rsum = node->value*R;
488 // compensate for missing values
489 for( i = n1; i < n; i++ )
491 int idx = indices[i];
492 double w = weights[idx];
493 rsum -= responses[idx]*w;
497 // find the optimal split
498 for( i = 0; i < n1 - 1; i++ )
500 int idx = indices[i];
501 double w = weights[idx];
502 double t = responses[idx]*w;
504 lsum += t; rsum -= t;
506 if( values[i] + epsilon < values[i+1] )
508 double val = (lsum*lsum*R + rsum*rsum*L)/(L*R);
517 CvDTreeSplit* split = 0;
520 split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f );
522 split->ord.c = (values[best_i] + values[best_i+1])*0.5f;
523 split->ord.split_point = best_i;
525 split->quality = (float)best_val;
532 CvBoostTree::find_split_cat_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split )
534 const double* weights = ensemble->get_subtree_weights()->data.db;
535 int ci = data->get_var_type(vi);
536 int n = node->sample_count;
537 int mi = data->cat_count->data.i[ci];
538 int* cat_labels_buf = data->get_pred_int_buf();
539 const int* cat_labels = 0;
540 data->get_cat_var_data(node, vi, cat_labels_buf, &cat_labels);
541 float* responses_buf = data->get_resp_float_buf();
542 const float* responses = 0;
543 data->get_ord_responses(node, responses_buf, &responses);
545 double* sum = (double*)cvStackAlloc( (mi+1)*sizeof(sum[0]) ) + 1;
546 double* counts = (double*)cvStackAlloc( (mi+1)*sizeof(counts[0]) ) + 1;
547 double** sum_ptr = (double**)cvStackAlloc( mi*sizeof(sum_ptr[0]) );
548 double L = 0, R = 0, best_val = init_quality, lsum = 0, rsum = 0;
549 int i, best_subset = -1, subset_i;
551 for( i = -1; i < mi; i++ )
552 sum[i] = counts[i] = 0;
554 // calculate sum response and weight of each category of the input var
555 for( i = 0; i < n; i++ )
557 int idx = ((cat_labels[i] == 65535) && data->is_buf_16u) ? -1 : cat_labels[i];
558 double w = weights[i];
559 double s = sum[idx] + responses[i]*w;
560 double nc = counts[idx] + w;
565 // calculate average response in each category
566 for( i = 0; i < mi; i++ )
571 sum_ptr[i] = sum + i;
574 icvSortDblPtr( sum_ptr, mi, 0 );
576 // revert back to unnormalized sums
577 // (there should be a very little loss in accuracy)
578 for( i = 0; i < mi; i++ )
581 for( subset_i = 0; subset_i < mi-1; subset_i++ )
583 int idx = (int)(sum_ptr[subset_i] - sum);
584 double ni = counts[idx];
586 if( ni > FLT_EPSILON )
592 if( L > FLT_EPSILON && R > FLT_EPSILON )
594 double val = (lsum*lsum*R + rsum*rsum*L)/(L*R);
598 best_subset = subset_i;
604 CvDTreeSplit* split = 0;
605 if( best_subset >= 0 )
607 split = _split ? _split : data->new_split_cat( 0, -1.0f);
609 split->quality = (float)best_val;
610 memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int));
611 for( i = 0; i <= best_subset; i++ )
613 int idx = (int)(sum_ptr[i] - sum);
614 split->subset[idx >> 5] |= 1 << (idx & 31);
622 CvBoostTree::find_surrogate_split_ord( CvDTreeNode* node, int vi )
624 const float epsilon = FLT_EPSILON*2;
625 float* values_buf = data->get_pred_float_buf();
626 const float* values = 0;
627 int* indices_buf = data->get_pred_int_buf();
628 const int* indices = 0;
629 data->get_ord_var_data( node, vi, values_buf, indices_buf, &values, &indices );
631 const double* weights = ensemble->get_subtree_weights()->data.db;
632 const char* dir = (char*)data->direction->data.ptr;
633 int n1 = node->get_num_valid(vi);
634 // LL - number of samples that both the primary and the surrogate splits send to the left
635 // LR - ... primary split sends to the left and the surrogate split sends to the right
636 // RL - ... primary split sends to the right and the surrogate split sends to the left
637 // RR - ... both send to the right
638 int i, best_i = -1, best_inversed = 0;
640 double LL = 0, RL = 0, LR, RR;
641 double worst_val = node->maxlr;
642 double sum = 0, sum_abs = 0;
643 best_val = worst_val;
645 for( i = 0; i < n1; i++ )
647 int idx = indices[i];
648 double w = weights[idx];
650 sum += d*w; sum_abs += (d & 1)*w;
653 // sum_abs = R + L; sum = R - L
654 RR = (sum_abs + sum)*0.5;
655 LR = (sum_abs - sum)*0.5;
657 // initially all the samples are sent to the right by the surrogate split,
658 // LR of them are sent to the left by primary split, and RR - to the right.
659 // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.
660 for( i = 0; i < n1 - 1; i++ )
662 int idx = indices[i];
663 double w = weights[idx];
669 if( LL + RR > best_val && values[i] + epsilon < values[i+1] )
672 best_i = i; best_inversed = 0;
678 if( RL + LR > best_val && values[i] + epsilon < values[i+1] )
681 best_i = i; best_inversed = 1;
686 return best_i >= 0 && best_val > node->maxlr ? data->new_split_ord( vi,
687 (values[best_i] + values[best_i+1])*0.5f, best_i,
688 best_inversed, (float)best_val ) : 0;
693 CvBoostTree::find_surrogate_split_cat( CvDTreeNode* node, int vi )
695 const char* dir = (char*)data->direction->data.ptr;
696 const double* weights = ensemble->get_subtree_weights()->data.db;
697 int n = node->sample_count;
698 int* cat_labels_buf = data->get_pred_int_buf();
699 const int* cat_labels = 0;
700 data->get_cat_var_data(node, vi, cat_labels_buf, &cat_labels);
702 // LL - number of samples that both the primary and the surrogate splits send to the left
703 // LR - ... primary split sends to the left and the surrogate split sends to the right
704 // RL - ... primary split sends to the right and the surrogate split sends to the left
705 // RR - ... both send to the right
706 CvDTreeSplit* split = data->new_split_cat( vi, 0 );
707 int i, mi = data->cat_count->data.i[data->get_var_type(vi)];
709 double* lc = (double*)cvStackAlloc( (mi+1)*2*sizeof(lc[0]) ) + 1;
710 double* rc = lc + mi + 1;
712 for( i = -1; i < mi; i++ )
715 // 1. for each category calculate the weight of samples
716 // sent to the left (lc) and to the right (rc) by the primary split
717 for( i = 0; i < n; i++ )
719 int idx = ((cat_labels[i] == 65535) && data->is_buf_16u) ? -1 : cat_labels[i];
720 double w = weights[i];
722 double sum = lc[idx] + d*w;
723 double sum_abs = rc[idx] + (d & 1)*w;
724 lc[idx] = sum; rc[idx] = sum_abs;
727 for( i = 0; i < mi; i++ )
730 double sum_abs = rc[i];
731 lc[i] = (sum_abs - sum) * 0.5;
732 rc[i] = (sum_abs + sum) * 0.5;
735 // 2. now form the split.
736 // in each category send all the samples to the same direction as majority
737 for( i = 0; i < mi; i++ )
739 double lval = lc[i], rval = rc[i];
742 split->subset[i >> 5] |= 1 << (i & 31);
749 split->quality = (float)best_val;
750 if( split->quality <= node->maxlr )
751 cvSetRemoveByPtr( data->split_heap, split ), split = 0;
758 CvBoostTree::calc_node_value( CvDTreeNode* node )
760 int i, n = node->sample_count;
761 const double* weights = ensemble->get_weights()->data.db;
762 int* labels_buf = data->get_cv_lables_buf();
763 const int* labels = 0;
764 data->get_cv_labels(node, labels_buf, &labels);
765 double* subtree_weights = ensemble->get_subtree_weights()->data.db;
766 double rcw[2] = {0,0};
767 int boost_type = ensemble->get_params().boost_type;
769 if( data->is_classifier )
771 int* _responses_buf = data->get_resp_int_buf();
772 const int* _responses = 0;
773 data->get_class_labels(node, _responses_buf, &_responses);
774 int m = data->get_num_classes();
775 int* cls_count = data->counts->data.i;
776 for( int k = 0; k < m; k++ )
779 for( i = 0; i < n; i++ )
782 double w = weights[idx];
783 int r = _responses[i];
786 subtree_weights[i] = w;
789 node->class_idx = rcw[1] > rcw[0];
791 if( boost_type == CvBoost::DISCRETE )
793 // ignore cat_map for responses, and use {-1,1},
794 // as the whole ensemble response is computes as sign(sum_i(weak_response_i)
795 node->value = node->class_idx*2 - 1;
799 double p = rcw[1]/(rcw[0] + rcw[1]);
800 assert( boost_type == CvBoost::REAL );
802 // store log-ratio of the probability
803 node->value = 0.5*log_ratio(p);
808 // in case of regression tree:
809 // * node value is 1/n*sum_i(Y_i), where Y_i is i-th response,
810 // n is the number of samples in the node.
811 // * node risk is the sum of squared errors: sum_i((Y_i - <node_value>)^2)
812 double sum = 0, sum2 = 0, iw;
813 float* values_buf = data->get_resp_float_buf();
814 const float* values = 0;
815 data->get_ord_responses(node, values_buf, &values);
817 for( i = 0; i < n; i++ )
820 double w = weights[idx]/*priors[values[i] > 0]*/;
821 double t = values[i];
823 subtree_weights[i] = w;
829 node->value = sum*iw;
830 node->node_risk = sum2 - (sum*iw)*sum;
832 // renormalize the risk, as in try_split_node the unweighted formula
833 // sqrt(risk)/n is used, rather than sqrt(risk)/sum(weights_i)
834 node->node_risk *= n*iw*n*iw;
837 // store summary weights
838 subtree_weights[n] = rcw[0];
839 subtree_weights[n+1] = rcw[1];
843 void CvBoostTree::read( CvFileStorage* fs, CvFileNode* fnode, CvBoost* _ensemble, CvDTreeTrainData* _data )
845 CvDTree::read( fs, fnode, _data );
846 ensemble = _ensemble;
850 void CvBoostTree::read( CvFileStorage*, CvFileNode* )
855 void CvBoostTree::read( CvFileStorage* _fs, CvFileNode* _node,
856 CvDTreeTrainData* _data )
858 CvDTree::read( _fs, _node, _data );
862 /////////////////////////////////// CvBoost /////////////////////////////////////
868 default_model_name = "my_boost_tree";
869 active_vars = active_vars_abs = orig_response = sum_response = weak_eval =
870 subsample_mask = weights = subtree_weights = 0;
871 have_active_cat_vars = have_subsample = false;
877 void CvBoost::prune( CvSlice slice )
882 int i, count = cvSliceLength( slice, weak );
884 cvStartReadSeq( weak, &reader );
885 cvSetSeqReaderPos( &reader, slice.start_index );
887 for( i = 0; i < count; i++ )
890 CV_READ_SEQ_ELEM( w, reader );
894 cvSeqRemoveSlice( weak, slice );
899 void CvBoost::clear()
903 prune( CV_WHOLE_SEQ );
904 cvReleaseMemStorage( &weak->storage );
910 cvReleaseMat( &active_vars );
911 cvReleaseMat( &active_vars_abs );
912 cvReleaseMat( &orig_response );
913 cvReleaseMat( &sum_response );
914 cvReleaseMat( &weak_eval );
915 cvReleaseMat( &subsample_mask );
916 cvReleaseMat( &weights );
917 have_subsample = false;
927 CvBoost::CvBoost( const CvMat* _train_data, int _tflag,
928 const CvMat* _responses, const CvMat* _var_idx,
929 const CvMat* _sample_idx, const CvMat* _var_type,
930 const CvMat* _missing_mask, CvBoostParams _params )
934 default_model_name = "my_boost_tree";
935 orig_response = sum_response = weak_eval = subsample_mask = weights = 0;
937 train( _train_data, _tflag, _responses, _var_idx, _sample_idx,
938 _var_type, _missing_mask, _params );
943 CvBoost::set_params( const CvBoostParams& _params )
947 CV_FUNCNAME( "CvBoost::set_params" );
952 if( params.boost_type != DISCRETE && params.boost_type != REAL &&
953 params.boost_type != LOGIT && params.boost_type != GENTLE )
954 CV_ERROR( CV_StsBadArg, "Unknown/unsupported boosting type" );
956 params.weak_count = MAX( params.weak_count, 1 );
957 params.weight_trim_rate = MAX( params.weight_trim_rate, 0. );
958 params.weight_trim_rate = MIN( params.weight_trim_rate, 1. );
959 if( params.weight_trim_rate < FLT_EPSILON )
960 params.weight_trim_rate = 1.f;
962 if( params.boost_type == DISCRETE &&
963 params.split_criteria != GINI && params.split_criteria != MISCLASS )
964 params.split_criteria = MISCLASS;
965 if( params.boost_type == REAL &&
966 params.split_criteria != GINI && params.split_criteria != MISCLASS )
967 params.split_criteria = GINI;
968 if( (params.boost_type == LOGIT || params.boost_type == GENTLE) &&
969 params.split_criteria != SQERR )
970 params.split_criteria = SQERR;
981 CvBoost::train( const CvMat* _train_data, int _tflag,
982 const CvMat* _responses, const CvMat* _var_idx,
983 const CvMat* _sample_idx, const CvMat* _var_type,
984 const CvMat* _missing_mask,
985 CvBoostParams _params, bool _update )
988 CvMemStorage* storage = 0;
990 CV_FUNCNAME( "CvBoost::train" );
996 set_params( _params );
998 cvReleaseMat( &active_vars );
999 cvReleaseMat( &active_vars_abs );
1001 if( !_update || !data )
1004 data = new CvDTreeTrainData( _train_data, _tflag, _responses, _var_idx,
1005 _sample_idx, _var_type, _missing_mask, _params, true, true );
1007 if( data->get_num_classes() != 2 )
1008 CV_ERROR( CV_StsNotImplemented,
1009 "Boosted trees can only be used for 2-class classification." );
1010 CV_CALL( storage = cvCreateMemStorage() );
1011 weak = cvCreateSeq( 0, sizeof(CvSeq), sizeof(CvBoostTree*), storage );
1016 data->set_data( _train_data, _tflag, _responses, _var_idx,
1017 _sample_idx, _var_type, _missing_mask, _params, true, true, true );
1020 if ( (_params.boost_type == LOGIT) || (_params.boost_type == GENTLE) )
1021 data->do_responses_copy();
1023 update_weights( 0 );
1025 for( i = 0; i < params.weak_count; i++ )
1027 CvBoostTree* tree = new CvBoostTree;
1028 if( !tree->train( data, subsample_mask, this ) )
1033 //cvCheckArr( get_weak_response());
1034 cvSeqPush( weak, &tree );
1035 update_weights( tree );
1039 get_active_vars(); // recompute active_vars* maps and condensed_idx's in the splits.
1040 data->is_classifier = true;
1043 data->free_train_data();
1050 bool CvBoost::train( CvMLData* _data,
1051 CvBoostParams params,
1054 bool result = false;
1056 CV_FUNCNAME( "CvBoost::train" );
1060 const CvMat* values = _data->get_values();
1061 const CvMat* response = _data->get_responses();
1062 const CvMat* missing = _data->get_missing();
1063 const CvMat* var_types = _data->get_var_types();
1064 const CvMat* train_sidx = _data->get_train_sample_idx();
1065 const CvMat* var_idx = _data->get_var_idx();
1067 CV_CALL( result = train( values, CV_ROW_SAMPLE, response, var_idx,
1068 train_sidx, var_types, missing, params, update ) );
1076 CvBoost::update_weights( CvBoostTree* tree )
1078 CV_FUNCNAME( "CvBoost::update_weights" );
1082 int i, n = data->sample_count;
1086 int *sample_idx_buf;
1087 const int* sample_idx = 0;
1088 if ( (params.boost_type == LOGIT) || (params.boost_type == GENTLE) )
1090 step = CV_IS_MAT_CONT(data->responses_copy->type) ?
1091 1 : data->responses_copy->step / CV_ELEM_SIZE(data->responses_copy->type);
1092 fdata = data->responses_copy->data.fl;
1093 sample_idx_buf = (int*)cvStackAlloc(data->sample_count*sizeof(sample_idx_buf[0]));
1094 data->get_sample_indices( data->data_root, sample_idx_buf, &sample_idx );
1096 CvMat* buf = data->buf;
1097 if( !tree ) // before training the first tree, initialize weights and other parameters
1099 int n = data->sample_count;
1100 int* class_labels_buf = data->get_resp_int_buf();
1101 const int* class_labels = 0;
1102 data->get_class_labels(data->data_root, class_labels_buf, &class_labels);
1103 // in case of logitboost and gentle adaboost each weak tree is a regression tree,
1104 // so we need to convert class labels to floating-point values
1105 float* responses_buf = data->get_resp_float_buf();
1106 const float* responses = 0;
1107 data->get_ord_responses(data->data_root, responses_buf, &responses);
1110 double p[2] = { 1, 1 };
1112 cvReleaseMat( &orig_response );
1113 cvReleaseMat( &sum_response );
1114 cvReleaseMat( &weak_eval );
1115 cvReleaseMat( &subsample_mask );
1116 cvReleaseMat( &weights );
1118 CV_CALL( orig_response = cvCreateMat( 1, n, CV_32S ));
1119 CV_CALL( weak_eval = cvCreateMat( 1, n, CV_64F ));
1120 CV_CALL( subsample_mask = cvCreateMat( 1, n, CV_8U ));
1121 CV_CALL( weights = cvCreateMat( 1, n, CV_64F ));
1122 CV_CALL( subtree_weights = cvCreateMat( 1, n + 2, CV_64F ));
1124 if( data->have_priors )
1126 // compute weight scale for each class from their prior probabilities
1128 for( i = 0; i < n; i++ )
1129 c1 += class_labels[i];
1130 p[0] = data->priors->data.db[0]*(c1 < n ? 1./(n - c1) : 0.);
1131 p[1] = data->priors->data.db[1]*(c1 > 0 ? 1./c1 : 0.);
1132 p[0] /= p[0] + p[1];
1136 if (data->is_buf_16u)
1138 unsigned short* labels = (unsigned short*)(buf->data.s + data->data_root->buf_idx*buf->cols +
1139 data->data_root->offset + (data->work_var_count-1)*data->sample_count);
1140 for( i = 0; i < n; i++ )
1142 // save original categorical responses {0,1}, convert them to {-1,1}
1143 orig_response->data.i[i] = class_labels[i]*2 - 1;
1144 // make all the samples active at start.
1145 // later, in trim_weights() deactivate/reactive again some, if need
1146 subsample_mask->data.ptr[i] = (uchar)1;
1147 // make all the initial weights the same.
1148 weights->data.db[i] = w0*p[class_labels[i]];
1149 // set the labels to find (from within weak tree learning proc)
1150 // the particular sample weight, and where to store the response.
1151 labels[i] = (unsigned short)i;
1156 int* labels = buf->data.i + data->data_root->buf_idx*buf->cols +
1157 data->data_root->offset + (data->work_var_count-1)*data->sample_count;
1159 for( i = 0; i < n; i++ )
1161 // save original categorical responses {0,1}, convert them to {-1,1}
1162 orig_response->data.i[i] = class_labels[i]*2 - 1;
1163 // make all the samples active at start.
1164 // later, in trim_weights() deactivate/reactive again some, if need
1165 subsample_mask->data.ptr[i] = (uchar)1;
1166 // make all the initial weights the same.
1167 weights->data.db[i] = w0*p[class_labels[i]];
1168 // set the labels to find (from within weak tree learning proc)
1169 // the particular sample weight, and where to store the response.
1174 if( params.boost_type == LOGIT )
1176 CV_CALL( sum_response = cvCreateMat( 1, n, CV_64F ));
1178 for( i = 0; i < n; i++ )
1180 sum_response->data.db[i] = 0;
1181 fdata[sample_idx[i]*step] = orig_response->data.i[i] > 0 ? 2.f : -2.f;
1184 // in case of logitboost each weak tree is a regression tree.
1185 // the target function values are recalculated for each of the trees
1186 data->is_classifier = false;
1188 else if( params.boost_type == GENTLE )
1190 for( i = 0; i < n; i++ )
1191 fdata[sample_idx[i]*step] = (float)orig_response->data.i[i];
1193 data->is_classifier = false;
1198 // at this moment, for all the samples that participated in the training of the most
1199 // recent weak classifier we know the responses. For other samples we need to compute them
1200 if( have_subsample )
1202 float* values0, *values = (float*)cvStackAlloc(data->buf->step*sizeof(float));
1203 uchar* missing0, *missing = (uchar*)cvStackAlloc(data->buf->step*sizeof(uchar));
1204 CvMat _sample, _mask;
1208 // invert the subsample mask
1209 cvXorS( subsample_mask, cvScalar(1.), subsample_mask );
1210 data->get_vectors( subsample_mask, values, missing, 0 );
1212 _sample = cvMat( 1, data->var_count, CV_32F );
1213 _mask = cvMat( 1, data->var_count, CV_8U );
1215 // run tree through all the non-processed samples
1216 for( i = 0; i < n; i++ )
1217 if( subsample_mask->data.ptr[i] )
1219 _sample.data.fl = values;
1220 _mask.data.ptr = missing;
1221 values += _sample.cols;
1222 missing += _mask.cols;
1223 weak_eval->data.db[i] = tree->predict( &_sample, &_mask, true )->value;
1227 // now update weights and other parameters for each type of boosting
1228 if( params.boost_type == DISCRETE )
1230 // Discrete AdaBoost:
1231 // weak_eval[i] (=f(x_i)) is in {-1,1}
1232 // err = sum(w_i*(f(x_i) != y_i))/sum(w_i)
1233 // C = log((1-err)/err)
1234 // w_i *= exp(C*(f(x_i) != y_i))
1237 double scale[] = { 1., 0. };
1239 for( i = 0; i < n; i++ )
1241 double w = weights->data.db[i];
1243 err += w*(weak_eval->data.db[i] != orig_response->data.i[i]);
1248 C = err = -log_ratio( err );
1249 scale[1] = exp(err);
1252 for( i = 0; i < n; i++ )
1254 double w = weights->data.db[i]*
1255 scale[weak_eval->data.db[i] != orig_response->data.i[i]];
1257 weights->data.db[i] = w;
1262 else if( params.boost_type == REAL )
1265 // weak_eval[i] = f(x_i) = 0.5*log(p(x_i)/(1-p(x_i))), p(x_i)=P(y=1|x_i)
1266 // w_i *= exp(-y_i*f(x_i))
1268 for( i = 0; i < n; i++ )
1269 weak_eval->data.db[i] *= -orig_response->data.i[i];
1271 cvExp( weak_eval, weak_eval );
1273 for( i = 0; i < n; i++ )
1275 double w = weights->data.db[i]*weak_eval->data.db[i];
1277 weights->data.db[i] = w;
1280 else if( params.boost_type == LOGIT )
1283 // weak_eval[i] = f(x_i) in [-z_max,z_max]
1284 // sum_response = F(x_i).
1285 // F(x_i) += 0.5*f(x_i)
1286 // p(x_i) = exp(F(x_i))/(exp(F(x_i)) + exp(-F(x_i))=1/(1+exp(-2*F(x_i)))
1287 // reuse weak_eval: weak_eval[i] <- p(x_i)
1288 // w_i = p(x_i)*1(1 - p(x_i))
1289 // z_i = ((y_i+1)/2 - p(x_i))/(p(x_i)*(1 - p(x_i)))
1290 // store z_i to the data->data_root as the new target responses
1292 const double lb_weight_thresh = FLT_EPSILON;
1293 const double lb_z_max = 10.;
1294 float* responses_buf = data->get_resp_float_buf();
1295 const float* responses = 0;
1296 data->get_ord_responses(data->data_root, responses_buf, &responses);
1298 /*if( weak->total == 7 )
1301 for( i = 0; i < n; i++ )
1303 double s = sum_response->data.db[i] + 0.5*weak_eval->data.db[i];
1304 sum_response->data.db[i] = s;
1305 weak_eval->data.db[i] = -2*s;
1308 cvExp( weak_eval, weak_eval );
1310 for( i = 0; i < n; i++ )
1312 double p = 1./(1. + weak_eval->data.db[i]);
1313 double w = p*(1 - p), z;
1314 w = MAX( w, lb_weight_thresh );
1315 weights->data.db[i] = w;
1317 if( orig_response->data.i[i] > 0 )
1320 fdata[sample_idx[i]*step] = (float)MIN(z, lb_z_max);
1325 fdata[sample_idx[i]*step] = (float)-MIN(z, lb_z_max);
1332 // weak_eval[i] = f(x_i) in [-1,1]
1333 // w_i *= exp(-y_i*f(x_i))
1334 assert( params.boost_type == GENTLE );
1336 for( i = 0; i < n; i++ )
1337 weak_eval->data.db[i] *= -orig_response->data.i[i];
1339 cvExp( weak_eval, weak_eval );
1341 for( i = 0; i < n; i++ )
1343 double w = weights->data.db[i] * weak_eval->data.db[i];
1344 weights->data.db[i] = w;
1350 // renormalize weights
1351 if( sumw > FLT_EPSILON )
1354 for( i = 0; i < n; ++i )
1355 weights->data.db[i] *= sumw;
1362 static CV_IMPLEMENT_QSORT_EX( icvSort_64f, double, CV_LT, int )
1366 CvBoost::trim_weights()
1368 //CV_FUNCNAME( "CvBoost::trim_weights" );
1372 int i, count = data->sample_count, nz_count = 0;
1373 double sum, threshold;
1375 if( params.weight_trim_rate <= 0. || params.weight_trim_rate >= 1. )
1378 // use weak_eval as temporary buffer for sorted weights
1379 cvCopy( weights, weak_eval );
1381 icvSort_64f( weak_eval->data.db, count, 0 );
1383 // as weight trimming occurs immediately after updating the weights,
1384 // where they are renormalized, we assume that the weight sum = 1.
1385 sum = 1. - params.weight_trim_rate;
1387 for( i = 0; i < count; i++ )
1389 double w = weak_eval->data.db[i];
1395 threshold = i < count ? weak_eval->data.db[i] : DBL_MAX;
1397 for( i = 0; i < count; i++ )
1399 double w = weights->data.db[i];
1400 int f = w >= threshold;
1401 subsample_mask->data.ptr[i] = (uchar)f;
1405 have_subsample = nz_count < count;
1412 CvBoost::get_active_vars( bool absolute_idx )
1418 CV_FUNCNAME( "CvBoost::get_active_vars" );
1423 CV_ERROR( CV_StsError, "The boosted tree ensemble has not been trained yet" );
1425 if( !active_vars || !active_vars_abs )
1428 int i, j, nactive_vars;
1430 const CvDTreeNode* node;
1432 assert(!active_vars && !active_vars_abs);
1433 mask = cvCreateMat( 1, data->var_count, CV_8U );
1434 inv_map = cvCreateMat( 1, data->var_count, CV_32S );
1436 cvSet( inv_map, cvScalar(-1) );
1438 // first pass: compute the mask of used variables
1439 cvStartReadSeq( weak, &reader );
1440 for( i = 0; i < weak->total; i++ )
1442 CV_READ_SEQ_ELEM(wtree, reader);
1444 node = wtree->get_root();
1445 assert( node != 0 );
1448 const CvDTreeNode* parent;
1451 CvDTreeSplit* split = node->split;
1452 for( ; split != 0; split = split->next )
1453 mask->data.ptr[split->var_idx] = 1;
1459 for( parent = node->parent; parent && parent->right == node;
1460 node = parent, parent = parent->parent )
1466 node = parent->right;
1470 nactive_vars = cvCountNonZero(mask);
1472 //if ( nactive_vars > 0 )
1474 active_vars = cvCreateMat( 1, nactive_vars, CV_32S );
1475 active_vars_abs = cvCreateMat( 1, nactive_vars, CV_32S );
1477 have_active_cat_vars = false;
1479 for( i = j = 0; i < data->var_count; i++ )
1481 if( mask->data.ptr[i] )
1483 active_vars->data.i[j] = i;
1484 active_vars_abs->data.i[j] = data->var_idx ? data->var_idx->data.i[i] : i;
1485 inv_map->data.i[i] = j;
1486 if( data->var_type->data.i[i] >= 0 )
1487 have_active_cat_vars = true;
1493 // second pass: now compute the condensed indices
1494 cvStartReadSeq( weak, &reader );
1495 for( i = 0; i < weak->total; i++ )
1497 CV_READ_SEQ_ELEM(wtree, reader);
1498 node = wtree->get_root();
1501 const CvDTreeNode* parent;
1504 CvDTreeSplit* split = node->split;
1505 for( ; split != 0; split = split->next )
1507 split->condensed_idx = inv_map->data.i[split->var_idx];
1508 assert( split->condensed_idx >= 0 );
1516 for( parent = node->parent; parent && parent->right == node;
1517 node = parent, parent = parent->parent )
1523 node = parent->right;
1529 result = absolute_idx ? active_vars_abs : active_vars;
1533 cvReleaseMat( &mask );
1534 cvReleaseMat( &inv_map );
1541 CvBoost::predict( const CvMat* _sample, const CvMat* _missing,
1542 CvMat* weak_responses, CvSlice slice,
1543 bool raw_mode, bool return_sum ) const
1546 bool allocated = false;
1547 float value = -FLT_MAX;
1549 CV_FUNCNAME( "CvBoost::predict" );
1553 int i, weak_count, var_count;
1554 CvMat sample, missing;
1561 const float* sample_data;
1564 CV_ERROR( CV_StsError, "The boosted tree ensemble has not been trained yet" );
1566 if( !CV_IS_MAT(_sample) || CV_MAT_TYPE(_sample->type) != CV_32FC1 ||
1567 (_sample->cols != 1 && _sample->rows != 1) ||
1568 (_sample->cols + _sample->rows - 1 != data->var_all && !raw_mode) ||
1569 (active_vars && _sample->cols + _sample->rows - 1 != active_vars->cols && raw_mode) )
1570 CV_ERROR( CV_StsBadArg,
1571 "the input sample must be 1d floating-point vector with the same "
1572 "number of elements as the total number of variables or "
1573 "as the number of variables used for training" );
1577 if( !CV_IS_MAT(_missing) || !CV_IS_MASK_ARR(_missing) ||
1578 !CV_ARE_SIZES_EQ(_missing, _sample) )
1579 CV_ERROR( CV_StsBadArg,
1580 "the missing data mask must be 8-bit vector of the same size as input sample" );
1583 weak_count = cvSliceLength( slice, weak );
1584 if( weak_count >= weak->total )
1586 weak_count = weak->total;
1587 slice.start_index = 0;
1590 if( weak_responses )
1592 if( !CV_IS_MAT(weak_responses) ||
1593 CV_MAT_TYPE(weak_responses->type) != CV_32FC1 ||
1594 (weak_responses->cols != 1 && weak_responses->rows != 1) ||
1595 weak_responses->cols + weak_responses->rows - 1 != weak_count )
1596 CV_ERROR( CV_StsBadArg,
1597 "The output matrix of weak classifier responses must be valid "
1598 "floating-point vector of the same number of components as the length of input slice" );
1599 wstep = CV_IS_MAT_CONT(weak_responses->type) ? 1 : weak_responses->step/sizeof(float);
1602 var_count = active_vars->cols;
1603 vtype = data->var_type->data.i;
1604 cmap = data->cat_map->data.i;
1605 cofs = data->cat_ofs->data.i;
1607 // if need, preprocess the input vector
1611 int step, mstep = 0;
1612 const float* src_sample;
1613 const uchar* src_mask = 0;
1616 const int* vidx = active_vars->data.i;
1617 const int* vidx_abs = active_vars_abs->data.i;
1618 bool have_mask = _missing != 0;
1620 bufsize = var_count*(sizeof(float) + sizeof(uchar));
1621 if( bufsize <= CV_MAX_LOCAL_SIZE )
1622 buf = (float*)cvStackAlloc( bufsize );
1625 CV_CALL( buf = (float*)cvAlloc( bufsize ));
1629 dst_mask = (uchar*)(buf + var_count);
1631 src_sample = _sample->data.fl;
1632 step = CV_IS_MAT_CONT(_sample->type) ? 1 : _sample->step/sizeof(src_sample[0]);
1636 src_mask = _missing->data.ptr;
1637 mstep = CV_IS_MAT_CONT(_missing->type) ? 1 : _missing->step;
1640 for( i = 0; i < var_count; i++ )
1642 int idx = vidx[i], idx_abs = vidx_abs[i];
1643 float val = src_sample[idx_abs*step];
1644 int ci = vtype[idx];
1645 uchar m = src_mask ? src_mask[idx_abs*mstep] : (uchar)0;
1649 int a = cofs[ci], b = (ci+1 >= data->cat_ofs->cols) ? data->cat_map->cols : cofs[ci+1],
1651 int ival = cvRound(val);
1652 if ( (ival != val) && (!m) )
1653 CV_ERROR( CV_StsBadArg,
1654 "one of input categorical variable is not an integer" );
1659 if( ival < cmap[c] )
1661 else if( ival > cmap[c] )
1667 if( c < 0 || ival != cmap[c] )
1674 val = (float)(c - cofs[ci]);
1678 dst_sample[i] = val;
1682 sample = cvMat( 1, var_count, CV_32F, dst_sample );
1687 missing = cvMat( 1, var_count, CV_8UC1, dst_mask );
1688 _missing = &missing;
1693 if( !CV_IS_MAT_CONT(_sample->type & (_missing ? _missing->type : -1)) )
1694 CV_ERROR( CV_StsBadArg, "In raw mode the input vectors must be continuous" );
1697 cvStartReadSeq( weak, &reader );
1698 cvSetSeqReaderPos( &reader, slice.start_index );
1700 sample_data = _sample->data.fl;
1702 if( !have_active_cat_vars && !_missing && !weak_responses )
1704 for( i = 0; i < weak_count; i++ )
1707 const CvDTreeNode* node;
1708 CV_READ_SEQ_ELEM( wtree, reader );
1710 node = wtree->get_root();
1713 CvDTreeSplit* split = node->split;
1714 int vi = split->condensed_idx;
1715 float val = sample_data[vi];
1716 int dir = val <= split->ord.c ? -1 : 1;
1717 if( split->inversed )
1719 node = dir < 0 ? node->left : node->right;
1726 const int* avars = active_vars->data.i;
1727 const uchar* m = _missing ? _missing->data.ptr : 0;
1729 // full-featured version
1730 for( i = 0; i < weak_count; i++ )
1733 const CvDTreeNode* node;
1734 CV_READ_SEQ_ELEM( wtree, reader );
1736 node = wtree->get_root();
1739 const CvDTreeSplit* split = node->split;
1741 for( ; !dir && split != 0; split = split->next )
1743 int vi = split->condensed_idx;
1744 int ci = vtype[avars[vi]];
1745 float val = sample_data[vi];
1748 if( ci < 0 ) // ordered
1749 dir = val <= split->ord.c ? -1 : 1;
1752 int c = cvRound(val);
1753 dir = CV_DTREE_CAT_DIR(c, split->subset);
1755 if( split->inversed )
1761 int diff = node->right->sample_count - node->left->sample_count;
1762 dir = diff < 0 ? -1 : 1;
1764 node = dir < 0 ? node->left : node->right;
1766 if( weak_responses )
1767 weak_responses->data.fl[i*wstep] = (float)node->value;
1776 int cls_idx = sum >= 0;
1778 value = (float)cls_idx;
1780 value = (float)cmap[cofs[vtype[data->var_count]] + cls_idx];
1791 float CvBoost::calc_error( CvMLData* _data, int type, vector<float> *resp )
1794 const CvMat* values = _data->get_values();
1795 const CvMat* response = _data->get_responses();
1796 const CvMat* missing = _data->get_missing();
1797 const CvMat* sample_idx = (type == CV_TEST_ERROR) ? _data->get_test_sample_idx() : _data->get_train_sample_idx();
1798 const CvMat* var_types = _data->get_var_types();
1799 int* sidx = sample_idx ? sample_idx->data.i : 0;
1800 int r_step = CV_IS_MAT_CONT(response->type) ?
1801 1 : response->step / CV_ELEM_SIZE(response->type);
1802 bool is_classifier = var_types->data.ptr[var_types->cols-1] == CV_VAR_CATEGORICAL;
1803 int sample_count = sample_idx ? sample_idx->cols : 0;
1804 sample_count = (type == CV_TRAIN_ERROR && sample_count == 0) ? values->rows : sample_count;
1805 float* pred_resp = 0;
1806 if( resp && (sample_count > 0) )
1808 resp->resize( sample_count );
1809 pred_resp = &((*resp)[0]);
1811 if ( is_classifier )
1813 for( int i = 0; i < sample_count; i++ )
1816 int si = sidx ? sidx[i] : i;
1817 cvGetRow( values, &sample, si );
1819 cvGetRow( missing, &miss, si );
1820 float r = (float)predict( &sample, missing ? &miss : 0 );
1823 int d = fabs((double)r - response->data.fl[si*r_step]) <= FLT_EPSILON ? 0 : 1;
1826 err = sample_count ? err / (float)sample_count * 100 : -FLT_MAX;
1830 for( int i = 0; i < sample_count; i++ )
1833 int si = sidx ? sidx[i] : i;
1834 cvGetRow( values, &sample, si );
1836 cvGetRow( missing, &miss, si );
1837 float r = (float)predict( &sample, missing ? &miss : 0 );
1840 float d = r - response->data.fl[si*r_step];
1843 err = sample_count ? err / (float)sample_count : -FLT_MAX;
1848 void CvBoost::write_params( CvFileStorage* fs ) const
1850 //CV_FUNCNAME( "CvBoost::write_params" );
1854 const char* boost_type_str =
1855 params.boost_type == DISCRETE ? "DiscreteAdaboost" :
1856 params.boost_type == REAL ? "RealAdaboost" :
1857 params.boost_type == LOGIT ? "LogitBoost" :
1858 params.boost_type == GENTLE ? "GentleAdaboost" : 0;
1860 const char* split_crit_str =
1861 params.split_criteria == DEFAULT ? "Default" :
1862 params.split_criteria == GINI ? "Gini" :
1863 params.boost_type == MISCLASS ? "Misclassification" :
1864 params.boost_type == SQERR ? "SquaredErr" : 0;
1866 if( boost_type_str )
1867 cvWriteString( fs, "boosting_type", boost_type_str );
1869 cvWriteInt( fs, "boosting_type", params.boost_type );
1871 if( split_crit_str )
1872 cvWriteString( fs, "splitting_criteria", split_crit_str );
1874 cvWriteInt( fs, "splitting_criteria", params.split_criteria );
1876 cvWriteInt( fs, "ntrees", params.weak_count );
1877 cvWriteReal( fs, "weight_trimming_rate", params.weight_trim_rate );
1879 data->write_params( fs );
1885 void CvBoost::read_params( CvFileStorage* fs, CvFileNode* fnode )
1887 CV_FUNCNAME( "CvBoost::read_params" );
1893 if( !fnode || !CV_NODE_IS_MAP(fnode->tag) )
1896 data = new CvDTreeTrainData();
1897 CV_CALL( data->read_params(fs, fnode));
1898 data->shared = true;
1900 params.max_depth = data->params.max_depth;
1901 params.min_sample_count = data->params.min_sample_count;
1902 params.max_categories = data->params.max_categories;
1903 params.priors = data->params.priors;
1904 params.regression_accuracy = data->params.regression_accuracy;
1905 params.use_surrogates = data->params.use_surrogates;
1907 temp = cvGetFileNodeByName( fs, fnode, "boosting_type" );
1911 if( temp && CV_NODE_IS_STRING(temp->tag) )
1913 const char* boost_type_str = cvReadString( temp, "" );
1914 params.boost_type = strcmp( boost_type_str, "DiscreteAdaboost" ) == 0 ? DISCRETE :
1915 strcmp( boost_type_str, "RealAdaboost" ) == 0 ? REAL :
1916 strcmp( boost_type_str, "LogitBoost" ) == 0 ? LOGIT :
1917 strcmp( boost_type_str, "GentleAdaboost" ) == 0 ? GENTLE : -1;
1920 params.boost_type = cvReadInt( temp, -1 );
1922 if( params.boost_type < DISCRETE || params.boost_type > GENTLE )
1923 CV_ERROR( CV_StsBadArg, "Unknown boosting type" );
1925 temp = cvGetFileNodeByName( fs, fnode, "splitting_criteria" );
1926 if( temp && CV_NODE_IS_STRING(temp->tag) )
1928 const char* split_crit_str = cvReadString( temp, "" );
1929 params.split_criteria = strcmp( split_crit_str, "Default" ) == 0 ? DEFAULT :
1930 strcmp( split_crit_str, "Gini" ) == 0 ? GINI :
1931 strcmp( split_crit_str, "Misclassification" ) == 0 ? MISCLASS :
1932 strcmp( split_crit_str, "SquaredErr" ) == 0 ? SQERR : -1;
1935 params.split_criteria = cvReadInt( temp, -1 );
1937 if( params.split_criteria < DEFAULT || params.boost_type > SQERR )
1938 CV_ERROR( CV_StsBadArg, "Unknown boosting type" );
1940 params.weak_count = cvReadIntByName( fs, fnode, "ntrees" );
1941 params.weight_trim_rate = cvReadRealByName( fs, fnode, "weight_trimming_rate", 0. );
1949 CvBoost::read( CvFileStorage* fs, CvFileNode* node )
1951 CV_FUNCNAME( "CvRTrees::read" );
1956 CvFileNode* trees_fnode;
1957 CvMemStorage* storage;
1961 read_params( fs, node );
1966 trees_fnode = cvGetFileNodeByName( fs, node, "trees" );
1967 if( !trees_fnode || !CV_NODE_IS_SEQ(trees_fnode->tag) )
1968 CV_ERROR( CV_StsParseError, "<trees> tag is missing" );
1970 cvStartReadSeq( trees_fnode->data.seq, &reader );
1971 ntrees = trees_fnode->data.seq->total;
1973 if( ntrees != params.weak_count )
1974 CV_ERROR( CV_StsUnmatchedSizes,
1975 "The number of trees stored does not match <ntrees> tag value" );
1977 CV_CALL( storage = cvCreateMemStorage() );
1978 weak = cvCreateSeq( 0, sizeof(CvSeq), sizeof(CvBoostTree*), storage );
1980 for( i = 0; i < ntrees; i++ )
1982 CvBoostTree* tree = new CvBoostTree();
1983 CV_CALL(tree->read( fs, (CvFileNode*)reader.ptr, this, data ));
1984 CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
1985 cvSeqPush( weak, &tree );
1994 CvBoost::write( CvFileStorage* fs, const char* name ) const
1996 CV_FUNCNAME( "CvBoost::write" );
2003 cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_BOOSTING );
2006 CV_ERROR( CV_StsBadArg, "The classifier has not been trained yet" );
2009 cvStartWriteStruct( fs, "trees", CV_NODE_SEQ );
2011 cvStartReadSeq( weak, &reader );
2013 for( i = 0; i < weak->total; i++ )
2016 CV_READ_SEQ_ELEM( tree, reader );
2017 cvStartWriteStruct( fs, 0, CV_NODE_MAP );
2019 cvEndWriteStruct( fs );
2022 cvEndWriteStruct( fs );
2023 cvEndWriteStruct( fs );
2030 CvBoost::get_weights()
2037 CvBoost::get_subtree_weights()
2039 return subtree_weights;
2044 CvBoost::get_weak_response()
2050 const CvBoostParams&
2051 CvBoost::get_params() const
2056 CvSeq* CvBoost::get_weak_predictors()
2061 const CvDTreeTrainData* CvBoost::get_data() const
2068 CvBoost::CvBoost( const Mat& _train_data, int _tflag,
2069 const Mat& _responses, const Mat& _var_idx,
2070 const Mat& _sample_idx, const Mat& _var_type,
2071 const Mat& _missing_mask,
2072 CvBoostParams _params )
2076 default_model_name = "my_boost_tree";
2077 orig_response = sum_response = weak_eval = subsample_mask = weights = 0;
2079 train( _train_data, _tflag, _responses, _var_idx, _sample_idx,
2080 _var_type, _missing_mask, _params );
2085 CvBoost::train( const Mat& _train_data, int _tflag,
2086 const Mat& _responses, const Mat& _var_idx,
2087 const Mat& _sample_idx, const Mat& _var_type,
2088 const Mat& _missing_mask,
2089 CvBoostParams _params, bool _update )
2091 CvMat tdata = _train_data, responses = _responses, vidx = _var_idx,
2092 sidx = _sample_idx, vtype = _var_type, mmask = _missing_mask;
2093 return train(&tdata, _tflag, &responses, vidx.data.ptr ? &vidx : 0,
2094 sidx.data.ptr ? &sidx : 0, vtype.data.ptr ? &vtype : 0,
2095 mmask.data.ptr ? &mmask : 0, _params, _update);
2099 CvBoost::predict( const Mat& _sample, const Mat& _missing,
2100 Mat* weak_responses, CvSlice slice,
2101 bool raw_mode, bool return_sum ) const
2103 CvMat sample = _sample, mmask = _missing, wr, *pwr = 0;
2104 if( weak_responses )
2106 int weak_count = cvSliceLength( slice, weak );
2107 if( weak_count >= weak->total )
2109 weak_count = weak->total;
2110 slice.start_index = 0;
2113 if( !(weak_responses->data && weak_responses->type() == CV_32FC1 &&
2114 (weak_responses->cols == 1 || weak_responses->rows == 1) &&
2115 weak_responses->cols + weak_responses->rows - 1 == weak_count) )
2116 weak_responses->create(weak_count, 1, CV_32FC1);
2117 pwr = &(wr = *weak_responses);
2119 return predict(&sample, &mmask, pwr, slice, raw_mode, return_sum);