Move the sources to trunk
[opencv] / ml / src / mlboost.cpp
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                        Intel License Agreement
11 //
12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
13 // Third party copyrights are property of their respective owners.
14 //
15 // Redistribution and use in source and binary forms, with or without modification,
16 // are permitted provided that the following conditions are met:
17 //
18 //   * Redistribution's of source code must retain the above copyright notice,
19 //     this list of conditions and the following disclaimer.
20 //
21 //   * Redistribution's in binary form must reproduce the above copyright notice,
22 //     this list of conditions and the following disclaimer in the documentation
23 //     and/or other materials provided with the distribution.
24 //
25 //   * The name of Intel Corporation may not be used to endorse or promote products
26 //     derived from this software without specific prior written permission.
27 //
28 // This software is provided by the copyright holders and contributors "as is" and
29 // any express or implied warranties, including, but not limited to, the implied
30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
31 // In no event shall the Intel Corporation or contributors be liable for any direct,
32 // indirect, incidental, special, exemplary, or consequential damages
33 // (including, but not limited to, procurement of substitute goods or services;
34 // loss of use, data, or profits; or business interruption) however caused
35 // and on any theory of liability, whether in contract, strict liability,
36 // or tort (including negligence or otherwise) arising in any way out of
37 // the use of this software, even if advised of the possibility of such damage.
38 //
39 //M*/
40
41 #include "_ml.h"
42
43 static inline double
44 log_ratio( double val )
45 {
46     const double eps = 1e-5;
47     
48     val = MAX( val, eps );
49     val = MIN( val, 1. - eps );
50     return log( val/(1. - val) );
51 }
52
53
54 CvBoostParams::CvBoostParams()
55 {
56     boost_type = CvBoost::REAL;
57     weak_count = 100;
58     weight_trim_rate = 0.95;
59     cv_folds = 0;
60     max_depth = 1;
61 }
62
63
64 CvBoostParams::CvBoostParams( int _boost_type, int _weak_count,
65                                         double _weight_trim_rate, int _max_depth,
66                                         bool _use_surrogates, const float* _priors )
67 {
68     boost_type = _boost_type;
69     weak_count = _weak_count;
70     weight_trim_rate = _weight_trim_rate;
71     split_criteria = CvBoost::DEFAULT;
72     cv_folds = 0;
73     max_depth = _max_depth;
74     use_surrogates = _use_surrogates;
75     priors = _priors;
76 }
77
78
79
80 ///////////////////////////////// CvBoostTree ///////////////////////////////////
81
82 CvBoostTree::CvBoostTree()
83 {
84     ensemble = 0;
85 }
86
87
88 CvBoostTree::~CvBoostTree()
89 {
90     clear();
91 }
92
93
94 void
95 CvBoostTree::clear()
96 {
97     CvDTree::clear();
98     ensemble = 0;
99 }
100
101
102 bool
103 CvBoostTree::train( CvDTreeTrainData* _train_data,
104                     const CvMat* _subsample_idx, CvBoost* _ensemble )
105 {
106     clear();
107     ensemble = _ensemble;
108     data = _train_data;
109     data->shared = true;
110     
111     return do_train( _subsample_idx );
112 }
113
114
115 bool
116 CvBoostTree::train( const CvMat*, int, const CvMat*, const CvMat*,
117                     const CvMat*, const CvMat*, const CvMat*, CvDTreeParams )
118 {
119     assert(0);
120     return false;
121 }
122
123
124 bool
125 CvBoostTree::train( CvDTreeTrainData*, const CvMat* )
126 {
127     assert(0);
128     return false;
129 }
130
131
132 void
133 CvBoostTree::scale( double scale )
134 {
135     CvDTreeNode* node = root;
136
137     // traverse the tree and scale all the node values
138     for(;;)
139     {
140         CvDTreeNode* parent;
141         for(;;)
142         {
143             node->value *= scale;
144             if( !node->left )
145                 break;
146             node = node->left;
147         }
148         
149         for( parent = node->parent; parent && parent->right == node;
150             node = parent, parent = parent->parent )
151             ;
152
153         if( !parent )
154             break;
155
156         node = parent->right;
157     }
158 }
159
160
161 void
162 CvBoostTree::try_split_node( CvDTreeNode* node )
163 {
164     CvDTree::try_split_node( node );
165
166     if( !node->left )
167     {
168         // if the node has not been split,
169         // store the responses for the corresponding training samples
170         double* weak_eval = ensemble->get_weak_response()->data.db;
171         int* labels = data->get_labels( node );
172         int i, count = node->sample_count;
173         double value = node->value;
174
175         for( i = 0; i < count; i++ )
176             weak_eval[labels[i]] = value;
177     }
178 }
179
180
181 double
182 CvBoostTree::calc_node_dir( CvDTreeNode* node )
183 {
184     char* dir = (char*)data->direction->data.ptr;
185     const double* weights = ensemble->get_subtree_weights()->data.db;
186     int i, n = node->sample_count, vi = node->split->var_idx;
187     double L, R;
188
189     assert( !node->split->inversed );
190
191     if( data->get_var_type(vi) >= 0 ) // split on categorical var
192     {
193         const int* cat_labels = data->get_cat_var_data( node, vi );
194         const int* subset = node->split->subset;
195         double sum = 0, sum_abs = 0;
196
197         for( i = 0; i < n; i++ )
198         {
199             int idx = cat_labels[i];
200             double w = weights[i];
201             int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0;
202             sum += d*w; sum_abs += (d & 1)*w;
203             dir[i] = (char)d;
204         }
205
206         R = (sum_abs + sum) * 0.5;
207         L = (sum_abs - sum) * 0.5;
208     }
209     else // split on ordered var
210     {
211         const CvPair32s32f* sorted = data->get_ord_var_data(node,vi);
212         int split_point = node->split->ord.split_point;
213         int n1 = node->get_num_valid(vi);
214
215         assert( 0 <= split_point && split_point < n1-1 );
216         L = R = 0;
217
218         for( i = 0; i <= split_point; i++ )
219         {
220             int idx = sorted[i].i;
221             double w = weights[idx];
222             dir[idx] = (char)-1;
223             L += w;
224         }
225
226         for( ; i < n1; i++ )
227         {
228             int idx = sorted[i].i;
229             double w = weights[idx];
230             dir[idx] = (char)1;
231             R += w;
232         }
233
234         for( ; i < n; i++ )
235             dir[sorted[i].i] = (char)0;
236     }
237
238     node->maxlr = MAX( L, R );
239     return node->split->quality/(L + R);
240 }
241
242
243 CvDTreeSplit*
244 CvBoostTree::find_split_ord_class( CvDTreeNode* node, int vi )
245 {
246     const float epsilon = FLT_EPSILON*2;
247     const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
248     const int* responses = data->get_class_labels(node);
249     const double* weights = ensemble->get_subtree_weights()->data.db;
250     int n = node->sample_count;
251     int n1 = node->get_num_valid(vi);
252     const double* rcw0 = weights + n;
253     double lcw[2] = {0,0}, rcw[2];
254     int i, best_i = -1;
255     double best_val = 0;
256     int boost_type = ensemble->get_params().boost_type;
257     int split_criteria = ensemble->get_params().split_criteria;
258
259     rcw[0] = rcw0[0]; rcw[1] = rcw0[1];
260     for( i = n1; i < n; i++ )
261     {
262         int idx = sorted[i].i;
263         double w = weights[idx];
264         rcw[responses[idx]] -= w;
265     }
266
267     if( split_criteria != CvBoost::GINI && split_criteria != CvBoost::MISCLASS )
268         split_criteria = boost_type == CvBoost::DISCRETE ? CvBoost::MISCLASS : CvBoost::GINI;
269
270     if( split_criteria == CvBoost::GINI )
271     {
272         double L = 0, R = rcw[0] + rcw[1];
273         double lsum2 = 0, rsum2 = rcw[0]*rcw[0] + rcw[1]*rcw[1];
274
275         for( i = 0; i < n1 - 1; i++ )
276         {
277             int idx = sorted[i].i;
278             double w = weights[idx], w2 = w*w;
279             double lv, rv;
280             idx = responses[idx];
281             L += w; R -= w;
282             lv = lcw[idx]; rv = rcw[idx];
283             lsum2 += 2*lv*w + w2;
284             rsum2 -= 2*rv*w - w2;
285             lcw[idx] = lv + w; rcw[idx] = rv - w;
286
287             if( sorted[i].val + epsilon < sorted[i+1].val )
288             {
289                 double val = (lsum2*R + rsum2*L)/(L*R);
290                 if( best_val < val )
291                 {
292                     best_val = val;
293                     best_i = i;
294                 }
295             }
296         }
297     }
298     else
299     {
300         for( i = 0; i < n1 - 1; i++ )
301         {
302             int idx = sorted[i].i;
303             double w = weights[idx];
304             idx = responses[idx];
305             lcw[idx] += w;
306             rcw[idx] -= w;
307
308             if( sorted[i].val + epsilon < sorted[i+1].val )
309             {
310                 double val = lcw[0] + rcw[1], val2 = lcw[1] + rcw[0];
311                 val = MAX(val, val2);
312                 if( best_val < val )
313                 {
314                     best_val = val;
315                     best_i = i;
316                 }
317             }
318         }
319     }
320
321     return best_i >= 0 ? data->new_split_ord( vi,
322         (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
323         0, (float)best_val ) : 0;
324 }
325
326
327 #define CV_CMP_NUM_PTR(a,b) (*(a) < *(b))
328 static CV_IMPLEMENT_QSORT_EX( icvSortDblPtr, double*, CV_CMP_NUM_PTR, int )
329
330 CvDTreeSplit*
331 CvBoostTree::find_split_cat_class( CvDTreeNode* node, int vi )
332 {
333     CvDTreeSplit* split;
334     const int* cat_labels = data->get_cat_var_data(node, vi);
335     const int* responses = data->get_class_labels(node);
336     int ci = data->get_var_type(vi);
337     int n = node->sample_count;
338     int mi = data->cat_count->data.i[ci];
339     double lcw[2]={0,0}, rcw[2]={0,0};
340     double* cjk = (double*)cvStackAlloc(2*(mi+1)*sizeof(cjk[0]))+2;
341     const double* weights = ensemble->get_subtree_weights()->data.db;
342     double** dbl_ptr = (double**)cvStackAlloc( mi*sizeof(dbl_ptr[0]) );
343     int i, j, k, idx;
344     double L = 0, R;
345     double best_val = 0;
346     int best_subset = -1, subset_i;
347     int boost_type = ensemble->get_params().boost_type;
348     int split_criteria = ensemble->get_params().split_criteria;
349
350     // init array of counters:
351     // c_{jk} - number of samples that have vi-th input variable = j and response = k.
352     for( j = -1; j < mi; j++ )
353         cjk[j*2] = cjk[j*2+1] = 0;
354
355     for( i = 0; i < n; i++ )
356     {
357         double w = weights[i];
358         j = cat_labels[i];
359         k = responses[i];
360         cjk[j*2 + k] += w;
361     }
362
363     for( j = 0; j < mi; j++ )
364     {
365         rcw[0] += cjk[j*2];
366         rcw[1] += cjk[j*2+1];
367         dbl_ptr[j] = cjk + j*2 + 1;
368     }
369
370     R = rcw[0] + rcw[1];
371
372     if( split_criteria != CvBoost::GINI && split_criteria != CvBoost::MISCLASS )
373         split_criteria = boost_type == CvBoost::DISCRETE ? CvBoost::MISCLASS : CvBoost::GINI;
374
375     // sort rows of c_jk by increasing c_j,1
376     // (i.e. by the weight of samples in j-th category that belong to class 1)
377     icvSortDblPtr( dbl_ptr, mi, 0 );
378
379     for( subset_i = 0; subset_i < mi-1; subset_i++ )
380     {
381         idx = (int)(dbl_ptr[subset_i] - cjk)/2;
382         const double* crow = cjk + idx*2;
383         double w0 = crow[0], w1 = crow[1];
384         double weight = w0 + w1;
385
386         if( weight < FLT_EPSILON )
387             continue;
388
389         lcw[0] += w0; rcw[0] -= w0;
390         lcw[1] += w1; rcw[1] -= w1;
391
392         if( split_criteria == CvBoost::GINI )
393         {
394             double lsum2 = lcw[0]*lcw[0] + lcw[1]*lcw[1];
395             double rsum2 = rcw[0]*rcw[0] + rcw[1]*rcw[1];
396         
397             L += weight;
398             R -= weight;
399
400             if( L > FLT_EPSILON && R > FLT_EPSILON )
401             {
402                 double val = (lsum2*R + rsum2*L)/(L*R);
403                 if( best_val < val )
404                 {
405                     best_val = val;
406                     best_subset = subset_i;
407                 }
408             }
409         }
410         else
411         {
412             double val = lcw[0] + rcw[1];
413             double val2 = lcw[1] + rcw[0];
414
415             val = MAX(val, val2);
416             if( best_val < val )
417             {
418                 best_val = val;
419                 best_subset = subset_i;
420             }
421         }
422     }
423
424     if( best_subset < 0 )
425         return 0;
426
427     split = data->new_split_cat( vi, (float)best_val );
428
429     for( i = 0; i <= best_subset; i++ )
430     {
431         idx = (int)(dbl_ptr[i] - cjk) >> 1;
432         split->subset[idx >> 5] |= 1 << (idx & 31);
433     }
434
435     return split;
436 }
437
438
439 CvDTreeSplit*
440 CvBoostTree::find_split_ord_reg( CvDTreeNode* node, int vi )
441 {
442     const float epsilon = FLT_EPSILON*2;
443     const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
444     const float* responses = data->get_ord_responses(node);
445     const double* weights = ensemble->get_subtree_weights()->data.db;
446     int n = node->sample_count;
447     int n1 = node->get_num_valid(vi);
448     int i, best_i = -1;
449     double best_val = 0, lsum = 0, rsum = node->value*n;
450     double L = 0, R = weights[n];
451
452     // compensate for missing values
453     for( i = n1; i < n; i++ )
454     {
455         int idx = sorted[i].i;
456         double w = weights[idx];
457         rsum -= responses[idx]*w;
458         R -= w;
459     }
460
461     // find the optimal split
462     for( i = 0; i < n1 - 1; i++ )
463     {
464         int idx = sorted[i].i;
465         double w = weights[idx];
466         double t = responses[idx]*w;
467         L += w; R -= w;
468         lsum += t; rsum -= t;
469
470         if( sorted[i].val + epsilon < sorted[i+1].val )
471         {
472             double val = (lsum*lsum*R + rsum*rsum*L)/(L*R);
473             if( best_val < val )
474             {
475                 best_val = val;
476                 best_i = i;
477             }
478         }
479     }
480
481     return best_i >= 0 ? data->new_split_ord( vi,
482         (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
483         0, (float)best_val ) : 0;
484 }
485
486
487 CvDTreeSplit*
488 CvBoostTree::find_split_cat_reg( CvDTreeNode* node, int vi )
489 {
490     CvDTreeSplit* split;
491     const int* cat_labels = data->get_cat_var_data(node, vi);
492     const float* responses = data->get_ord_responses(node);
493     const double* weights = ensemble->get_subtree_weights()->data.db;
494     int ci = data->get_var_type(vi);
495     int n = node->sample_count;
496     int mi = data->cat_count->data.i[ci];
497     double* sum = (double*)cvStackAlloc( (mi+1)*sizeof(sum[0]) ) + 1;
498     double* counts = (double*)cvStackAlloc( (mi+1)*sizeof(counts[0]) ) + 1;
499     double** sum_ptr = (double**)cvStackAlloc( mi*sizeof(sum_ptr[0]) );
500     double L = 0, R = 0, best_val = 0, lsum = 0, rsum = 0;
501     int i, best_subset = -1, subset_i;
502
503     for( i = -1; i < mi; i++ )
504         sum[i] = counts[i] = 0;
505
506     // calculate sum response and weight of each category of the input var
507     for( i = 0; i < n; i++ )
508     {
509         int idx = cat_labels[i];
510         double w = weights[i];
511         double s = sum[idx] + responses[i]*w;
512         double nc = counts[idx] + w;
513         sum[idx] = s;
514         counts[idx] = nc;
515     }
516
517     // calculate average response in each category
518     for( i = 0; i < mi; i++ )
519     {
520         R += counts[i];
521         rsum += sum[i];
522         sum[i] /= counts[i];
523         sum_ptr[i] = sum + i;
524     }
525
526     icvSortDblPtr( sum_ptr, mi, 0 );
527
528     // revert back to unnormalized sums
529     // (there should be a very little loss in accuracy)
530     for( i = 0; i < mi; i++ )
531         sum[i] *= counts[i];
532
533     for( subset_i = 0; subset_i < mi-1; subset_i++ )
534     {
535         int idx = (int)(sum_ptr[subset_i] - sum);
536         double ni = counts[idx];
537
538         if( ni > FLT_EPSILON )
539         {
540             double s = sum[idx];
541             lsum += s; L += ni;
542             rsum -= s; R -= ni;
543             
544             if( L > FLT_EPSILON && R > FLT_EPSILON )
545             {
546                 double val = (lsum*lsum*R + rsum*rsum*L)/(L*R);
547                 if( best_val < val )
548                 {
549                     best_val = val;
550                     best_subset = subset_i;
551                 }
552             }
553         }
554     }
555
556     if( best_subset < 0 )
557         return 0;
558
559     split = data->new_split_cat( vi, (float)best_val );
560     for( i = 0; i <= best_subset; i++ )
561     {
562         int idx = (int)(sum_ptr[i] - sum);
563         split->subset[idx >> 5] |= 1 << (idx & 31);
564     }
565
566     return split;
567 }
568
569
570 CvDTreeSplit*
571 CvBoostTree::find_surrogate_split_ord( CvDTreeNode* node, int vi )
572 {
573     const float epsilon = FLT_EPSILON*2;
574     const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
575     const double* weights = ensemble->get_subtree_weights()->data.db;
576     const char* dir = (char*)data->direction->data.ptr;
577     int n1 = node->get_num_valid(vi);
578     // LL - number of samples that both the primary and the surrogate splits send to the left
579     // LR - ... primary split sends to the left and the surrogate split sends to the right
580     // RL - ... primary split sends to the right and the surrogate split sends to the left
581     // RR - ... both send to the right
582     int i, best_i = -1, best_inversed = 0;
583     double best_val; 
584     double LL = 0, RL = 0, LR, RR;
585     double worst_val = node->maxlr;
586     double sum = 0, sum_abs = 0;
587     best_val = worst_val;
588     
589     for( i = 0; i < n1; i++ )
590     {
591         int idx = sorted[i].i;
592         double w = weights[idx];
593         int d = dir[idx];
594         sum += d*w; sum_abs += (d & 1)*w;
595     }
596
597     // sum_abs = R + L; sum = R - L
598     RR = (sum_abs + sum)*0.5;
599     LR = (sum_abs - sum)*0.5;
600
601     // initially all the samples are sent to the right by the surrogate split,
602     // LR of them are sent to the left by primary split, and RR - to the right.
603     // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.
604     for( i = 0; i < n1 - 1; i++ )
605     {
606         int idx = sorted[i].i;
607         double w = weights[idx];
608         int d = dir[idx];
609
610         if( d < 0 )
611         {
612             LL += w; LR -= w;
613             if( LL + RR > best_val && sorted[i].val + epsilon < sorted[i+1].val )
614             {
615                 best_val = LL + RR;
616                 best_i = i; best_inversed = 0;
617             }
618         }
619         else if( d > 0 )
620         {
621             RL += w; RR -= w;
622             if( RL + LR > best_val && sorted[i].val + epsilon < sorted[i+1].val )
623             {
624                 best_val = RL + LR;
625                 best_i = i; best_inversed = 1;
626             }
627         }
628     }
629
630     return best_i >= 0 && best_val > node->maxlr ? data->new_split_ord( vi,
631         (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
632         best_inversed, (float)best_val ) : 0;
633 }
634
635
636 CvDTreeSplit*
637 CvBoostTree::find_surrogate_split_cat( CvDTreeNode* node, int vi )
638 {
639     const int* cat_labels = data->get_cat_var_data(node, vi);
640     const char* dir = (char*)data->direction->data.ptr;
641     const double* weights = ensemble->get_subtree_weights()->data.db;
642     int n = node->sample_count;
643     // LL - number of samples that both the primary and the surrogate splits send to the left
644     // LR - ... primary split sends to the left and the surrogate split sends to the right
645     // RL - ... primary split sends to the right and the surrogate split sends to the left
646     // RR - ... both send to the right
647     CvDTreeSplit* split = data->new_split_cat( vi, 0 );
648     int i, mi = data->cat_count->data.i[data->get_var_type(vi)];
649     double best_val = 0;
650     double* lc = (double*)cvStackAlloc( (mi+1)*2*sizeof(lc[0]) ) + 1;
651     double* rc = lc + mi + 1;
652     
653     for( i = -1; i < mi; i++ )
654         lc[i] = rc[i] = 0;
655
656     // 1. for each category calculate the weight of samples
657     // sent to the left (lc) and to the right (rc) by the primary split
658     for( i = 0; i < n; i++ )
659     {
660         int idx = cat_labels[i];
661         double w = weights[i];
662         int d = dir[i];
663         double sum = lc[idx] + d*w;
664         double sum_abs = rc[idx] + (d & 1)*w;
665         lc[idx] = sum; rc[idx] = sum_abs;
666     }
667
668     for( i = 0; i < mi; i++ )
669     {
670         double sum = lc[i];
671         double sum_abs = rc[i];
672         lc[i] = (sum_abs - sum) * 0.5;
673         rc[i] = (sum_abs + sum) * 0.5;
674     }
675
676     // 2. now form the split.
677     // in each category send all the samples to the same direction as majority
678     for( i = 0; i < mi; i++ )
679     {
680         double lval = lc[i], rval = rc[i];
681         if( lval > rval )
682         {
683             split->subset[i >> 5] |= 1 << (i & 31);
684             best_val += lval;
685         }
686         else
687             best_val += rval;
688     }
689
690     split->quality = (float)best_val;
691     if( split->quality <= node->maxlr )
692         cvSetRemoveByPtr( data->split_heap, split ), split = 0;
693
694     return split;
695 }
696
697
698 void
699 CvBoostTree::calc_node_value( CvDTreeNode* node )
700 {
701     int i, count = node->sample_count;
702     const double* weights = ensemble->get_weights()->data.db;
703     const int* labels = data->get_labels(node);
704     double* subtree_weights = ensemble->get_subtree_weights()->data.db;
705     double rcw[2] = {0,0};
706     int boost_type = ensemble->get_params().boost_type;
707     //const double* priors = data->priors->data.db;
708
709     if( data->is_classifier )
710     {
711         const int* responses = data->get_class_labels(node);
712         
713         for( i = 0; i < count; i++ )
714         {
715             int idx = labels[i];
716             double w = weights[idx]/*priors[responses[i]]*/;
717             rcw[responses[i]] += w;
718             subtree_weights[i] = w;
719         }
720
721         node->class_idx = rcw[1] > rcw[0];
722
723         if( boost_type == CvBoost::DISCRETE )
724         {
725             // ignore cat_map for responses, and use {-1,1},
726             // as the whole ensemble response is computes as sign(sum_i(weak_response_i)
727             node->value = node->class_idx*2 - 1;
728         }
729         else
730         {
731             double p = rcw[1]/(rcw[0] + rcw[1]);
732             assert( boost_type == CvBoost::REAL );
733             
734             // store log-ratio of the probability
735             node->value = 0.5*log_ratio(p);
736         }
737     }
738     else
739     {
740         // in case of regression tree:
741         //  * node value is 1/n*sum_i(Y_i), where Y_i is i-th response,
742         //    n is the number of samples in the node.
743         //  * node risk is the sum of squared errors: sum_i((Y_i - <node_value>)^2)
744         double sum = 0, sum2 = 0, iw;
745         const float* values = data->get_ord_responses(node);
746         
747         for( i = 0; i < count; i++ )
748         {
749             int idx = labels[i];
750             double w = weights[idx]/*priors[values[i] > 0]*/;
751             double t = values[i];
752             rcw[0] += w;
753             subtree_weights[i] = w;
754             sum += t*w;
755             sum2 += t*t*w;
756         }
757
758         iw = 1./rcw[0];
759         node->value = sum*iw;
760         node->node_risk = sum2 - (sum*iw)*sum;
761         
762         // renormalize the risk, as in try_split_node the unweighted formula
763         // sqrt(risk)/n is used, rather than sqrt(risk)/sum(weights_i)
764         node->node_risk *= count*iw*count*iw;
765     }
766
767     // store summary weights
768     subtree_weights[count] = rcw[0];
769     subtree_weights[count+1] = rcw[1];
770 }
771
772
773 void CvBoostTree::read( CvFileStorage* fs, CvFileNode* fnode, CvBoost* _ensemble, CvDTreeTrainData* _data )
774 {
775     CvDTree::read( fs, fnode, _data );
776     ensemble = _ensemble;
777 }
778
779
780 void CvBoostTree::read( CvFileStorage*, CvFileNode* )
781 {
782     assert(0);
783 }
784
785 void CvBoostTree::read( CvFileStorage* _fs, CvFileNode* _node,
786                         CvDTreeTrainData* _data )
787 {
788     CvDTree::read( _fs, _node, _data );
789 }
790
791
792 /////////////////////////////////// CvBoost /////////////////////////////////////
793
794 CvBoost::CvBoost()
795 {
796     data = 0;
797     weak = 0;
798     default_model_name = "my_boost_tree";
799     orig_response = sum_response = weak_eval = subsample_mask =
800         weights = subtree_weights = 0;
801
802     clear();
803 }
804
805
806 void CvBoost::prune( CvSlice slice )
807 {
808     if( weak )
809     {
810         CvSeqReader reader;
811         int i, count = cvSliceLength( slice, weak );
812         
813         cvStartReadSeq( weak, &reader );
814         cvSetSeqReaderPos( &reader, slice.start_index );
815
816         for( i = 0; i < count; i++ )
817         {
818             CvBoostTree* w;
819             CV_READ_SEQ_ELEM( w, reader );
820             delete w;
821         }
822
823         cvSeqRemoveSlice( weak, slice );
824     }
825 }
826
827
828 void CvBoost::clear()
829 {
830     if( weak )
831     {
832         prune( CV_WHOLE_SEQ );
833         cvReleaseMemStorage( &weak->storage );
834     }
835     if( data )
836         delete data;
837     weak = 0;
838     data = 0;
839     cvReleaseMat( &orig_response );
840     cvReleaseMat( &sum_response );
841     cvReleaseMat( &weak_eval );
842     cvReleaseMat( &subsample_mask );
843     cvReleaseMat( &weights );
844     have_subsample = false;
845 }
846
847
848 CvBoost::~CvBoost()
849 {
850     clear();
851 }
852
853
854 CvBoost::CvBoost( const CvMat* _train_data, int _tflag,
855                   const CvMat* _responses, const CvMat* _var_idx,
856                   const CvMat* _sample_idx, const CvMat* _var_type,
857                   const CvMat* _missing_mask, CvBoostParams _params )
858 {
859     weak = 0;
860     data = 0;
861     default_model_name = "my_boost_tree";
862     orig_response = sum_response = weak_eval = subsample_mask = weights = 0;
863
864     train( _train_data, _tflag, _responses, _var_idx, _sample_idx,
865            _var_type, _missing_mask, _params );
866 }
867
868
869 bool
870 CvBoost::set_params( const CvBoostParams& _params )
871 {
872     bool ok = false;
873     
874     CV_FUNCNAME( "CvBoost::set_params" );
875
876     __BEGIN__;
877
878     params = _params;
879     if( params.boost_type != DISCRETE && params.boost_type != REAL &&
880         params.boost_type != LOGIT && params.boost_type != GENTLE )
881         CV_ERROR( CV_StsBadArg, "Unknown/unsupported boosting type" );
882
883     params.weak_count = MAX( params.weak_count, 1 );
884     params.weight_trim_rate = MAX( params.weight_trim_rate, 0. );
885     params.weight_trim_rate = MIN( params.weight_trim_rate, 1. );
886     if( params.weight_trim_rate < FLT_EPSILON )
887         params.weight_trim_rate = 1.f;
888
889     if( params.boost_type == DISCRETE &&
890         params.split_criteria != GINI && params.split_criteria != MISCLASS )
891         params.split_criteria = MISCLASS;
892     if( params.boost_type == REAL &&
893         params.split_criteria != GINI && params.split_criteria != MISCLASS )
894         params.split_criteria = GINI;
895     if( (params.boost_type == LOGIT || params.boost_type == GENTLE) &&
896         params.split_criteria != SQERR )
897         params.split_criteria = SQERR;
898     
899     ok = true;
900     
901     __END__;
902
903     return ok;
904 }
905
906
907 bool
908 CvBoost::train( const CvMat* _train_data, int _tflag,
909               const CvMat* _responses, const CvMat* _var_idx,
910               const CvMat* _sample_idx, const CvMat* _var_type,
911               const CvMat* _missing_mask,
912               CvBoostParams _params, bool _update )
913 {
914     bool ok = false;
915     CvMemStorage* storage = 0;
916
917     CV_FUNCNAME( "CvBoost::train" );
918
919     __BEGIN__;
920
921     int i;
922
923     set_params( _params );
924
925     if( !_update || !data )
926     {
927         clear();
928         data = new CvDTreeTrainData( _train_data, _tflag, _responses, _var_idx,
929             _sample_idx, _var_type, _missing_mask, _params, true, true );
930
931         if( data->get_num_classes() != 2 )
932             CV_ERROR( CV_StsNotImplemented,
933             "Boosted trees can only be used for 2-class classification." );
934         CV_CALL( storage = cvCreateMemStorage() );
935         weak = cvCreateSeq( 0, sizeof(CvSeq), sizeof(CvBoostTree*), storage );
936         storage = 0;
937     }
938     else
939     {
940         data->set_data( _train_data, _tflag, _responses, _var_idx,
941             _sample_idx, _var_type, _missing_mask, _params, true, true, true );
942     }
943
944     update_weights( 0 );
945
946     for( i = 0; i < params.weak_count; i++ )
947     {
948         CvBoostTree* tree = new CvBoostTree;
949         if( !tree->train( data, subsample_mask, this ) )
950         {
951             delete tree;
952             continue;
953         }
954         //cvCheckArr( get_weak_response());
955         cvSeqPush( weak, &tree );
956         update_weights( tree );
957         trim_weights();
958     }
959
960     data->is_classifier = true;
961     ok = true;
962
963     __END__;
964
965     return ok;
966 }
967
968
969 void
970 CvBoost::update_weights( CvBoostTree* tree )
971 {
972     CV_FUNCNAME( "CvBoost::update_weights" );
973
974     __BEGIN__;
975
976     int i, count = data->sample_count;
977     double sumw = 0.;
978
979     if( !tree ) // before training the first tree, initialize weights and other parameters
980     {
981         const int* class_labels = data->get_class_labels(data->data_root);
982         // in case of logitboost and gentle adaboost each weak tree is a regression tree,
983         // so we need to convert class labels to floating-point values
984         float* responses = data->get_ord_responses(data->data_root);
985         int* labels = data->get_labels(data->data_root);
986         double w0 = 1./count;
987         double p[2] = { 1, 1 };
988         
989         cvReleaseMat( &orig_response );
990         cvReleaseMat( &sum_response );
991         cvReleaseMat( &weak_eval );
992         cvReleaseMat( &subsample_mask );
993         cvReleaseMat( &weights );
994
995         CV_CALL( orig_response = cvCreateMat( 1, count, CV_32S ));
996         CV_CALL( weak_eval = cvCreateMat( 1, count, CV_64F ));
997         CV_CALL( subsample_mask = cvCreateMat( 1, count, CV_8U ));
998         CV_CALL( weights = cvCreateMat( 1, count, CV_64F ));
999         CV_CALL( subtree_weights = cvCreateMat( 1, count + 2, CV_64F ));
1000
1001         if( data->have_priors )
1002         {
1003             // compute weight scale for each class from their prior probabilities
1004             int c1 = 0;
1005             for( i = 0; i < count; i++ )
1006                 c1 += class_labels[i];
1007             p[0] = data->priors->data.db[0]*(c1 < count ? 1./(count - c1) : 0.);
1008             p[1] = data->priors->data.db[1]*(c1 > 0 ? 1./c1 : 0.);
1009             p[0] /= p[0] + p[1];
1010             p[1] = 1. - p[0];
1011         }
1012
1013         for( i = 0; i < count; i++ )
1014         {
1015             // save original categorical responses {0,1}, convert them to {-1,1}
1016             orig_response->data.i[i] = class_labels[i]*2 - 1;
1017             // make all the samples active at start.
1018             // later, in trim_weights() deactivate/reactive again some, if need
1019             subsample_mask->data.ptr[i] = (uchar)1;
1020             // make all the initial weights the same.
1021             weights->data.db[i] = w0*p[class_labels[i]];
1022             // set the labels to find (from within weak tree learning proc)
1023             // the particular sample weight, and where to store the response.
1024             labels[i] = i;
1025         }
1026
1027         if( params.boost_type == LOGIT )
1028         {
1029             CV_CALL( sum_response = cvCreateMat( 1, count, CV_64F ));
1030             
1031             for( i = 0; i < count; i++ )
1032             {
1033                 sum_response->data.db[i] = 0;
1034                 responses[i] = orig_response->data.i[i] > 0 ? 2.f : -2.f;
1035             }
1036
1037             // in case of logitboost each weak tree is a regression tree.
1038             // the target function values are recalculated for each of the trees
1039             data->is_classifier = false;
1040         }
1041         else if( params.boost_type == GENTLE )
1042         {
1043             for( i = 0; i < count; i++ )
1044                 responses[i] = (float)orig_response->data.i[i];
1045
1046             data->is_classifier = false;
1047         }
1048     }
1049     else
1050     {
1051         // at this moment, for all the samples that participated in the training of the most
1052         // recent weak classifier we know the responses. For other samples we need to compute them
1053         if( have_subsample )
1054         {
1055             float* values = (float*)(data->buf->data.ptr + data->buf->step);
1056             uchar* missing = data->buf->data.ptr + data->buf->step*2;
1057             CvMat _sample, _mask;
1058
1059             // invert the subsample mask
1060             cvXorS( subsample_mask, cvScalar(1.), subsample_mask );
1061             data->get_vectors( subsample_mask, values, missing, 0 );
1062             //data->get_vectors( 0, values, missing, 0 );
1063
1064             _sample = cvMat( 1, data->var_count, CV_32F );
1065             _mask = cvMat( 1, data->var_count, CV_8U );
1066
1067             // run tree through all the non-processed samples
1068             for( i = 0; i < count; i++ )
1069                 if( subsample_mask->data.ptr[i] )
1070                 {
1071                     _sample.data.fl = values;
1072                     _mask.data.ptr = missing;
1073                     values += _sample.cols;
1074                     missing += _mask.cols;
1075                     weak_eval->data.db[i] = tree->predict( &_sample, &_mask, true )->value;
1076                 }
1077         }
1078
1079         // now update weights and other parameters for each type of boosting
1080         if( params.boost_type == DISCRETE )
1081         {
1082             // Discrete AdaBoost:
1083             //   weak_eval[i] (=f(x_i)) is in {-1,1}
1084             //   err = sum(w_i*(f(x_i) != y_i))/sum(w_i)
1085             //   C = log((1-err)/err)
1086             //   w_i *= exp(C*(f(x_i) != y_i))
1087             
1088             double C, err = 0.;
1089             double scale[] = { 1., 0. };
1090
1091             for( i = 0; i < count; i++ )
1092             {
1093                 double w = weights->data.db[i];
1094                 sumw += w;
1095                 err += w*(weak_eval->data.db[i] != orig_response->data.i[i]);
1096             }
1097             
1098             if( sumw != 0 )
1099                 err /= sumw;
1100             C = err = -log_ratio( err );
1101             scale[1] = exp(err);
1102     
1103             sumw = 0;
1104             for( i = 0; i < count; i++ )
1105             {
1106                 double w = weights->data.db[i]*
1107                     scale[weak_eval->data.db[i] != orig_response->data.i[i]];
1108                 sumw += w;
1109                 weights->data.db[i] = w;
1110             }
1111
1112             tree->scale( C );
1113         }
1114         else if( params.boost_type == REAL )
1115         {
1116             // Real AdaBoost:
1117             //   weak_eval[i] = f(x_i) = 0.5*log(p(x_i)/(1-p(x_i))), p(x_i)=P(y=1|x_i)
1118             //   w_i *= exp(-y_i*f(x_i))
1119             
1120             for( i = 0; i < count; i++ )
1121                 weak_eval->data.db[i] *= -orig_response->data.i[i];
1122
1123             cvExp( weak_eval, weak_eval );
1124
1125             for( i = 0; i < count; i++ )
1126             {
1127                 double w = weights->data.db[i]*weak_eval->data.db[i];
1128                 sumw += w;
1129                 weights->data.db[i] = w;
1130             }
1131         }
1132         else if( params.boost_type == LOGIT )
1133         {
1134             // LogitBoost:
1135             //   weak_eval[i] = f(x_i) in [-z_max,z_max]
1136             //   sum_response = F(x_i).
1137             //   F(x_i) += 0.5*f(x_i)
1138             //   p(x_i) = exp(F(x_i))/(exp(F(x_i)) + exp(-F(x_i))=1/(1+exp(-2*F(x_i)))
1139             //   reuse weak_eval: weak_eval[i] <- p(x_i)
1140             //   w_i = p(x_i)*1(1 - p(x_i))
1141             //   z_i = ((y_i+1)/2 - p(x_i))/(p(x_i)*(1 - p(x_i)))
1142             //   store z_i to the data->data_root as the new target responses
1143
1144             const double lb_weight_thresh = FLT_EPSILON;
1145             const double lb_z_max = 10.;
1146             float* responses = data->get_ord_responses(data->data_root);
1147
1148             /*if( weak->total == 7 )
1149                 putchar('*');*/
1150
1151             for( i = 0; i < count; i++ )
1152             {
1153                 double s = sum_response->data.db[i] + 0.5*weak_eval->data.db[i];
1154                 sum_response->data.db[i] = s;
1155                 weak_eval->data.db[i] = -2*s;
1156             }
1157
1158             cvExp( weak_eval, weak_eval );
1159             
1160             for( i = 0; i < count; i++ )
1161             {
1162                 double p = 1./(1. + weak_eval->data.db[i]);
1163                 double w = p*(1 - p), z;
1164                 w = MAX( w, lb_weight_thresh );
1165                 weights->data.db[i] = w;
1166                 sumw += w;
1167                 if( orig_response->data.i[i] > 0 )
1168                 {
1169                     z = 1./p;
1170                     responses[i] = (float)MIN(z, lb_z_max);
1171                 }
1172                 else
1173                 {
1174                     z = 1./(1-p);
1175                     responses[i] = (float)-MIN(z, lb_z_max);
1176                 }
1177             }
1178         }
1179         else
1180         {
1181             // Gentle AdaBoost:
1182             //   weak_eval[i] = f(x_i) in [-1,1]
1183             //   w_i *= exp(-y_i*f(x_i))
1184             assert( params.boost_type == GENTLE );
1185             
1186             for( i = 0; i < count; i++ )
1187                 weak_eval->data.db[i] *= -orig_response->data.i[i];
1188
1189             cvExp( weak_eval, weak_eval );
1190
1191             for( i = 0; i < count; i++ )
1192             {
1193                 double w = weights->data.db[i] * weak_eval->data.db[i];
1194                 weights->data.db[i] = w;
1195                 sumw += w;
1196             }
1197         }
1198     }
1199
1200     // renormalize weights
1201     if( sumw > FLT_EPSILON )
1202     {
1203         sumw = 1./sumw;
1204         for( i = 0; i < count; ++i )
1205             weights->data.db[i] *= sumw;
1206     }
1207
1208     __END__;
1209 }
1210
1211
1212 static CV_IMPLEMENT_QSORT_EX( icvSort_64f, double, CV_LT, int )
1213
1214
1215 void
1216 CvBoost::trim_weights()
1217 {
1218     CV_FUNCNAME( "CvBoost::trim_weights" );
1219
1220     __BEGIN__;
1221
1222     int i, count = data->sample_count, nz_count = 0;
1223     double sum, threshold;
1224
1225     if( params.weight_trim_rate <= 0. || params.weight_trim_rate >= 1. )
1226         EXIT;
1227
1228     // use weak_eval as temporary buffer for sorted weights
1229     cvCopy( weights, weak_eval );
1230
1231     icvSort_64f( weak_eval->data.db, count, 0 );
1232
1233     // as weight trimming occurs immediately after updating the weights,
1234     // where they are renormalized, we assume that the weight sum = 1.
1235     sum = 1. - params.weight_trim_rate;
1236
1237     for( i = 0; i < count; i++ )
1238     {
1239         double w = weak_eval->data.db[i];
1240         if( sum > w )
1241             break;
1242         sum -= w;
1243     }
1244
1245     threshold = i < count ? weak_eval->data.db[i] : DBL_MAX;
1246
1247     for( i = 0; i < count; i++ )
1248     {
1249         double w = weights->data.db[i];
1250         int f = w > threshold;
1251         subsample_mask->data.ptr[i] = (uchar)f;
1252         nz_count += f;
1253     }
1254
1255     have_subsample = nz_count < count;
1256
1257     __END__;
1258 }
1259
1260
1261 float
1262 CvBoost::predict( const CvMat* _sample, const CvMat* _missing,
1263                   CvMat* weak_responses, CvSlice slice,
1264                   bool raw_mode ) const
1265 {
1266     float* buf = 0;
1267     bool allocated = false;
1268     float value = -FLT_MAX;
1269     
1270     CV_FUNCNAME( "CvBoost::predict" );
1271
1272     __BEGIN__;
1273
1274     int i, weak_count, var_count;
1275     CvMat sample, missing;
1276     CvSeqReader reader;
1277     double sum = 0;
1278     int cls_idx;
1279     int wstep = 0;
1280     const int* vtype;
1281     const int* cmap;
1282     const int* cofs;
1283
1284     if( !weak )
1285         CV_ERROR( CV_StsError, "The boosted tree ensemble has not been trained yet" );
1286
1287     if( !CV_IS_MAT(_sample) || CV_MAT_TYPE(_sample->type) != CV_32FC1 ||
1288         _sample->cols != 1 && _sample->rows != 1 ||
1289         _sample->cols + _sample->rows - 1 != data->var_all && !raw_mode ||
1290         _sample->cols + _sample->rows - 1 != data->var_count && raw_mode )
1291             CV_ERROR( CV_StsBadArg,
1292         "the input sample must be 1d floating-point vector with the same "
1293         "number of elements as the total number of variables used for training" );
1294
1295     if( _missing )
1296     {
1297         if( !CV_IS_MAT(_missing) || !CV_IS_MASK_ARR(_missing) ||
1298             !CV_ARE_SIZES_EQ(_missing, _sample) )
1299             CV_ERROR( CV_StsBadArg,
1300             "the missing data mask must be 8-bit vector of the same size as input sample" );
1301     }
1302
1303     weak_count = cvSliceLength( slice, weak );
1304     if( weak_count >= weak->total )
1305     {
1306         weak_count = weak->total;
1307         slice.start_index = 0;
1308     }
1309
1310     if( weak_responses )
1311     {
1312         if( !CV_IS_MAT(weak_responses) ||
1313             CV_MAT_TYPE(weak_responses->type) != CV_32FC1 ||
1314             weak_responses->cols != 1 && weak_responses->rows != 1 ||
1315             weak_responses->cols + weak_responses->rows - 1 != weak_count )
1316             CV_ERROR( CV_StsBadArg,
1317             "The output matrix of weak classifier responses must be valid "
1318             "floating-point vector of the same number of components as the length of input slice" );
1319         wstep = CV_IS_MAT_CONT(weak_responses->type) ? 1 : weak_responses->step/sizeof(float);
1320     }
1321
1322     var_count = data->var_count;
1323     vtype = data->var_type->data.i;
1324     cmap = data->cat_map->data.i;
1325     cofs = data->cat_ofs->data.i;
1326
1327     // if need, preprocess the input vector
1328     if( !raw_mode && (data->cat_var_count > 0 || data->var_idx) )
1329     {
1330         int bufsize;
1331         int step, mstep = 0;
1332         const float* src_sample;
1333         const uchar* src_mask = 0;
1334         float* dst_sample;
1335         uchar* dst_mask;
1336         const int* vidx = data->var_idx && !raw_mode ? data->var_idx->data.i : 0;
1337         bool have_mask = _missing != 0;
1338
1339         bufsize = var_count*(sizeof(float) + sizeof(uchar));
1340         if( bufsize <= CV_MAX_LOCAL_SIZE )
1341             buf = (float*)cvStackAlloc( bufsize );
1342         else
1343         {
1344             CV_CALL( buf = (float*)cvAlloc( bufsize ));
1345             allocated = true;
1346         }
1347         dst_sample = buf;
1348         dst_mask = (uchar*)(buf + var_count);
1349
1350         src_sample = _sample->data.fl;
1351         step = CV_IS_MAT_CONT(_sample->type) ? 1 : _sample->step/sizeof(src_sample[0]);
1352
1353         if( _missing )
1354         {
1355             src_mask = _missing->data.ptr;
1356             mstep = CV_IS_MAT_CONT(_missing->type) ? 1 : _missing->step;
1357         }
1358
1359         for( i = 0; i < var_count; i++ )
1360         {
1361             int idx = vidx ? vidx[i] : i;
1362             float val = src_sample[idx*step];
1363             int ci = vtype[i];
1364             uchar m = src_mask ? src_mask[i] : (uchar)0;
1365
1366             if( ci >= 0 )
1367             {
1368                 int a = cofs[ci], b = cofs[ci+1], c = a;
1369                 int ival = cvRound(val);
1370                 if( ival != val )
1371                     CV_ERROR( CV_StsBadArg,
1372                     "one of input categorical variable is not an integer" );
1373
1374                 while( a < b )
1375                 {
1376                     c = (a + b) >> 1;
1377                     if( ival < cmap[c] )
1378                         b = c;
1379                     else if( ival > cmap[c] )
1380                         a = c+1;
1381                     else
1382                         break;
1383                 }
1384
1385                 if( c < 0 || ival != cmap[c] )
1386                 {
1387                     m = 1;
1388                     have_mask = true;
1389                 }
1390                 else
1391                 {
1392                     val = (float)(c - cofs[ci]);
1393                 }
1394             }
1395
1396             dst_sample[i] = val;
1397             dst_mask[i] = m;
1398         }
1399
1400         sample = cvMat( 1, var_count, CV_32F, dst_sample );
1401         _sample = &sample;
1402
1403         if( have_mask )
1404         {
1405             missing = cvMat( 1, var_count, CV_8UC1, dst_mask );
1406             _missing = &missing;
1407         }
1408     }
1409
1410     cvStartReadSeq( weak, &reader );
1411     cvSetSeqReaderPos( &reader, slice.start_index );
1412
1413     for( i = 0; i < weak_count; i++ )
1414     {
1415         CvBoostTree* wtree;
1416         double val;
1417
1418         CV_READ_SEQ_ELEM( wtree, reader );
1419
1420         val = wtree->predict( _sample, _missing, true )->value;
1421         if( weak_responses )
1422             weak_responses->data.fl[i*wstep] = (float)val;
1423
1424         sum += val;
1425     }
1426
1427     cls_idx = sum >= 0;
1428     if( raw_mode )
1429         value = (float)cls_idx;
1430     else
1431         value = (float)cmap[cofs[vtype[var_count]] + cls_idx];
1432
1433     __END__;
1434
1435     if( allocated )
1436         cvFree( &buf );
1437
1438     return value;
1439 }
1440
1441
1442
1443 void CvBoost::write_params( CvFileStorage* fs )
1444 {
1445     CV_FUNCNAME( "CvBoost::write_params" );
1446
1447     __BEGIN__;
1448
1449     const char* boost_type_str =
1450         params.boost_type == DISCRETE ? "DiscreteAdaboost" :
1451         params.boost_type == REAL ? "RealAdaboost" :
1452         params.boost_type == LOGIT ? "LogitBoost" :
1453         params.boost_type == GENTLE ? "GentleAdaboost" : 0;
1454
1455     const char* split_crit_str =
1456         params.split_criteria == DEFAULT ? "Default" :
1457         params.split_criteria == GINI ? "Gini" :
1458         params.boost_type == MISCLASS ? "Misclassification" :
1459         params.boost_type == SQERR ? "SquaredErr" : 0;
1460
1461     if( boost_type_str )
1462         cvWriteString( fs, "boosting_type", boost_type_str );
1463     else
1464         cvWriteInt( fs, "boosting_type", params.boost_type );
1465
1466     if( split_crit_str )
1467         cvWriteString( fs, "splitting_criteria", split_crit_str );
1468     else
1469         cvWriteInt( fs, "splitting_criteria", params.split_criteria );
1470
1471     cvWriteInt( fs, "ntrees", params.weak_count );
1472     cvWriteReal( fs, "weight_trimming_rate", params.weight_trim_rate );
1473
1474     data->write_params( fs );
1475
1476     __END__;
1477 }
1478
1479
1480 void CvBoost::read_params( CvFileStorage* fs, CvFileNode* fnode )
1481 {
1482     CV_FUNCNAME( "CvBoost::read_params" );
1483
1484     __BEGIN__;
1485
1486     CvFileNode* temp;
1487
1488     if( !fnode || !CV_NODE_IS_MAP(fnode->tag) )
1489         return;
1490
1491     data = new CvDTreeTrainData();
1492     CV_CALL( data->read_params(fs, fnode));
1493     data->shared = true;
1494
1495     params.max_depth = data->params.max_depth;
1496     params.min_sample_count = data->params.min_sample_count;
1497     params.max_categories = data->params.max_categories;
1498     params.priors = data->params.priors;
1499     params.regression_accuracy = data->params.regression_accuracy;
1500     params.use_surrogates = data->params.use_surrogates;
1501
1502     temp = cvGetFileNodeByName( fs, fnode, "boosting_type" );
1503     if( !temp )
1504         return;
1505
1506     if( temp && CV_NODE_IS_STRING(temp->tag) )
1507     {
1508         const char* boost_type_str = cvReadString( temp, "" );
1509         params.boost_type = strcmp( boost_type_str, "DiscreteAdaboost" ) == 0 ? DISCRETE :
1510                             strcmp( boost_type_str, "RealAdaboost" ) == 0 ? REAL :
1511                             strcmp( boost_type_str, "LogitBoost" ) == 0 ? LOGIT :
1512                             strcmp( boost_type_str, "GentleAdaboost" ) == 0 ? GENTLE : -1;
1513     }
1514     else
1515         params.boost_type = cvReadInt( temp, -1 );
1516
1517     if( params.boost_type < DISCRETE || params.boost_type > GENTLE )
1518         CV_ERROR( CV_StsBadArg, "Unknown boosting type" );
1519
1520     temp = cvGetFileNodeByName( fs, fnode, "splitting_criteria" );
1521     if( temp && CV_NODE_IS_STRING(temp->tag) )
1522     {
1523         const char* split_crit_str = cvReadString( temp, "" );
1524         params.split_criteria = strcmp( split_crit_str, "Default" ) == 0 ? DEFAULT :
1525                                 strcmp( split_crit_str, "Gini" ) == 0 ? GINI :
1526                                 strcmp( split_crit_str, "Misclassification" ) == 0 ? MISCLASS :
1527                                 strcmp( split_crit_str, "SquaredErr" ) == 0 ? SQERR : -1;
1528     }
1529     else
1530         params.split_criteria = cvReadInt( temp, -1 );
1531
1532     if( params.split_criteria < DEFAULT || params.boost_type > SQERR )
1533         CV_ERROR( CV_StsBadArg, "Unknown boosting type" );
1534
1535     params.weak_count = cvReadIntByName( fs, fnode, "ntrees" );
1536     params.weight_trim_rate = cvReadRealByName( fs, fnode, "weight_trimming_rate", 0. );
1537
1538     __END__;
1539 }
1540
1541
1542
1543 void
1544 CvBoost::read( CvFileStorage* fs, CvFileNode* node )
1545 {
1546     CV_FUNCNAME( "CvRTrees::read" );
1547
1548     __BEGIN__;
1549
1550     CvSeqReader reader;
1551     CvFileNode* trees_fnode;
1552     CvMemStorage* storage;
1553     int i, ntrees;
1554
1555     clear();
1556     read_params( fs, node );
1557
1558     if( !data )
1559         EXIT;
1560         
1561     trees_fnode = cvGetFileNodeByName( fs, node, "trees" );
1562     if( !trees_fnode || !CV_NODE_IS_SEQ(trees_fnode->tag) )
1563         CV_ERROR( CV_StsParseError, "<trees> tag is missing" );
1564
1565     cvStartReadSeq( trees_fnode->data.seq, &reader );
1566     ntrees = trees_fnode->data.seq->total;
1567
1568     if( ntrees != params.weak_count )
1569         CV_ERROR( CV_StsUnmatchedSizes,
1570         "The number of trees stored does not match <ntrees> tag value" );
1571
1572     CV_CALL( storage = cvCreateMemStorage() );
1573     weak = cvCreateSeq( 0, sizeof(CvSeq), sizeof(CvBoostTree*), storage );
1574
1575     for( i = 0; i < ntrees; i++ )
1576     {
1577         CvBoostTree* tree = new CvBoostTree();
1578         CV_CALL(tree->read( fs, (CvFileNode*)reader.ptr, this, data ));
1579         CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
1580         cvSeqPush( weak, &tree );
1581     }
1582
1583     __END__;
1584 }
1585
1586
1587 void
1588 CvBoost::write( CvFileStorage* fs, const char* name )
1589 {
1590     CV_FUNCNAME( "CvBoost::write" );
1591
1592     __BEGIN__;
1593     
1594     CvSeqReader reader;
1595     int i;
1596
1597     cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_BOOSTING );
1598
1599     if( !weak )
1600         CV_ERROR( CV_StsBadArg, "The classifier has not been trained yet" );
1601         
1602     write_params( fs );
1603     cvStartWriteStruct( fs, "trees", CV_NODE_SEQ );
1604
1605     cvStartReadSeq( weak, &reader );
1606
1607     for( i = 0; i < weak->total; i++ )
1608     {
1609         CvBoostTree* tree;
1610         CV_READ_SEQ_ELEM( tree, reader );
1611         cvStartWriteStruct( fs, 0, CV_NODE_MAP );
1612         tree->write( fs );
1613         cvEndWriteStruct( fs );
1614     }
1615
1616     cvEndWriteStruct( fs );
1617     cvEndWriteStruct( fs );
1618
1619     __END__;
1620 }
1621
1622
1623 CvMat*
1624 CvBoost::get_weights()
1625 {
1626     return weights;
1627 }
1628
1629
1630 CvMat*
1631 CvBoost::get_subtree_weights()
1632 {
1633     return subtree_weights;
1634 }
1635
1636
1637 CvMat*
1638 CvBoost::get_weak_response()
1639 {
1640     return weak_eval;
1641 }
1642
1643
1644 const CvBoostParams&
1645 CvBoost::get_params() const
1646 {
1647     return params;
1648 }
1649
1650 CvSeq* CvBoost::get_weak_predictors()
1651 {
1652     return weak;
1653 }
1654
1655 /* End of file. */