Update to 2.0.0 tree from current Fremantle build
[opencv] / src / ml / mltree.cpp
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                        Intel License Agreement
11 //
12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
13 // Third party copyrights are property of their respective owners.
14 //
15 // Redistribution and use in source and binary forms, with or without modification,
16 // are permitted provided that the following conditions are met:
17 //
18 //   * Redistribution's of source code must retain the above copyright notice,
19 //     this list of conditions and the following disclaimer.
20 //
21 //   * Redistribution's in binary form must reproduce the above copyright notice,
22 //     this list of conditions and the following disclaimer in the documentation
23 //     and/or other materials provided with the distribution.
24 //
25 //   * The name of Intel Corporation may not be used to endorse or promote products
26 //     derived from this software without specific prior written permission.
27 //
28 // This software is provided by the copyright holders and contributors "as is" and
29 // any express or implied warranties, including, but not limited to, the implied
30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
31 // In no event shall the Intel Corporation or contributors be liable for any direct,
32 // indirect, incidental, special, exemplary, or consequential damages
33 // (including, but not limited to, procurement of substitute goods or services;
34 // loss of use, data, or profits; or business interruption) however caused
35 // and on any theory of liability, whether in contract, strict liability,
36 // or tort (including negligence or otherwise) arising in any way out of
37 // the use of this software, even if advised of the possibility of such damage.
38 //
39 //M*/
40
41 #include "_ml.h"
42 #include <ctype.h>
43
44 using namespace cv;
45
46 static const float ord_nan = FLT_MAX*0.5f;
47 static const int min_block_size = 1 << 16;
48 static const int block_size_delta = 1 << 10;
49
50 CvDTreeTrainData::CvDTreeTrainData()
51 {
52     var_idx = var_type = cat_count = cat_ofs = cat_map =
53         priors = priors_mult = counts = buf = direction = split_buf = responses_copy = 0;
54     tree_storage = temp_storage = 0;
55
56     clear();
57 }
58
59
60 CvDTreeTrainData::CvDTreeTrainData( const CvMat* _train_data, int _tflag,
61                       const CvMat* _responses, const CvMat* _var_idx,
62                       const CvMat* _sample_idx, const CvMat* _var_type,
63                       const CvMat* _missing_mask, const CvDTreeParams& _params,
64                       bool _shared, bool _add_labels )
65 {
66     var_idx = var_type = cat_count = cat_ofs = cat_map =
67         priors = priors_mult = counts = buf = direction = split_buf = responses_copy = 0;
68
69     tree_storage = temp_storage = 0;
70
71     set_data( _train_data, _tflag, _responses, _var_idx, _sample_idx,
72               _var_type, _missing_mask, _params, _shared, _add_labels );
73 }
74
75
76 CvDTreeTrainData::~CvDTreeTrainData()
77 {
78     clear();
79 }
80
81
82 bool CvDTreeTrainData::set_params( const CvDTreeParams& _params )
83 {
84     bool ok = false;
85
86     CV_FUNCNAME( "CvDTreeTrainData::set_params" );
87
88     __BEGIN__;
89
90     // set parameters
91     params = _params;
92
93     if( params.max_categories < 2 )
94         CV_ERROR( CV_StsOutOfRange, "params.max_categories should be >= 2" );
95     params.max_categories = MIN( params.max_categories, 15 );
96
97     if( params.max_depth < 0 )
98         CV_ERROR( CV_StsOutOfRange, "params.max_depth should be >= 0" );
99     params.max_depth = MIN( params.max_depth, 25 );
100
101     params.min_sample_count = MAX(params.min_sample_count,1);
102
103     if( params.cv_folds < 0 )
104         CV_ERROR( CV_StsOutOfRange,
105         "params.cv_folds should be =0 (the tree is not pruned) "
106         "or n>0 (tree is pruned using n-fold cross-validation)" );
107
108     if( params.cv_folds == 1 )
109         params.cv_folds = 0;
110
111     if( params.regression_accuracy < 0 )
112         CV_ERROR( CV_StsOutOfRange, "params.regression_accuracy should be >= 0" );
113
114     ok = true;
115
116     __END__;
117
118     return ok;
119 }
120
121 #define CV_CMP_NUM_PTR(a,b) (*(a) < *(b))
122 static CV_IMPLEMENT_QSORT_EX( icvSortIntPtr, int*, CV_CMP_NUM_PTR, int )
123 static CV_IMPLEMENT_QSORT_EX( icvSortDblPtr, double*, CV_CMP_NUM_PTR, int )
124
125 #define CV_CMP_NUM_IDX(i,j) (aux[i] < aux[j])
126 static CV_IMPLEMENT_QSORT_EX( icvSortIntAux, int, CV_CMP_NUM_IDX, const float* )
127 static CV_IMPLEMENT_QSORT_EX( icvSortUShAux, unsigned short, CV_CMP_NUM_IDX, const float* )
128
129 #define CV_CMP_PAIRS(a,b) (*((a).i) < *((b).i))
130 static CV_IMPLEMENT_QSORT_EX( icvSortPairs, CvPair16u32s, CV_CMP_PAIRS, int )
131
132 void CvDTreeTrainData::set_data( const CvMat* _train_data, int _tflag,
133     const CvMat* _responses, const CvMat* _var_idx, const CvMat* _sample_idx,
134     const CvMat* _var_type, const CvMat* _missing_mask, const CvDTreeParams& _params,
135     bool _shared, bool _add_labels, bool _update_data )
136 {
137     CvMat* sample_indices = 0;
138     CvMat* var_type0 = 0;
139     CvMat* tmp_map = 0;
140     int** int_ptr = 0;
141     CvPair16u32s* pair16u32s_ptr = 0;
142     CvDTreeTrainData* data = 0;
143     float *_fdst = 0;
144     int *_idst = 0;
145     unsigned short* udst = 0;
146     int* idst = 0;
147
148     CV_FUNCNAME( "CvDTreeTrainData::set_data" );
149
150     __BEGIN__;
151
152     int sample_all = 0, r_type = 0, cv_n;
153     int total_c_count = 0;
154     int tree_block_size, temp_block_size, max_split_size, nv_size, cv_size = 0;
155     int ds_step, dv_step, ms_step = 0, mv_step = 0; // {data|mask}{sample|var}_step
156     int vi, i, size;
157     char err[100];
158     const int *sidx = 0, *vidx = 0;
159     
160     if( _update_data && data_root )
161     {
162         data = new CvDTreeTrainData( _train_data, _tflag, _responses, _var_idx,
163             _sample_idx, _var_type, _missing_mask, _params, _shared, _add_labels );
164
165         // compare new and old train data
166         if( !(data->var_count == var_count &&
167             cvNorm( data->var_type, var_type, CV_C ) < FLT_EPSILON &&
168             cvNorm( data->cat_count, cat_count, CV_C ) < FLT_EPSILON &&
169             cvNorm( data->cat_map, cat_map, CV_C ) < FLT_EPSILON) )
170             CV_ERROR( CV_StsBadArg,
171             "The new training data must have the same types and the input and output variables "
172             "and the same categories for categorical variables" );
173
174         cvReleaseMat( &priors );
175         cvReleaseMat( &priors_mult );
176         cvReleaseMat( &buf );
177         cvReleaseMat( &direction );
178         cvReleaseMat( &split_buf );
179         cvReleaseMemStorage( &temp_storage );
180
181         priors = data->priors; data->priors = 0;
182         priors_mult = data->priors_mult; data->priors_mult = 0;
183         buf = data->buf; data->buf = 0;
184         buf_count = data->buf_count; buf_size = data->buf_size;
185         sample_count = data->sample_count;
186
187         direction = data->direction; data->direction = 0;
188         split_buf = data->split_buf; data->split_buf = 0;
189         temp_storage = data->temp_storage; data->temp_storage = 0;
190         nv_heap = data->nv_heap; cv_heap = data->cv_heap;
191
192         data_root = new_node( 0, sample_count, 0, 0 );
193         EXIT;
194     }
195
196     clear();
197
198     var_all = 0;
199     rng = cvRNG(-1);
200
201     CV_CALL( set_params( _params ));
202
203     // check parameter types and sizes
204     CV_CALL( cvCheckTrainData( _train_data, _tflag, _missing_mask, &var_all, &sample_all ));
205
206     train_data = _train_data;
207     responses = _responses;
208
209     if( _tflag == CV_ROW_SAMPLE )
210     {
211         ds_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
212         dv_step = 1;
213         if( _missing_mask )
214             ms_step = _missing_mask->step, mv_step = 1;
215     }
216     else
217     {
218         dv_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
219         ds_step = 1;
220         if( _missing_mask )
221             mv_step = _missing_mask->step, ms_step = 1;
222     }
223     tflag = _tflag;
224
225     sample_count = sample_all;
226     var_count = var_all;
227     
228     if( _sample_idx )
229     {
230         CV_CALL( sample_indices = cvPreprocessIndexArray( _sample_idx, sample_all ));
231         sidx = sample_indices->data.i;
232         sample_count = sample_indices->rows + sample_indices->cols - 1;
233     }
234
235     if( _var_idx )
236     {
237         CV_CALL( var_idx = cvPreprocessIndexArray( _var_idx, var_all ));
238         vidx = var_idx->data.i;
239         var_count = var_idx->rows + var_idx->cols - 1;
240     }
241
242     is_buf_16u = false;     
243     if ( sample_count < 65536 ) 
244         is_buf_16u = true;                                
245     
246     if( !CV_IS_MAT(_responses) ||
247         (CV_MAT_TYPE(_responses->type) != CV_32SC1 &&
248          CV_MAT_TYPE(_responses->type) != CV_32FC1) ||
249         (_responses->rows != 1 && _responses->cols != 1) ||
250         _responses->rows + _responses->cols - 1 != sample_all )
251         CV_ERROR( CV_StsBadArg, "The array of _responses must be an integer or "
252                   "floating-point vector containing as many elements as "
253                   "the total number of samples in the training data matrix" );
254    
255   
256     CV_CALL( var_type0 = cvPreprocessVarType( _var_type, var_idx, var_count, &r_type ));
257
258     CV_CALL( var_type = cvCreateMat( 1, var_count+2, CV_32SC1 ));
259    
260     
261     cat_var_count = 0;
262     ord_var_count = -1;
263
264     is_classifier = r_type == CV_VAR_CATEGORICAL;
265
266     // step 0. calc the number of categorical vars
267     for( vi = 0; vi < var_count; vi++ )
268     {
269         var_type->data.i[vi] = var_type0->data.ptr[vi] == CV_VAR_CATEGORICAL ?
270             cat_var_count++ : ord_var_count--;
271     }
272
273     ord_var_count = ~ord_var_count;
274     cv_n = params.cv_folds;
275     // set the two last elements of var_type array to be able
276     // to locate responses and cross-validation labels using
277     // the corresponding get_* functions.
278     var_type->data.i[var_count] = cat_var_count;
279     var_type->data.i[var_count+1] = cat_var_count+1;
280
281     // in case of single ordered predictor we need dummy cv_labels
282     // for safe split_node_data() operation
283     have_labels = cv_n > 0 || (ord_var_count == 1 && cat_var_count == 0) || _add_labels;
284
285     work_var_count = var_count + (is_classifier ? 1 : 0) + (have_labels ? 1 : 0);
286     buf_size = (work_var_count + 1)*sample_count;
287     shared = _shared;
288     buf_count = shared ? 2 : 1;
289     
290     if ( is_buf_16u )
291     {
292         CV_CALL( buf = cvCreateMat( buf_count, buf_size, CV_16UC1 ));
293         CV_CALL( pair16u32s_ptr = (CvPair16u32s*)cvAlloc( sample_count*sizeof(pair16u32s_ptr[0]) ));
294     }
295     else
296     {
297         CV_CALL( buf = cvCreateMat( buf_count, buf_size, CV_32SC1 ));
298         CV_CALL( int_ptr = (int**)cvAlloc( sample_count*sizeof(int_ptr[0]) ));
299     }    
300
301     size = is_classifier ? (cat_var_count+1) : cat_var_count;
302     size = !size ? 1 : size;
303     CV_CALL( cat_count = cvCreateMat( 1, size, CV_32SC1 ));
304     CV_CALL( cat_ofs = cvCreateMat( 1, size, CV_32SC1 ));
305         
306     size = is_classifier ? (cat_var_count + 1)*params.max_categories : cat_var_count*params.max_categories;
307     size = !size ? 1 : size;
308     CV_CALL( cat_map = cvCreateMat( 1, size, CV_32SC1 ));
309
310     // now calculate the maximum size of split,
311     // create memory storage that will keep nodes and splits of the decision tree
312     // allocate root node and the buffer for the whole training data
313     max_split_size = cvAlign(sizeof(CvDTreeSplit) +
314         (MAX(0,sample_count - 33)/32)*sizeof(int),sizeof(void*));
315     tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);
316     tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);
317     CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));
318     CV_CALL( node_heap = cvCreateSet( 0, sizeof(*node_heap), sizeof(CvDTreeNode), tree_storage ));
319
320     nv_size = var_count*sizeof(int);
321     nv_size = cvAlign(MAX( nv_size, (int)sizeof(CvSetElem) ), sizeof(void*));
322
323     temp_block_size = nv_size;
324
325     if( cv_n )
326     {
327         if( sample_count < cv_n*MAX(params.min_sample_count,10) )
328             CV_ERROR( CV_StsOutOfRange,
329                 "The many folds in cross-validation for such a small dataset" );
330
331         cv_size = cvAlign( cv_n*(sizeof(int) + sizeof(double)*2), sizeof(double) );
332         temp_block_size = MAX(temp_block_size, cv_size);
333     }
334
335     temp_block_size = MAX( temp_block_size + block_size_delta, min_block_size );
336     CV_CALL( temp_storage = cvCreateMemStorage( temp_block_size ));
337     CV_CALL( nv_heap = cvCreateSet( 0, sizeof(*nv_heap), nv_size, temp_storage ));
338     if( cv_size )
339         CV_CALL( cv_heap = cvCreateSet( 0, sizeof(*cv_heap), cv_size, temp_storage ));
340
341     CV_CALL( data_root = new_node( 0, sample_count, 0, 0 ));
342
343     max_c_count = 1;
344
345     _fdst = 0;
346     _idst = 0;
347     if (ord_var_count)
348         _fdst = (float*)cvAlloc(sample_count*sizeof(_fdst[0]));
349     if (is_buf_16u && (cat_var_count || is_classifier))
350         _idst = (int*)cvAlloc(sample_count*sizeof(_idst[0]));
351
352     // transform the training data to convenient representation
353     for( vi = 0; vi <= var_count; vi++ )
354     {
355         int ci;
356         const uchar* mask = 0;
357         int m_step = 0, step;
358         const int* idata = 0;
359         const float* fdata = 0;
360         int num_valid = 0;
361
362         if( vi < var_count ) // analyze i-th input variable
363         {
364             int vi0 = vidx ? vidx[vi] : vi;
365             ci = get_var_type(vi);
366             step = ds_step; m_step = ms_step;
367             if( CV_MAT_TYPE(_train_data->type) == CV_32SC1 )
368                 idata = _train_data->data.i + vi0*dv_step;
369             else
370                 fdata = _train_data->data.fl + vi0*dv_step;
371             if( _missing_mask )
372                 mask = _missing_mask->data.ptr + vi0*mv_step;
373         }
374         else // analyze _responses
375         {
376             ci = cat_var_count;
377             step = CV_IS_MAT_CONT(_responses->type) ?
378                 1 : _responses->step / CV_ELEM_SIZE(_responses->type);
379             if( CV_MAT_TYPE(_responses->type) == CV_32SC1 )
380                 idata = _responses->data.i;
381             else
382                 fdata = _responses->data.fl;
383         }
384
385         if( (vi < var_count && ci>=0) ||
386             (vi == var_count && is_classifier) ) // process categorical variable or response
387         {
388             int c_count, prev_label;
389             int* c_map;
390             
391             if (is_buf_16u)
392                 udst = (unsigned short*)(buf->data.s + vi*sample_count);
393             else
394                 idst = buf->data.i + vi*sample_count;
395             
396             // copy data
397             for( i = 0; i < sample_count; i++ )
398             {
399                 int val = INT_MAX, si = sidx ? sidx[i] : i;
400                 if( !mask || !mask[si*m_step] )
401                 {
402                     if( idata )
403                         val = idata[si*step];
404                     else
405                     {
406                         float t = fdata[si*step];
407                         val = cvRound(t);
408                         if( fabs(t - val) > FLT_EPSILON )
409                         {
410                             sprintf( err, "%d-th value of %d-th (categorical) "
411                                 "variable is not an integer", i, vi );
412                             CV_ERROR( CV_StsBadArg, err );
413                         }
414                     }
415
416                     if( val == INT_MAX )
417                     {
418                         sprintf( err, "%d-th value of %d-th (categorical) "
419                             "variable is too large", i, vi );
420                         CV_ERROR( CV_StsBadArg, err );
421                     }
422                     num_valid++;
423                 }
424                 if (is_buf_16u)
425                 {
426                     _idst[i] = val;
427                     pair16u32s_ptr[i].u = udst + i;
428                     pair16u32s_ptr[i].i = _idst + i;
429                 }   
430                 else
431                 {
432                     idst[i] = val;
433                     int_ptr[i] = idst + i;
434                 }
435             }
436
437             c_count = num_valid > 0;
438             if (is_buf_16u)
439             {
440                 icvSortPairs( pair16u32s_ptr, sample_count, 0 );
441                 // count the categories
442                 for( i = 1; i < num_valid; i++ )
443                     if (*pair16u32s_ptr[i].i != *pair16u32s_ptr[i-1].i)
444                         c_count ++ ;
445             }
446             else
447             {
448                 icvSortIntPtr( int_ptr, sample_count, 0 );
449                 // count the categories
450                 for( i = 1; i < num_valid; i++ )
451                     c_count += *int_ptr[i] != *int_ptr[i-1];
452             }
453
454             if( vi > 0 )
455                 max_c_count = MAX( max_c_count, c_count );
456             cat_count->data.i[ci] = c_count;
457             cat_ofs->data.i[ci] = total_c_count;
458
459             // resize cat_map, if need
460             if( cat_map->cols < total_c_count + c_count )
461             {
462                 tmp_map = cat_map;
463                 CV_CALL( cat_map = cvCreateMat( 1,
464                     MAX(cat_map->cols*3/2,total_c_count+c_count), CV_32SC1 ));
465                 for( i = 0; i < total_c_count; i++ )
466                     cat_map->data.i[i] = tmp_map->data.i[i];
467                 cvReleaseMat( &tmp_map );
468             }
469
470             c_map = cat_map->data.i + total_c_count;
471             total_c_count += c_count;
472
473             c_count = -1;
474             if (is_buf_16u)
475             {
476                 // compact the class indices and build the map
477                 prev_label = ~*pair16u32s_ptr[0].i;
478                 for( i = 0; i < num_valid; i++ )
479                 {
480                     int cur_label = *pair16u32s_ptr[i].i;
481                     if( cur_label != prev_label )
482                         c_map[++c_count] = prev_label = cur_label;
483                     *pair16u32s_ptr[i].u = (unsigned short)c_count;
484                 }
485                 // replace labels for missing values with -1
486                 for( ; i < sample_count; i++ )
487                     *pair16u32s_ptr[i].u = 65535;
488             }
489             else
490             {
491                 // compact the class indices and build the map
492                 prev_label = ~*int_ptr[0];
493                 for( i = 0; i < num_valid; i++ )
494                 {
495                     int cur_label = *int_ptr[i];
496                     if( cur_label != prev_label )
497                         c_map[++c_count] = prev_label = cur_label;
498                     *int_ptr[i] = c_count;
499                 }
500                 // replace labels for missing values with -1
501                 for( ; i < sample_count; i++ )
502                     *int_ptr[i] = -1;
503             }           
504         }
505         else if( ci < 0 ) // process ordered variable
506         {
507             if (is_buf_16u)
508                 udst = (unsigned short*)(buf->data.s + vi*sample_count);
509             else
510                 idst = buf->data.i + vi*sample_count;
511
512             for( i = 0; i < sample_count; i++ )
513             {
514                 float val = ord_nan;
515                 int si = sidx ? sidx[i] : i;
516                 if( !mask || !mask[si*m_step] )
517                 {
518                     if( idata )
519                         val = (float)idata[si*step];
520                     else
521                         val = fdata[si*step];
522
523                     if( fabs(val) >= ord_nan )
524                     {
525                         sprintf( err, "%d-th value of %d-th (ordered) "
526                             "variable (=%g) is too large", i, vi, val );
527                         CV_ERROR( CV_StsBadArg, err );
528                     }
529                 }
530                 num_valid++;
531                 if (is_buf_16u)
532                     udst[i] = (unsigned short)i;
533                 else
534                     idst[i] = i; // Ã”ÂÂÌÂÒÚˠ‚˚¯Â â€š if( idata )
535                 _fdst[i] = val;
536                 
537             }
538             if (is_buf_16u)
539                 icvSortUShAux( udst, num_valid, _fdst);
540             else
541                 icvSortIntAux( idst, /*or num_valid?\*/ sample_count, _fdst );
542         }
543        
544         if( vi < var_count )
545             data_root->set_num_valid(vi, num_valid);
546     }
547
548     // set sample labels
549     if (is_buf_16u)
550         udst = (unsigned short*)(buf->data.s + work_var_count*sample_count);
551     else
552         idst = buf->data.i + work_var_count*sample_count;
553
554     for (i = 0; i < sample_count; i++)
555     {
556         if (udst)
557             udst[i] = sidx ? (unsigned short)sidx[i] : (unsigned short)i;
558         else
559             idst[i] = sidx ? sidx[i] : i;
560     }
561
562     if( cv_n )
563     {
564         unsigned short* udst = 0;
565         int* idst = 0;
566         CvRNG* r = &rng;
567
568         if (is_buf_16u)
569         {
570             udst = (unsigned short*)(buf->data.s + (get_work_var_count()-1)*sample_count);
571             for( i = vi = 0; i < sample_count; i++ )
572             {
573                 udst[i] = (unsigned short)vi++;
574                 vi &= vi < cv_n ? -1 : 0;
575             }
576
577             for( i = 0; i < sample_count; i++ )
578             {
579                 int a = cvRandInt(r) % sample_count;
580                 int b = cvRandInt(r) % sample_count;
581                 unsigned short unsh = (unsigned short)vi;
582                 CV_SWAP( udst[a], udst[b], unsh );
583             }
584         }
585         else
586         {
587             idst = buf->data.i + (get_work_var_count()-1)*sample_count;
588             for( i = vi = 0; i < sample_count; i++ )
589             {
590                 idst[i] = vi++;
591                 vi &= vi < cv_n ? -1 : 0;
592             }
593
594             for( i = 0; i < sample_count; i++ )
595             {
596                 int a = cvRandInt(r) % sample_count;
597                 int b = cvRandInt(r) % sample_count;
598                 CV_SWAP( idst[a], idst[b], vi );
599             }
600         }
601     }
602
603     if ( cat_map ) 
604         cat_map->cols = MAX( total_c_count, 1 );
605
606     max_split_size = cvAlign(sizeof(CvDTreeSplit) +
607         (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));
608     CV_CALL( split_heap = cvCreateSet( 0, sizeof(*split_heap), max_split_size, tree_storage ));
609
610     have_priors = is_classifier && params.priors;
611     if( is_classifier )
612     {
613         int m = get_num_classes();
614         double sum = 0;
615         CV_CALL( priors = cvCreateMat( 1, m, CV_64F ));
616         for( i = 0; i < m; i++ )
617         {
618             double val = have_priors ? params.priors[i] : 1.;
619             if( val <= 0 )
620                 CV_ERROR( CV_StsOutOfRange, "Every class weight should be positive" );
621             priors->data.db[i] = val;
622             sum += val;
623         }
624
625         // normalize weights
626         if( have_priors )
627             cvScale( priors, priors, 1./sum );
628
629         CV_CALL( priors_mult = cvCloneMat( priors ));
630         CV_CALL( counts = cvCreateMat( 1, m, CV_32SC1 ));
631     }
632
633
634     CV_CALL( direction = cvCreateMat( 1, sample_count, CV_8UC1 ));
635     CV_CALL( split_buf = cvCreateMat( 1, sample_count, CV_32SC1 ));
636
637     {
638         int maxNumThreads = 1;
639 #ifdef _OPENMP
640         maxNumThreads = cv::getNumThreads();
641 #endif
642         pred_float_buf.resize(maxNumThreads);
643         pred_int_buf.resize(maxNumThreads);
644         resp_float_buf.resize(maxNumThreads);
645         resp_int_buf.resize(maxNumThreads);
646         cv_lables_buf.resize(maxNumThreads);
647         sample_idx_buf.resize(maxNumThreads);
648         for( int ti = 0; ti < maxNumThreads; ti++ )
649         {
650             pred_float_buf[ti].resize(sample_count);
651             pred_int_buf[ti].resize(sample_count);
652             resp_float_buf[ti].resize(sample_count);
653             resp_int_buf[ti].resize(sample_count);
654             cv_lables_buf[ti].resize(sample_count);
655             sample_idx_buf[ti].resize(sample_count);
656         }
657     }
658
659     __END__;
660
661     if( data )
662         delete data;
663
664     if (_fdst)
665         cvFree( &_fdst );
666     if (_idst)
667         cvFree( &_idst );
668     cvFree( &int_ptr );
669     cvReleaseMat( &var_type0 );
670     cvReleaseMat( &sample_indices );
671     cvReleaseMat( &tmp_map );
672 }
673
674 void CvDTreeTrainData::do_responses_copy()
675 {
676     responses_copy = cvCreateMat( responses->rows, responses->cols, responses->type );
677     cvCopy( responses, responses_copy);
678     responses = responses_copy;
679 }
680
681 CvDTreeNode* CvDTreeTrainData::subsample_data( const CvMat* _subsample_idx )
682 {
683     CvDTreeNode* root = 0;
684     CvMat* isubsample_idx = 0;
685     CvMat* subsample_co = 0;
686
687     CV_FUNCNAME( "CvDTreeTrainData::subsample_data" );
688
689     __BEGIN__;
690
691     if( !data_root )
692         CV_ERROR( CV_StsError, "No training data has been set" );
693
694     if( _subsample_idx )
695         CV_CALL( isubsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));
696
697     if( !isubsample_idx )
698     {
699         // make a copy of the root node
700         CvDTreeNode temp;
701         int i;
702         root = new_node( 0, 1, 0, 0 );
703         temp = *root;
704         *root = *data_root;
705         root->num_valid = temp.num_valid;
706         if( root->num_valid )
707         {
708             for( i = 0; i < var_count; i++ )
709                 root->num_valid[i] = data_root->num_valid[i];
710         }
711         root->cv_Tn = temp.cv_Tn;
712         root->cv_node_risk = temp.cv_node_risk;
713         root->cv_node_error = temp.cv_node_error;
714     }
715     else
716     {
717         int* sidx = isubsample_idx->data.i;
718         // co - array of count/offset pairs (to handle duplicated values in _subsample_idx)
719         int* co, cur_ofs = 0;
720         int vi, i;
721         int work_var_count = get_work_var_count();
722         int count = isubsample_idx->rows + isubsample_idx->cols - 1;
723
724         root = new_node( 0, count, 1, 0 );
725
726         CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 ));
727         cvZero( subsample_co );
728         co = subsample_co->data.i;
729         for( i = 0; i < count; i++ )
730             co[sidx[i]*2]++;
731         for( i = 0; i < sample_count; i++ )
732         {
733             if( co[i*2] )
734             {
735                 co[i*2+1] = cur_ofs;
736                 cur_ofs += co[i*2];
737             }
738             else
739                 co[i*2+1] = -1;
740         }
741
742         for( vi = 0; vi < work_var_count; vi++ )
743         {
744             int ci = get_var_type(vi);
745
746             if( ci >= 0 || vi >= var_count )
747             {
748                 int* src_buf = get_pred_int_buf();
749                 const int* src = 0;
750                 int num_valid = 0;
751                 
752                 get_cat_var_data( data_root, vi, src_buf, &src );
753
754                 if (is_buf_16u)
755                 {
756                     unsigned short* udst = (unsigned short*)(buf->data.s + root->buf_idx*buf->cols + 
757                         vi*sample_count + root->offset);
758                     for( i = 0; i < count; i++ )
759                     {
760                         int val = src[sidx[i]];
761                         udst[i] = (unsigned short)val;
762                         num_valid += val >= 0;
763                     }
764                 }
765                 else
766                 {
767                     int* idst = buf->data.i + root->buf_idx*buf->cols + 
768                         vi*sample_count + root->offset;
769                     for( i = 0; i < count; i++ )
770                     {
771                         int val = src[sidx[i]];
772                         idst[i] = val;
773                         num_valid += val >= 0;
774                     }
775                 }
776
777                 if( vi < var_count )
778                     root->set_num_valid(vi, num_valid);
779             }
780             else
781             {
782                 int *src_idx_buf = get_pred_int_buf();
783                 const int* src_idx = 0;
784                 float *src_val_buf = get_pred_float_buf();
785                 const float* src_val = 0;
786                 int j = 0, idx, count_i;
787                 int num_valid = data_root->get_num_valid(vi);
788
789                 get_ord_var_data( data_root, vi, src_val_buf, src_idx_buf, &src_val, &src_idx );
790                 if (is_buf_16u)
791                 {
792                     unsigned short* udst_idx = (unsigned short*)(buf->data.s + root->buf_idx*buf->cols + 
793                         vi*sample_count + data_root->offset);
794                     for( i = 0; i < num_valid; i++ )
795                     {
796                         idx = src_idx[i];
797                         count_i = co[idx*2];
798                         if( count_i )
799                             for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
800                                 udst_idx[j] = (unsigned short)cur_ofs;
801                     }
802
803                     root->set_num_valid(vi, j);
804
805                     for( ; i < sample_count; i++ )
806                     {
807                         idx = src_idx[i];
808                         count_i = co[idx*2];
809                         if( count_i )
810                             for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
811                                 udst_idx[j] = (unsigned short)cur_ofs;
812                     }
813                 }
814                 else
815                 {
816                     int* idst_idx = buf->data.i + root->buf_idx*buf->cols + 
817                         vi*sample_count + root->offset;
818                     for( i = 0; i < num_valid; i++ )
819                     {
820                         idx = src_idx[i];
821                         count_i = co[idx*2];
822                         if( count_i )
823                             for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
824                                 idst_idx[j] = cur_ofs;
825                     }
826
827                     root->set_num_valid(vi, j);
828
829                     for( ; i < sample_count; i++ )
830                     {
831                         idx = src_idx[i];
832                         count_i = co[idx*2];
833                         if( count_i )
834                             for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
835                                 idst_idx[j] = cur_ofs;
836                     }
837                 }
838             }
839         }
840         // sample indices subsampling
841         int* sample_idx_src_buf = get_sample_idx_buf();
842         const int* sample_idx_src = 0;
843         get_sample_indices(data_root, sample_idx_src_buf, &sample_idx_src);
844         if (is_buf_16u)
845         {
846             unsigned short* sample_idx_dst = (unsigned short*)(buf->data.s + root->buf_idx*buf->cols + 
847                 get_work_var_count()*sample_count + root->offset);            
848             for (i = 0; i < count; i++)
849                 sample_idx_dst[i] = (unsigned short)sample_idx_src[sidx[i]];
850         }
851         else
852         {
853             int* sample_idx_dst = buf->data.i + root->buf_idx*buf->cols + 
854                 get_work_var_count()*sample_count + root->offset;            
855             for (i = 0; i < count; i++)
856                 sample_idx_dst[i] = sample_idx_src[sidx[i]];
857         }
858     }
859
860     __END__;
861
862     cvReleaseMat( &isubsample_idx );
863     cvReleaseMat( &subsample_co );
864
865     return root;
866 }
867
868
869 void CvDTreeTrainData::get_vectors( const CvMat* _subsample_idx,
870                                     float* values, uchar* missing,
871                                     float* responses, bool get_class_idx )
872 {
873     CvMat* subsample_idx = 0;
874     CvMat* subsample_co = 0;
875
876     CV_FUNCNAME( "CvDTreeTrainData::get_vectors" );
877
878     __BEGIN__;
879
880     int i, vi, total = sample_count, count = total, cur_ofs = 0;
881     int* sidx = 0;
882     int* co = 0;
883
884     if( _subsample_idx )
885     {
886         CV_CALL( subsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));
887         sidx = subsample_idx->data.i;
888         CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 ));
889         co = subsample_co->data.i;
890         cvZero( subsample_co );
891         count = subsample_idx->cols + subsample_idx->rows - 1;
892         for( i = 0; i < count; i++ )
893             co[sidx[i]*2]++;
894         for( i = 0; i < total; i++ )
895         {
896             int count_i = co[i*2];
897             if( count_i )
898             {
899                 co[i*2+1] = cur_ofs*var_count;
900                 cur_ofs += count_i;
901             }
902         }
903     }
904
905     if( missing )
906         memset( missing, 1, count*var_count );
907
908     for( vi = 0; vi < var_count; vi++ )
909     {
910         int ci = get_var_type(vi);
911         if( ci >= 0 ) // categorical
912         {
913             float* dst = values + vi;
914             uchar* m = missing ? missing + vi : 0;
915             int* src_buf = get_pred_int_buf();
916             const int* src = 0; 
917             get_cat_var_data(data_root, vi, src_buf, &src);
918
919             for( i = 0; i < count; i++, dst += var_count )
920             {
921                 int idx = sidx ? sidx[i] : i;
922                 int val = src[idx];
923                 *dst = (float)val;
924                 if( m )
925                 {
926                     *m = (!is_buf_16u && val < 0) || (is_buf_16u && (val == 65535));
927                     m += var_count;
928                 }
929             }
930         }
931         else // ordered
932         {
933             float* dst = values + vi;
934             uchar* m = missing ? missing + vi : 0;
935             int count1 = data_root->get_num_valid(vi);
936             float *src_val_buf = get_pred_float_buf();
937             const float *src_val = 0;
938             int* src_idx_buf = get_pred_int_buf();
939             const int* src_idx = 0;
940             get_ord_var_data(data_root, vi, src_val_buf, src_idx_buf, &src_val, &src_idx);
941
942             for( i = 0; i < count1; i++ )
943             {
944                 int idx = src_idx[i];
945                 int count_i = 1;
946                 if( co )
947                 {
948                     count_i = co[idx*2];
949                     cur_ofs = co[idx*2+1];
950                 }
951                 else
952                     cur_ofs = idx*var_count;
953                 if( count_i )
954                 {
955                     float val = src_val[i];
956                     for( ; count_i > 0; count_i--, cur_ofs += var_count )
957                     {
958                         dst[cur_ofs] = val;
959                         if( m )
960                             m[cur_ofs] = 0;
961                     }
962                 }
963             }
964         }
965     }
966
967     // copy responses
968     if( responses )
969     {
970         if( is_classifier )
971         {
972             int* src_buf = get_resp_int_buf();
973             const int* src = 0;
974             get_class_labels(data_root, src_buf, &src);
975             for( i = 0; i < count; i++ )
976             {
977                 int idx = sidx ? sidx[i] : i;
978                 int val = get_class_idx ? src[idx] :
979                     cat_map->data.i[cat_ofs->data.i[cat_var_count]+src[idx]];
980                 responses[i] = (float)val;
981             }
982         }
983         else
984         {
985             float *_values_buf = get_resp_float_buf();
986             const float* _values = 0;
987             get_ord_responses(data_root, _values_buf, &_values);
988             for( i = 0; i < count; i++ )
989             {
990                 int idx = sidx ? sidx[i] : i;
991                 responses[i] = _values[idx];
992             }
993         }
994     }
995
996     __END__;
997
998     cvReleaseMat( &subsample_idx );
999     cvReleaseMat( &subsample_co );
1000 }
1001
1002
1003 CvDTreeNode* CvDTreeTrainData::new_node( CvDTreeNode* parent, int count,
1004                                          int storage_idx, int offset )
1005 {
1006     CvDTreeNode* node = (CvDTreeNode*)cvSetNew( node_heap );
1007
1008     node->sample_count = count;
1009     node->depth = parent ? parent->depth + 1 : 0;
1010     node->parent = parent;
1011     node->left = node->right = 0;
1012     node->split = 0;
1013     node->value = 0;
1014     node->class_idx = 0;
1015     node->maxlr = 0.;
1016
1017     node->buf_idx = storage_idx;
1018     node->offset = offset;
1019     if( nv_heap )
1020         node->num_valid = (int*)cvSetNew( nv_heap );
1021     else
1022         node->num_valid = 0;
1023     node->alpha = node->node_risk = node->tree_risk = node->tree_error = 0.;
1024     node->complexity = 0;
1025
1026     if( params.cv_folds > 0 && cv_heap )
1027     {
1028         int cv_n = params.cv_folds;
1029         node->Tn = INT_MAX;
1030         node->cv_Tn = (int*)cvSetNew( cv_heap );
1031         node->cv_node_risk = (double*)cvAlignPtr(node->cv_Tn + cv_n, sizeof(double));
1032         node->cv_node_error = node->cv_node_risk + cv_n;
1033     }
1034     else
1035     {
1036         node->Tn = 0;
1037         node->cv_Tn = 0;
1038         node->cv_node_risk = 0;
1039         node->cv_node_error = 0;
1040     }
1041
1042     return node;
1043 }
1044
1045
1046 CvDTreeSplit* CvDTreeTrainData::new_split_ord( int vi, float cmp_val,
1047                 int split_point, int inversed, float quality )
1048 {
1049     CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap );
1050     split->var_idx = vi;
1051     split->condensed_idx = INT_MIN;
1052     split->ord.c = cmp_val;
1053     split->ord.split_point = split_point;
1054     split->inversed = inversed;
1055     split->quality = quality;
1056     split->next = 0;
1057
1058     return split;
1059 }
1060
1061
1062 CvDTreeSplit* CvDTreeTrainData::new_split_cat( int vi, float quality )
1063 {
1064     CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap );
1065     int i, n = (max_c_count + 31)/32;
1066
1067     split->var_idx = vi;
1068     split->condensed_idx = INT_MIN;
1069     split->inversed = 0;
1070     split->quality = quality;
1071     for( i = 0; i < n; i++ )
1072         split->subset[i] = 0;
1073     split->next = 0;
1074
1075     return split;
1076 }
1077
1078
1079 void CvDTreeTrainData::free_node( CvDTreeNode* node )
1080 {
1081     CvDTreeSplit* split = node->split;
1082     free_node_data( node );
1083     while( split )
1084     {
1085         CvDTreeSplit* next = split->next;
1086         cvSetRemoveByPtr( split_heap, split );
1087         split = next;
1088     }
1089     node->split = 0;
1090     cvSetRemoveByPtr( node_heap, node );
1091 }
1092
1093
1094 void CvDTreeTrainData::free_node_data( CvDTreeNode* node )
1095 {
1096     if( node->num_valid )
1097     {
1098         cvSetRemoveByPtr( nv_heap, node->num_valid );
1099         node->num_valid = 0;
1100     }
1101     // do not free cv_* fields, as all the cross-validation related data is released at once.
1102 }
1103
1104
1105 void CvDTreeTrainData::free_train_data()
1106 {
1107     cvReleaseMat( &counts );
1108     cvReleaseMat( &buf );
1109     cvReleaseMat( &direction );
1110     cvReleaseMat( &split_buf );
1111     cvReleaseMemStorage( &temp_storage );
1112     cvReleaseMat( &responses_copy );
1113     pred_float_buf.clear();
1114     pred_int_buf.clear();
1115     resp_float_buf.clear();
1116     resp_int_buf.clear();
1117     cv_lables_buf.clear();
1118     sample_idx_buf.clear();
1119
1120     cv_heap = nv_heap = 0;
1121 }
1122
1123
1124 void CvDTreeTrainData::clear()
1125 {
1126     free_train_data();
1127
1128     cvReleaseMemStorage( &tree_storage );
1129
1130     cvReleaseMat( &var_idx );
1131     cvReleaseMat( &var_type );
1132     cvReleaseMat( &cat_count );
1133     cvReleaseMat( &cat_ofs );
1134     cvReleaseMat( &cat_map );
1135     cvReleaseMat( &priors );
1136     cvReleaseMat( &priors_mult );
1137     
1138     node_heap = split_heap = 0;
1139
1140     sample_count = var_all = var_count = max_c_count = ord_var_count = cat_var_count = 0;
1141     have_labels = have_priors = is_classifier = false;
1142
1143     buf_count = buf_size = 0;
1144     shared = false;
1145     
1146     data_root = 0;
1147
1148     rng = cvRNG(-1);
1149 }
1150
1151
1152 int CvDTreeTrainData::get_num_classes() const
1153 {
1154     return is_classifier ? cat_count->data.i[cat_var_count] : 0;
1155 }
1156
1157
1158 int CvDTreeTrainData::get_var_type(int vi) const
1159 {
1160     return var_type->data.i[vi];
1161 }
1162
1163 int CvDTreeTrainData::get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* indices_buf, const float** ord_values, const int** indices )
1164 {
1165     int vidx = var_idx ? var_idx->data.i[vi] : vi;
1166     int node_sample_count = n->sample_count; 
1167     int* sample_indices_buf = get_sample_idx_buf();
1168     const int* sample_indices = 0;
1169     int td_step = train_data->step/CV_ELEM_SIZE(train_data->type);
1170
1171     get_sample_indices(n, sample_indices_buf, &sample_indices);
1172
1173     if( !is_buf_16u )
1174         *indices = buf->data.i + n->buf_idx*buf->cols + 
1175         vi*sample_count + n->offset;
1176     else {
1177         const unsigned short* short_indices = (const unsigned short*)(buf->data.s + n->buf_idx*buf->cols + 
1178             vi*sample_count + n->offset );
1179         for( int i = 0; i < node_sample_count; i++ )
1180             indices_buf[i] = short_indices[i];
1181         *indices = indices_buf;
1182     }
1183     
1184     if( tflag == CV_ROW_SAMPLE )
1185     {
1186         for( int i = 0; i < node_sample_count && 
1187             ((((*indices)[i] >= 0) && !is_buf_16u) || (((*indices)[i] != 65535) && is_buf_16u)); i++ )
1188         {
1189             int idx = (*indices)[i];
1190             idx = sample_indices[idx];
1191             ord_values_buf[i] = *(train_data->data.fl + idx * td_step + vidx);
1192         }
1193     }
1194     else
1195         for( int i = 0; i < node_sample_count && 
1196             ((((*indices)[i] >= 0) && !is_buf_16u) || (((*indices)[i] != 65535) && is_buf_16u)); i++ )
1197         {
1198             int idx = (*indices)[i];
1199             idx = sample_indices[idx];
1200             ord_values_buf[i] = *(train_data->data.fl + vidx* td_step + idx);
1201         }
1202     
1203     *ord_values = ord_values_buf;
1204     return 0; //TODO: return the number of non-missing values
1205 }
1206
1207
1208 void CvDTreeTrainData::get_class_labels( CvDTreeNode* n, int* labels_buf, const int** labels )
1209 {
1210     if (is_classifier)
1211         get_cat_var_data( n, var_count, labels_buf, labels );
1212 }
1213
1214 void CvDTreeTrainData::get_sample_indices( CvDTreeNode* n, int* indices_buf, const int** indices )
1215 {
1216     get_cat_var_data( n, get_work_var_count(), indices_buf, indices );
1217 }
1218
1219 void CvDTreeTrainData::get_ord_responses( CvDTreeNode* n, float* values_buf, const float** values)
1220 {
1221     int sample_count = n->sample_count;
1222     int* indices_buf = get_sample_idx_buf();
1223     const int* indices = 0;
1224
1225     int r_step = responses->step/CV_ELEM_SIZE(responses->type);
1226
1227     get_sample_indices(n, indices_buf, &indices);
1228
1229     
1230     for( int i = 0; i < sample_count && 
1231         (((indices[i] >= 0) && !is_buf_16u) || ((indices[i] != 65535) && is_buf_16u)); i++ )
1232     {
1233         int idx = indices[i];
1234         values_buf[i] = *(responses->data.fl + idx * r_step);
1235     }
1236     
1237     *values = values_buf;    
1238 }
1239
1240
1241 void CvDTreeTrainData::get_cv_labels( CvDTreeNode* n, int* labels_buf, const int** labels )
1242 {
1243     if (have_labels)
1244         get_cat_var_data( n, get_work_var_count()- 1, labels_buf, labels );
1245 }
1246
1247
1248 int CvDTreeTrainData::get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf, const int** cat_values )
1249 {
1250     if( !is_buf_16u )
1251         *cat_values = buf->data.i + n->buf_idx*buf->cols + 
1252         vi*sample_count + n->offset;
1253     else {
1254         const unsigned short* short_values = (const unsigned short*)(buf->data.s + n->buf_idx*buf->cols + 
1255             vi*sample_count + n->offset);
1256         for( int i = 0; i < n->sample_count; i++ )
1257             cat_values_buf[i] = short_values[i];
1258         *cat_values = cat_values_buf;
1259     }
1260
1261     return 0; //TODO: return the number of non-missing values
1262 }
1263
1264
1265 int CvDTreeTrainData::get_child_buf_idx( CvDTreeNode* n )
1266 {
1267     int idx = n->buf_idx + 1;
1268     if( idx >= buf_count )
1269         idx = shared ? 1 : 0;
1270     return idx;
1271 }
1272
1273
1274 void CvDTreeTrainData::write_params( CvFileStorage* fs ) const
1275 {
1276     CV_FUNCNAME( "CvDTreeTrainData::write_params" );
1277
1278     __BEGIN__;
1279
1280     int vi, vcount = var_count;
1281
1282     cvWriteInt( fs, "is_classifier", is_classifier ? 1 : 0 );
1283     cvWriteInt( fs, "var_all", var_all );
1284     cvWriteInt( fs, "var_count", var_count );
1285     cvWriteInt( fs, "ord_var_count", ord_var_count );
1286     cvWriteInt( fs, "cat_var_count", cat_var_count );
1287
1288     cvStartWriteStruct( fs, "training_params", CV_NODE_MAP );
1289     cvWriteInt( fs, "use_surrogates", params.use_surrogates ? 1 : 0 );
1290
1291     if( is_classifier )
1292     {
1293         cvWriteInt( fs, "max_categories", params.max_categories );
1294     }
1295     else
1296     {
1297         cvWriteReal( fs, "regression_accuracy", params.regression_accuracy );
1298     }
1299
1300     cvWriteInt( fs, "max_depth", params.max_depth );
1301     cvWriteInt( fs, "min_sample_count", params.min_sample_count );
1302     cvWriteInt( fs, "cross_validation_folds", params.cv_folds );
1303
1304     if( params.cv_folds > 1 )
1305     {
1306         cvWriteInt( fs, "use_1se_rule", params.use_1se_rule ? 1 : 0 );
1307         cvWriteInt( fs, "truncate_pruned_tree", params.truncate_pruned_tree ? 1 : 0 );
1308     }
1309
1310     if( priors )
1311         cvWrite( fs, "priors", priors );
1312
1313     cvEndWriteStruct( fs );
1314
1315     if( var_idx )
1316         cvWrite( fs, "var_idx", var_idx );
1317
1318     cvStartWriteStruct( fs, "var_type", CV_NODE_SEQ+CV_NODE_FLOW );
1319
1320     for( vi = 0; vi < vcount; vi++ )
1321         cvWriteInt( fs, 0, var_type->data.i[vi] >= 0 );
1322
1323     cvEndWriteStruct( fs );
1324
1325     if( cat_count && (cat_var_count > 0 || is_classifier) )
1326     {
1327         CV_ASSERT( cat_count != 0 );
1328         cvWrite( fs, "cat_count", cat_count );
1329         cvWrite( fs, "cat_map", cat_map );
1330     }
1331
1332     __END__;
1333 }
1334
1335
1336 void CvDTreeTrainData::read_params( CvFileStorage* fs, CvFileNode* node )
1337 {
1338     CV_FUNCNAME( "CvDTreeTrainData::read_params" );
1339
1340     __BEGIN__;
1341
1342     CvFileNode *tparams_node, *vartype_node;
1343     CvSeqReader reader;
1344     int vi, max_split_size, tree_block_size;
1345
1346     is_classifier = (cvReadIntByName( fs, node, "is_classifier" ) != 0);
1347     var_all = cvReadIntByName( fs, node, "var_all" );
1348     var_count = cvReadIntByName( fs, node, "var_count", var_all );
1349     cat_var_count = cvReadIntByName( fs, node, "cat_var_count" );
1350     ord_var_count = cvReadIntByName( fs, node, "ord_var_count" );
1351
1352     tparams_node = cvGetFileNodeByName( fs, node, "training_params" );
1353
1354     if( tparams_node ) // training parameters are not necessary
1355     {
1356         params.use_surrogates = cvReadIntByName( fs, tparams_node, "use_surrogates", 1 ) != 0;
1357
1358         if( is_classifier )
1359         {
1360             params.max_categories = cvReadIntByName( fs, tparams_node, "max_categories" );
1361         }
1362         else
1363         {
1364             params.regression_accuracy =
1365                 (float)cvReadRealByName( fs, tparams_node, "regression_accuracy" );
1366         }
1367
1368         params.max_depth = cvReadIntByName( fs, tparams_node, "max_depth" );
1369         params.min_sample_count = cvReadIntByName( fs, tparams_node, "min_sample_count" );
1370         params.cv_folds = cvReadIntByName( fs, tparams_node, "cross_validation_folds" );
1371
1372         if( params.cv_folds > 1 )
1373         {
1374             params.use_1se_rule = cvReadIntByName( fs, tparams_node, "use_1se_rule" ) != 0;
1375             params.truncate_pruned_tree =
1376                 cvReadIntByName( fs, tparams_node, "truncate_pruned_tree" ) != 0;
1377         }
1378
1379         priors = (CvMat*)cvReadByName( fs, tparams_node, "priors" );
1380         if( priors )
1381         {
1382             if( !CV_IS_MAT(priors) )
1383                 CV_ERROR( CV_StsParseError, "priors must stored as a matrix" );
1384             priors_mult = cvCloneMat( priors );
1385         }
1386     }
1387
1388     CV_CALL( var_idx = (CvMat*)cvReadByName( fs, node, "var_idx" ));
1389     if( var_idx )
1390     {
1391         if( !CV_IS_MAT(var_idx) ||
1392             (var_idx->cols != 1 && var_idx->rows != 1) ||
1393             var_idx->cols + var_idx->rows - 1 != var_count ||
1394             CV_MAT_TYPE(var_idx->type) != CV_32SC1 )
1395             CV_ERROR( CV_StsParseError,
1396                 "var_idx (if exist) must be valid 1d integer vector containing <var_count> elements" );
1397
1398         for( vi = 0; vi < var_count; vi++ )
1399             if( (unsigned)var_idx->data.i[vi] >= (unsigned)var_all )
1400                 CV_ERROR( CV_StsOutOfRange, "some of var_idx elements are out of range" );
1401     }
1402
1403     ////// read var type
1404     CV_CALL( var_type = cvCreateMat( 1, var_count + 2, CV_32SC1 ));
1405
1406     cat_var_count = 0;
1407     ord_var_count = -1;
1408     vartype_node = cvGetFileNodeByName( fs, node, "var_type" );
1409
1410     if( vartype_node && CV_NODE_TYPE(vartype_node->tag) == CV_NODE_INT && var_count == 1 )
1411         var_type->data.i[0] = vartype_node->data.i ? cat_var_count++ : ord_var_count--;
1412     else
1413     {
1414         if( !vartype_node || CV_NODE_TYPE(vartype_node->tag) != CV_NODE_SEQ ||
1415             vartype_node->data.seq->total != var_count )
1416             CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );
1417
1418         cvStartReadSeq( vartype_node->data.seq, &reader );
1419
1420         for( vi = 0; vi < var_count; vi++ )
1421         {
1422             CvFileNode* n = (CvFileNode*)reader.ptr;
1423             if( CV_NODE_TYPE(n->tag) != CV_NODE_INT || (n->data.i & ~1) )
1424                 CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );
1425             var_type->data.i[vi] = n->data.i ? cat_var_count++ : ord_var_count--;
1426             CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
1427         }
1428     }
1429     var_type->data.i[var_count] = cat_var_count;
1430
1431     ord_var_count = ~ord_var_count;
1432     if( cat_var_count != cat_var_count || ord_var_count != ord_var_count )
1433         CV_ERROR( CV_StsParseError, "var_type is inconsistent with cat_var_count and ord_var_count" );
1434     //////
1435
1436     if( cat_var_count > 0 || is_classifier )
1437     {
1438         int ccount, total_c_count = 0;
1439         CV_CALL( cat_count = (CvMat*)cvReadByName( fs, node, "cat_count" ));
1440         CV_CALL( cat_map = (CvMat*)cvReadByName( fs, node, "cat_map" ));
1441
1442         if( !CV_IS_MAT(cat_count) || !CV_IS_MAT(cat_map) ||
1443             (cat_count->cols != 1 && cat_count->rows != 1) ||
1444             CV_MAT_TYPE(cat_count->type) != CV_32SC1 ||
1445             cat_count->cols + cat_count->rows - 1 != cat_var_count + is_classifier ||
1446             (cat_map->cols != 1 && cat_map->rows != 1) ||
1447             CV_MAT_TYPE(cat_map->type) != CV_32SC1 )
1448             CV_ERROR( CV_StsParseError,
1449             "Both cat_count and cat_map must exist and be valid 1d integer vectors of an appropriate size" );
1450
1451         ccount = cat_var_count + is_classifier;
1452
1453         CV_CALL( cat_ofs = cvCreateMat( 1, ccount + 1, CV_32SC1 ));
1454         cat_ofs->data.i[0] = 0;
1455         max_c_count = 1;
1456
1457         for( vi = 0; vi < ccount; vi++ )
1458         {
1459             int val = cat_count->data.i[vi];
1460             if( val <= 0 )
1461                 CV_ERROR( CV_StsOutOfRange, "some of cat_count elements are out of range" );
1462             max_c_count = MAX( max_c_count, val );
1463             cat_ofs->data.i[vi+1] = total_c_count += val;
1464         }
1465
1466         if( cat_map->cols + cat_map->rows - 1 != total_c_count )
1467             CV_ERROR( CV_StsBadSize,
1468             "cat_map vector length is not equal to the total number of categories in all categorical vars" );
1469     }
1470
1471     max_split_size = cvAlign(sizeof(CvDTreeSplit) +
1472         (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));
1473
1474     tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);
1475     tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);
1476     CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));
1477     CV_CALL( node_heap = cvCreateSet( 0, sizeof(node_heap[0]),
1478             sizeof(CvDTreeNode), tree_storage ));
1479     CV_CALL( split_heap = cvCreateSet( 0, sizeof(split_heap[0]),
1480             max_split_size, tree_storage ));
1481
1482     __END__;
1483 }
1484
1485 float* CvDTreeTrainData::get_pred_float_buf()
1486 {
1487     return &pred_float_buf[cv::getThreadNum()][0];
1488 }
1489 int* CvDTreeTrainData::get_pred_int_buf()
1490 {
1491     return &pred_int_buf[cv::getThreadNum()][0];
1492 }
1493 float* CvDTreeTrainData::get_resp_float_buf()
1494 {
1495     return &resp_float_buf[cv::getThreadNum()][0];
1496 }
1497 int* CvDTreeTrainData::get_resp_int_buf()
1498 {
1499     return &resp_int_buf[cv::getThreadNum()][0];
1500 }
1501 int* CvDTreeTrainData::get_cv_lables_buf()
1502 {
1503     return &cv_lables_buf[cv::getThreadNum()][0];
1504 }
1505 int* CvDTreeTrainData::get_sample_idx_buf()
1506 {
1507     return &sample_idx_buf[cv::getThreadNum()][0];
1508 }
1509
1510 /////////////////////// Decision Tree /////////////////////////
1511
1512 CvDTree::CvDTree()
1513 {
1514     data = 0;
1515     var_importance = 0;
1516     default_model_name = "my_tree";
1517
1518     clear();
1519 }
1520
1521
1522 void CvDTree::clear()
1523 {
1524     cvReleaseMat( &var_importance );
1525     if( data )
1526     {
1527         if( !data->shared )
1528             delete data;
1529         else
1530             free_tree();
1531         data = 0;
1532     }
1533     root = 0;
1534     pruned_tree_idx = -1;
1535 }
1536
1537
1538 CvDTree::~CvDTree()
1539 {
1540     clear();
1541 }
1542
1543
1544 const CvDTreeNode* CvDTree::get_root() const
1545 {
1546     return root;
1547 }
1548
1549
1550 int CvDTree::get_pruned_tree_idx() const
1551 {
1552     return pruned_tree_idx;
1553 }
1554
1555
1556 CvDTreeTrainData* CvDTree::get_data()
1557 {
1558     return data;
1559 }
1560
1561
1562 bool CvDTree::train( const CvMat* _train_data, int _tflag,
1563                      const CvMat* _responses, const CvMat* _var_idx,
1564                      const CvMat* _sample_idx, const CvMat* _var_type,
1565                      const CvMat* _missing_mask, CvDTreeParams _params )
1566 {
1567     bool result = false;
1568
1569     CV_FUNCNAME( "CvDTree::train" );
1570
1571     __BEGIN__;
1572
1573     clear();
1574     data = new CvDTreeTrainData( _train_data, _tflag, _responses,
1575                                  _var_idx, _sample_idx, _var_type,
1576                                  _missing_mask, _params, false );
1577     CV_CALL( result = do_train(0) );
1578
1579     __END__;
1580
1581     return result;
1582 }
1583
1584 bool CvDTree::train( const Mat& _train_data, int _tflag,
1585                     const Mat& _responses, const Mat& _var_idx,
1586                     const Mat& _sample_idx, const Mat& _var_type,
1587                     const Mat& _missing_mask, CvDTreeParams _params )
1588 {
1589     CvMat tdata = _train_data, responses = _responses, vidx=_var_idx,
1590         sidx=_sample_idx, vtype=_var_type, mmask=_missing_mask; 
1591     return train(&tdata, _tflag, &responses, vidx.data.ptr ? &vidx : 0, sidx.data.ptr ? &sidx : 0,
1592                  vtype.data.ptr ? &vtype : 0, mmask.data.ptr ? &mmask : 0, _params);
1593 }
1594
1595
1596 bool CvDTree::train( CvMLData* _data, CvDTreeParams _params )
1597 {
1598    bool result = false;
1599
1600     CV_FUNCNAME( "CvDTree::train" );
1601
1602     __BEGIN__;
1603
1604     const CvMat* values = _data->get_values();
1605     const CvMat* response = _data->get_responses();
1606     const CvMat* missing = _data->get_missing();
1607     const CvMat* var_types = _data->get_var_types();
1608     const CvMat* train_sidx = _data->get_train_sample_idx();
1609     const CvMat* var_idx = _data->get_var_idx();
1610
1611     CV_CALL( result = train( values, CV_ROW_SAMPLE, response, var_idx,
1612         train_sidx, var_types, missing, _params ) );
1613
1614     __END__;
1615
1616     return result;
1617 }
1618
1619 bool CvDTree::train( CvDTreeTrainData* _data, const CvMat* _subsample_idx )
1620 {
1621     bool result = false;
1622
1623     CV_FUNCNAME( "CvDTree::train" );
1624
1625     __BEGIN__;
1626
1627     clear();
1628     data = _data;
1629     data->shared = true;
1630     CV_CALL( result = do_train(_subsample_idx));
1631
1632     __END__;
1633
1634     return result;
1635 }
1636
1637
1638 bool CvDTree::do_train( const CvMat* _subsample_idx )
1639 {
1640     bool result = false;
1641
1642     CV_FUNCNAME( "CvDTree::do_train" );
1643
1644     __BEGIN__;
1645
1646     root = data->subsample_data( _subsample_idx );
1647
1648     CV_CALL( try_split_node(root));
1649
1650     if( data->params.cv_folds > 0 )
1651         CV_CALL( prune_cv());
1652
1653     if( !data->shared )
1654         data->free_train_data();
1655
1656     result = true;
1657
1658     __END__;
1659
1660     return result;
1661 }
1662
1663
1664 void CvDTree::try_split_node( CvDTreeNode* node )
1665 {
1666     CvDTreeSplit* best_split = 0;
1667     int i, n = node->sample_count, vi;
1668     bool can_split = true;
1669     double quality_scale;
1670
1671     calc_node_value( node );
1672
1673     if( node->sample_count <= data->params.min_sample_count ||
1674         node->depth >= data->params.max_depth )
1675         can_split = false;
1676
1677     if( can_split && data->is_classifier )
1678     {
1679         // check if we have a "pure" node,
1680         // we assume that cls_count is filled by calc_node_value()
1681         int* cls_count = data->counts->data.i;
1682         int nz = 0, m = data->get_num_classes();
1683         for( i = 0; i < m; i++ )
1684             nz += cls_count[i] != 0;
1685         if( nz == 1 ) // there is only one class
1686             can_split = false;
1687     }
1688     else if( can_split )
1689     {
1690         if( sqrt(node->node_risk)/n < data->params.regression_accuracy )
1691             can_split = false;
1692     }
1693
1694     if( can_split )
1695     {
1696         best_split = find_best_split(node);
1697         // TODO: check the split quality ...
1698         node->split = best_split;
1699     }
1700     if( !can_split || !best_split )
1701     {
1702         data->free_node_data(node);
1703         return;
1704     }
1705
1706     quality_scale = calc_node_dir( node );
1707     if( data->params.use_surrogates )
1708     {
1709         // find all the surrogate splits
1710         // and sort them by their similarity to the primary one
1711         for( vi = 0; vi < data->var_count; vi++ )
1712         {
1713             CvDTreeSplit* split;
1714             int ci = data->get_var_type(vi);
1715
1716             if( vi == best_split->var_idx )
1717                 continue;
1718
1719             if( ci >= 0 )
1720                 split = find_surrogate_split_cat( node, vi );
1721             else
1722                 split = find_surrogate_split_ord( node, vi );
1723
1724             if( split )
1725             {
1726                 // insert the split
1727                 CvDTreeSplit* prev_split = node->split;
1728                 split->quality = (float)(split->quality*quality_scale);
1729
1730                 while( prev_split->next &&
1731                        prev_split->next->quality > split->quality )
1732                     prev_split = prev_split->next;
1733                 split->next = prev_split->next;
1734                 prev_split->next = split;
1735             }
1736         }
1737     }
1738     split_node_data( node );
1739     try_split_node( node->left );
1740     try_split_node( node->right );
1741 }
1742
1743
1744 // calculate direction (left(-1),right(1),missing(0))
1745 // for each sample using the best split
1746 // the function returns scale coefficients for surrogate split quality factors.
1747 // the scale is applied to normalize surrogate split quality relatively to the
1748 // best (primary) split quality. That is, if a surrogate split is absolutely
1749 // identical to the primary split, its quality will be set to the maximum value =
1750 // quality of the primary split; otherwise, it will be lower.
1751 // besides, the function compute node->maxlr,
1752 // minimum possible quality (w/o considering the above mentioned scale)
1753 // for a surrogate split. Surrogate splits with quality less than node->maxlr
1754 // are not discarded.
1755 double CvDTree::calc_node_dir( CvDTreeNode* node )
1756 {
1757     char* dir = (char*)data->direction->data.ptr;
1758     int i, n = node->sample_count, vi = node->split->var_idx;
1759     double L, R;
1760
1761     assert( !node->split->inversed );
1762
1763     if( data->get_var_type(vi) >= 0 ) // split on categorical var
1764     {
1765         int* labels_buf = data->get_pred_int_buf();
1766         const int* labels = 0;
1767         const int* subset = node->split->subset;
1768         data->get_cat_var_data( node, vi, labels_buf, &labels );
1769         if( !data->have_priors )
1770         {
1771             int sum = 0, sum_abs = 0;
1772
1773             for( i = 0; i < n; i++ )
1774             {
1775                 int idx = labels[i];
1776                 int d = ( ((idx >= 0)&&(!data->is_buf_16u)) || ((idx != 65535)&&(data->is_buf_16u)) ) ?
1777                     CV_DTREE_CAT_DIR(idx,subset) : 0;
1778                 sum += d; sum_abs += d & 1;
1779                 dir[i] = (char)d;
1780             }
1781
1782             R = (sum_abs + sum) >> 1;
1783             L = (sum_abs - sum) >> 1;
1784         }
1785         else
1786         {
1787             const double* priors = data->priors_mult->data.db;
1788             double sum = 0, sum_abs = 0;
1789             int *responses_buf = data->get_resp_int_buf();
1790             const int* responses;
1791             data->get_class_labels(node, responses_buf, &responses);
1792
1793             for( i = 0; i < n; i++ )
1794             {
1795                 int idx = labels[i];
1796                 double w = priors[responses[i]];
1797                 int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0;
1798                 sum += d*w; sum_abs += (d & 1)*w;
1799                 dir[i] = (char)d;
1800             }
1801
1802             R = (sum_abs + sum) * 0.5;
1803             L = (sum_abs - sum) * 0.5;
1804         }
1805     }
1806     else // split on ordered var
1807     {
1808         int split_point = node->split->ord.split_point;
1809         int n1 = node->get_num_valid(vi);
1810         
1811         float* val_buf = data->get_pred_float_buf();
1812         const float* val = 0;
1813         int* sorted_buf = data->get_pred_int_buf();
1814         const int* sorted = 0;
1815         data->get_ord_var_data( node, vi, val_buf, sorted_buf, &val, &sorted);
1816         
1817         assert( 0 <= split_point && split_point < n1-1 );
1818
1819         if( !data->have_priors )
1820         {
1821             for( i = 0; i <= split_point; i++ )
1822                 dir[sorted[i]] = (char)-1;
1823             for( ; i < n1; i++ )
1824                 dir[sorted[i]] = (char)1;
1825             for( ; i < n; i++ )
1826                 dir[sorted[i]] = (char)0;
1827
1828             L = split_point-1;
1829             R = n1 - split_point + 1;
1830         }
1831         else
1832         {
1833             const double* priors = data->priors_mult->data.db;
1834             int* responses_buf = data->get_resp_int_buf();
1835             const int* responses = 0;
1836             data->get_class_labels(node, responses_buf, &responses);
1837             L = R = 0;
1838
1839             for( i = 0; i <= split_point; i++ )
1840             {
1841                 int idx = sorted[i];
1842                 double w = priors[responses[idx]];
1843                 dir[idx] = (char)-1;
1844                 L += w;
1845             }
1846
1847             for( ; i < n1; i++ )
1848             {
1849                 int idx = sorted[i];
1850                 double w = priors[responses[idx]];
1851                 dir[idx] = (char)1;
1852                 R += w;
1853             }
1854
1855             for( ; i < n; i++ )
1856                 dir[sorted[i]] = (char)0;
1857         }
1858     }
1859     node->maxlr = MAX( L, R );
1860     return node->split->quality/(L + R);
1861 }
1862
1863 CvDTreeSplit* CvDTree::find_best_split( CvDTreeNode* node )
1864 {
1865     int vi;
1866     CvDTreeSplit *bestSplit = 0;
1867     int maxNumThreads = 1;
1868 #ifdef _OPENMP
1869     maxNumThreads = cv::getNumThreads();
1870 #endif
1871     vector<CvDTreeSplit*> splits(maxNumThreads);
1872     vector<CvDTreeSplit*> bestSplits(maxNumThreads);
1873     vector<int> canSplit(maxNumThreads);
1874     CvDTreeSplit **splitsPtr = &splits[0], ** bestSplitsPtr = &bestSplits[0];
1875     int* canSplitPtr = &canSplit[0];
1876     for (int i = 0; i < maxNumThreads; i++)
1877     {
1878         splitsPtr[i] = data->new_split_cat( 0, -1.0f );
1879         bestSplitsPtr[i] = data->new_split_cat( 0, -1.0f );
1880         canSplitPtr[i] = 0;
1881     }
1882
1883 #ifdef _OPENMP
1884 #pragma omp parallel for num_threads(maxNumThreads) schedule(dynamic)
1885 #endif
1886     for( vi = 0; vi < data->var_count; vi++ )
1887     {
1888         CvDTreeSplit *res, *t;
1889         int threadIdx = cv::getThreadNum();
1890         int ci = data->get_var_type(vi);
1891         if( node->get_num_valid(vi) <= 1 )
1892             continue;
1893
1894         if( data->is_classifier )
1895         {
1896             if( ci >= 0 )
1897                 res = find_split_cat_class( node, vi, bestSplitsPtr[threadIdx]->quality, splitsPtr[threadIdx] );
1898             else
1899                 res = find_split_ord_class( node, vi, bestSplitsPtr[threadIdx]->quality, splitsPtr[threadIdx] );
1900         }
1901         else
1902         {
1903             if( ci >= 0 )
1904                 res = find_split_cat_reg( node, vi, bestSplitsPtr[threadIdx]->quality, splitsPtr[threadIdx] );
1905             else
1906                 res = find_split_ord_reg( node, vi, bestSplitsPtr[threadIdx]->quality, splitsPtr[threadIdx] );
1907         }
1908
1909         if( res )
1910         {
1911             canSplitPtr[threadIdx] = 1;
1912             if( bestSplitsPtr[threadIdx]->quality < splitsPtr[threadIdx]->quality )
1913                 CV_SWAP( bestSplitsPtr[threadIdx], splitsPtr[threadIdx], t );
1914         }
1915     }
1916     int ti = 0;
1917     for( ; ti < maxNumThreads; ti++ )
1918     {
1919         if( canSplitPtr[ti] )
1920         {
1921             bestSplit = bestSplitsPtr[ti];
1922             break;
1923         }
1924     }
1925     for( ; ti < maxNumThreads; ti++ )
1926     {
1927         if( bestSplit->quality < bestSplitsPtr[ti]->quality )
1928             bestSplit = bestSplitsPtr[ti];
1929     }
1930     for(int i = 0; i < maxNumThreads; i++)
1931     {
1932         cvSetRemoveByPtr( data->split_heap, splitsPtr[i] );
1933         if( bestSplitsPtr[i] != bestSplit )
1934             cvSetRemoveByPtr( data->split_heap, bestSplitsPtr[i] );
1935     }
1936     return bestSplit;
1937 }
1938
1939 CvDTreeSplit* CvDTree::find_split_ord_class( CvDTreeNode* node, int vi,
1940                                             float init_quality, CvDTreeSplit* _split )
1941 {
1942     const float epsilon = FLT_EPSILON*2;
1943     int n = node->sample_count;
1944     int n1 = node->get_num_valid(vi);
1945     int m = data->get_num_classes();
1946
1947     float* values_buf = data->get_pred_float_buf();
1948     const float* values = 0;
1949     int* indices_buf = data->get_pred_int_buf();
1950     const int* indices = 0;
1951     data->get_ord_var_data( node, vi, values_buf, indices_buf, &values, &indices );
1952     int* responses_buf =  data->get_resp_int_buf();
1953     const int* responses = 0;
1954     data->get_class_labels( node, responses_buf, &responses );
1955
1956     const int* rc0 = data->counts->data.i;
1957     int* lc = (int*)cvStackAlloc(m*sizeof(lc[0]));
1958     int* rc = (int*)cvStackAlloc(m*sizeof(rc[0]));
1959     int i, best_i = -1;
1960     double lsum2 = 0, rsum2 = 0, best_val = init_quality;
1961     const double* priors = data->have_priors ? data->priors_mult->data.db : 0;
1962
1963     // init arrays of class instance counters on both sides of the split
1964     for( i = 0; i < m; i++ )
1965     {
1966         lc[i] = 0;
1967         rc[i] = rc0[i];
1968     }
1969
1970     // compensate for missing values
1971     for( i = n1; i < n; i++ )
1972     {
1973         rc[responses[indices[i]]]--;
1974     }
1975
1976     if( !priors )
1977     {
1978         int L = 0, R = n1;
1979
1980         for( i = 0; i < m; i++ )
1981             rsum2 += (double)rc[i]*rc[i];
1982
1983         for( i = 0; i < n1 - 1; i++ )
1984         {
1985             int idx = responses[indices[i]];
1986             int lv, rv;
1987             L++; R--;
1988             lv = lc[idx]; rv = rc[idx];
1989             lsum2 += lv*2 + 1;
1990             rsum2 -= rv*2 - 1;
1991             lc[idx] = lv + 1; rc[idx] = rv - 1;
1992
1993             if( values[i] + epsilon < values[i+1] )
1994             {
1995                 double val = (lsum2*R + rsum2*L)/((double)L*R);
1996                 if( best_val < val )
1997                 {
1998                     best_val = val;
1999                     best_i = i;
2000                 }
2001             }
2002         }
2003     }
2004     else
2005     {
2006         double L = 0, R = 0;
2007         for( i = 0; i < m; i++ )
2008         {
2009             double wv = rc[i]*priors[i];
2010             R += wv;
2011             rsum2 += wv*wv;
2012         }
2013
2014         for( i = 0; i < n1 - 1; i++ )
2015         {
2016             int idx = responses[indices[i]];
2017             int lv, rv;
2018             double p = priors[idx], p2 = p*p;
2019             L += p; R -= p;
2020             lv = lc[idx]; rv = rc[idx];
2021             lsum2 += p2*(lv*2 + 1);
2022             rsum2 -= p2*(rv*2 - 1);
2023             lc[idx] = lv + 1; rc[idx] = rv - 1;
2024
2025             if( values[i] + epsilon < values[i+1] )
2026             {
2027                 double val = (lsum2*R + rsum2*L)/((double)L*R);
2028                 if( best_val < val )
2029                 {
2030                     best_val = val;
2031                     best_i = i;
2032                 }
2033             }
2034         }
2035     }
2036
2037     CvDTreeSplit* split = 0;
2038     if( best_i >= 0 )
2039     {
2040         split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f );
2041         split->var_idx = vi;
2042         split->ord.c = (values[best_i] + values[best_i+1])*0.5f;
2043         split->ord.split_point = best_i;
2044         split->inversed = 0;
2045         split->quality = (float)best_val;
2046     }
2047     return split;
2048 }
2049
2050
2051 void CvDTree::cluster_categories( const int* vectors, int n, int m,
2052                                 int* csums, int k, int* labels )
2053 {
2054     // TODO: consider adding priors (class weights) and sample weights to the clustering algorithm
2055     int iters = 0, max_iters = 100;
2056     int i, j, idx;
2057     double* buf = (double*)cvStackAlloc( (n + k)*sizeof(buf[0]) );
2058     double *v_weights = buf, *c_weights = buf + k;
2059     bool modified = true;
2060     CvRNG* r = &data->rng;
2061
2062     // assign labels randomly
2063     for( i = idx = 0; i < n; i++ )
2064     {
2065         int sum = 0;
2066         const int* v = vectors + i*m;
2067         labels[i] = idx++;
2068         idx &= idx < k ? -1 : 0;
2069
2070         // compute weight of each vector
2071         for( j = 0; j < m; j++ )
2072             sum += v[j];
2073         v_weights[i] = sum ? 1./sum : 0.;
2074     }
2075
2076     for( i = 0; i < n; i++ )
2077     {
2078         int i1 = cvRandInt(r) % n;
2079         int i2 = cvRandInt(r) % n;
2080         CV_SWAP( labels[i1], labels[i2], j );
2081     }
2082
2083     for( iters = 0; iters <= max_iters; iters++ )
2084     {
2085         // calculate csums
2086         for( i = 0; i < k; i++ )
2087         {
2088             for( j = 0; j < m; j++ )
2089                 csums[i*m + j] = 0;
2090         }
2091
2092         for( i = 0; i < n; i++ )
2093         {
2094             const int* v = vectors + i*m;
2095             int* s = csums + labels[i]*m;
2096             for( j = 0; j < m; j++ )
2097                 s[j] += v[j];
2098         }
2099
2100         // exit the loop here, when we have up-to-date csums
2101         if( iters == max_iters || !modified )
2102             break;
2103
2104         modified = false;
2105
2106         // calculate weight of each cluster
2107         for( i = 0; i < k; i++ )
2108         {
2109             const int* s = csums + i*m;
2110             int sum = 0;
2111             for( j = 0; j < m; j++ )
2112                 sum += s[j];
2113             c_weights[i] = sum ? 1./sum : 0;
2114         }
2115
2116         // now for each vector determine the closest cluster
2117         for( i = 0; i < n; i++ )
2118         {
2119             const int* v = vectors + i*m;
2120             double alpha = v_weights[i];
2121             double min_dist2 = DBL_MAX;
2122             int min_idx = -1;
2123
2124             for( idx = 0; idx < k; idx++ )
2125             {
2126                 const int* s = csums + idx*m;
2127                 double dist2 = 0., beta = c_weights[idx];
2128                 for( j = 0; j < m; j++ )
2129                 {
2130                     double t = v[j]*alpha - s[j]*beta;
2131                     dist2 += t*t;
2132                 }
2133                 if( min_dist2 > dist2 )
2134                 {
2135                     min_dist2 = dist2;
2136                     min_idx = idx;
2137                 }
2138             }
2139
2140             if( min_idx != labels[i] )
2141                 modified = true;
2142             labels[i] = min_idx;
2143         }
2144     }
2145 }
2146
2147
2148 CvDTreeSplit* CvDTree::find_split_cat_class( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split )
2149 {
2150     int ci = data->get_var_type(vi);
2151     int n = node->sample_count;
2152     int m = data->get_num_classes();
2153     int _mi = data->cat_count->data.i[ci], mi = _mi;
2154
2155     int* labels_buf = data->get_pred_int_buf();
2156     const int* labels = 0;
2157     data->get_cat_var_data(node, vi, labels_buf, &labels);
2158     int *responses_buf = data->get_resp_int_buf();
2159     const int* responses = 0;
2160     data->get_class_labels(node, responses_buf, &responses);
2161
2162     int* lc = (int*)cvStackAlloc(m*sizeof(lc[0]));
2163     int* rc = (int*)cvStackAlloc(m*sizeof(rc[0]));
2164     int* _cjk = (int*)cvStackAlloc(m*(mi+1)*sizeof(_cjk[0]))+m, *cjk = _cjk;
2165     double* c_weights = (double*)cvStackAlloc( mi*sizeof(c_weights[0]) );
2166     int* cluster_labels = 0;
2167     int** int_ptr = 0;
2168     int i, j, k, idx;
2169     double L = 0, R = 0;
2170     double best_val = init_quality;
2171     int prevcode = 0, best_subset = -1, subset_i, subset_n, subtract = 0;
2172     const double* priors = data->priors_mult->data.db;
2173
2174     // init array of counters:
2175     // c_{jk} - number of samples that have vi-th input variable = j and response = k.
2176     for( j = -1; j < mi; j++ )
2177         for( k = 0; k < m; k++ )
2178             cjk[j*m + k] = 0;
2179
2180     for( i = 0; i < n; i++ )
2181     {
2182        j = ( labels[i] == 65535 && data->is_buf_16u) ? -1 : labels[i];
2183        k = responses[i];
2184        cjk[j*m + k]++;
2185     }
2186
2187     if( m > 2 )
2188     {
2189         if( mi > data->params.max_categories )
2190         {
2191             mi = MIN(data->params.max_categories, n);
2192             cjk += _mi*m;
2193             cluster_labels = (int*)cvStackAlloc(mi*sizeof(cluster_labels[0]));
2194             cluster_categories( _cjk, _mi, m, cjk, mi, cluster_labels );
2195         }
2196         subset_i = 1;
2197         subset_n = 1 << mi;
2198     }
2199     else
2200     {
2201         assert( m == 2 );
2202         int_ptr = (int**)cvStackAlloc( mi*sizeof(int_ptr[0]) );
2203         for( j = 0; j < mi; j++ )
2204             int_ptr[j] = cjk + j*2 + 1;
2205         icvSortIntPtr( int_ptr, mi, 0 );
2206         subset_i = 0;
2207         subset_n = mi;
2208     }
2209
2210     for( k = 0; k < m; k++ )
2211     {
2212         int sum = 0;
2213         for( j = 0; j < mi; j++ )
2214             sum += cjk[j*m + k];
2215         rc[k] = sum;
2216         lc[k] = 0;
2217     }
2218
2219     for( j = 0; j < mi; j++ )
2220     {
2221         double sum = 0;
2222         for( k = 0; k < m; k++ )
2223             sum += cjk[j*m + k]*priors[k];
2224         c_weights[j] = sum;
2225         R += c_weights[j];
2226     }
2227
2228     for( ; subset_i < subset_n; subset_i++ )
2229     {
2230         double weight;
2231         int* crow;
2232         double lsum2 = 0, rsum2 = 0;
2233
2234         if( m == 2 )
2235             idx = (int)(int_ptr[subset_i] - cjk)/2;
2236         else
2237         {
2238             int graycode = (subset_i>>1)^subset_i;
2239             int diff = graycode ^ prevcode;
2240
2241             // determine index of the changed bit.
2242             Cv32suf u;
2243             idx = diff >= (1 << 16) ? 16 : 0;
2244             u.f = (float)(((diff >> 16) | diff) & 65535);
2245             idx += (u.i >> 23) - 127;
2246             subtract = graycode < prevcode;
2247             prevcode = graycode;
2248         }
2249
2250         crow = cjk + idx*m;
2251         weight = c_weights[idx];
2252         if( weight < FLT_EPSILON )
2253             continue;
2254
2255         if( !subtract )
2256         {
2257             for( k = 0; k < m; k++ )
2258             {
2259                 int t = crow[k];
2260                 int lval = lc[k] + t;
2261                 int rval = rc[k] - t;
2262                 double p = priors[k], p2 = p*p;
2263                 lsum2 += p2*lval*lval;
2264                 rsum2 += p2*rval*rval;
2265                 lc[k] = lval; rc[k] = rval;
2266             }
2267             L += weight;
2268             R -= weight;
2269         }
2270         else
2271         {
2272             for( k = 0; k < m; k++ )
2273             {
2274                 int t = crow[k];
2275                 int lval = lc[k] - t;
2276                 int rval = rc[k] + t;
2277                 double p = priors[k], p2 = p*p;
2278                 lsum2 += p2*lval*lval;
2279                 rsum2 += p2*rval*rval;
2280                 lc[k] = lval; rc[k] = rval;
2281             }
2282             L -= weight;
2283             R += weight;
2284         }
2285
2286         if( L > FLT_EPSILON && R > FLT_EPSILON )
2287         {
2288             double val = (lsum2*R + rsum2*L)/((double)L*R);
2289             if( best_val < val )
2290             {
2291                 best_val = val;
2292                 best_subset = subset_i;
2293             }
2294         }
2295     }
2296
2297     CvDTreeSplit* split = 0;
2298     if( best_subset >= 0 ) 
2299     {
2300         split = _split ? _split : data->new_split_cat( 0, -1.0f );
2301         split->var_idx = vi;
2302         split->quality = (float)best_val;
2303         memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int));
2304         if( m == 2 )
2305         {
2306             for( i = 0; i <= best_subset; i++ )
2307             {
2308                 idx = (int)(int_ptr[i] - cjk) >> 1;
2309                 split->subset[idx >> 5] |= 1 << (idx & 31);
2310             }
2311         }
2312         else
2313         {
2314             for( i = 0; i < _mi; i++ )
2315             {
2316                 idx = cluster_labels ? cluster_labels[i] : i;
2317                 if( best_subset & (1 << idx) )
2318                     split->subset[i >> 5] |= 1 << (i & 31);
2319             }
2320         }
2321     }
2322     return split;
2323 }
2324
2325
2326 CvDTreeSplit* CvDTree::find_split_ord_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split )
2327 {
2328     const float epsilon = FLT_EPSILON*2;
2329     int n = node->sample_count;
2330     int n1 = node->get_num_valid(vi);
2331
2332     float* values_buf = data->get_pred_float_buf();
2333     const float* values = 0;
2334     int* indices_buf = data->get_pred_int_buf();
2335     const int* indices = 0;
2336     data->get_ord_var_data( node, vi, values_buf, indices_buf, &values, &indices );
2337     float* responses_buf =  data->get_resp_float_buf();
2338     const float* responses = 0;
2339     data->get_ord_responses( node, responses_buf, &responses );
2340
2341     int i, best_i = -1;
2342     double best_val = init_quality, lsum = 0, rsum = node->value*n;
2343     int L = 0, R = n1;
2344
2345     // compensate for missing values
2346     for( i = n1; i < n; i++ )
2347         rsum -= responses[indices[i]];
2348
2349     // find the optimal split
2350     for( i = 0; i < n1 - 1; i++ )
2351     {
2352         float t = responses[indices[i]];
2353         L++; R--;
2354         lsum += t;
2355         rsum -= t;
2356
2357         if( values[i] + epsilon < values[i+1] )
2358         {
2359             double val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);
2360             if( best_val < val )
2361             {
2362                 best_val = val;
2363                 best_i = i;
2364             }
2365         }
2366     }
2367
2368     CvDTreeSplit* split = 0;
2369     if( best_i >= 0 )
2370     {
2371         split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f );
2372         split->var_idx = vi;
2373         split->ord.c = (values[best_i] + values[best_i+1])*0.5f;
2374         split->ord.split_point = best_i;
2375         split->inversed = 0;
2376         split->quality = (float)best_val;
2377     }
2378     return split;
2379 }
2380
2381 CvDTreeSplit* CvDTree::find_split_cat_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split )
2382 {
2383     int ci = data->get_var_type(vi);
2384     int n = node->sample_count;
2385     int mi = data->cat_count->data.i[ci];
2386     int* labels_buf = data->get_pred_int_buf();
2387     const int* labels = 0;
2388     float* responses_buf = data->get_resp_float_buf();
2389     const float* responses = 0;
2390     data->get_cat_var_data(node, vi, labels_buf, &labels);
2391     data->get_ord_responses(node, responses_buf, &responses);
2392
2393     double* sum = (double*)cvStackAlloc( (mi+1)*sizeof(sum[0]) ) + 1;
2394     int* counts = (int*)cvStackAlloc( (mi+1)*sizeof(counts[0]) ) + 1;
2395     double** sum_ptr = (double**)cvStackAlloc( (mi+1)*sizeof(sum_ptr[0]) );
2396     int i, L = 0, R = 0;
2397     double best_val = init_quality, lsum = 0, rsum = 0;
2398     int best_subset = -1, subset_i;
2399
2400     for( i = -1; i < mi; i++ )
2401         sum[i] = counts[i] = 0;
2402
2403     // calculate sum response and weight of each category of the input var
2404     for( i = 0; i < n; i++ )
2405     {
2406         int idx = ( (labels[i] == 65535) && data->is_buf_16u ) ? -1 : labels[i];
2407         double s = sum[idx] + responses[i];
2408         int nc = counts[idx] + 1;
2409         sum[idx] = s;
2410         counts[idx] = nc;
2411     }
2412
2413     // calculate average response in each category
2414     for( i = 0; i < mi; i++ )
2415     {
2416         R += counts[i];
2417         rsum += sum[i];
2418         sum[i] /= MAX(counts[i],1);
2419         sum_ptr[i] = sum + i;
2420     }
2421
2422     icvSortDblPtr( sum_ptr, mi, 0 );
2423
2424     // revert back to unnormalized sums
2425     // (there should be a very little loss of accuracy)
2426     for( i = 0; i < mi; i++ )
2427         sum[i] *= counts[i];
2428
2429     for( subset_i = 0; subset_i < mi-1; subset_i++ )
2430     {
2431         int idx = (int)(sum_ptr[subset_i] - sum);
2432         int ni = counts[idx];
2433
2434         if( ni )
2435         {
2436             double s = sum[idx];
2437             lsum += s; L += ni;
2438             rsum -= s; R -= ni;
2439
2440             if( L && R )
2441             {
2442                 double val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);
2443                 if( best_val < val )
2444                 {
2445                     best_val = val;
2446                     best_subset = subset_i;
2447                 }
2448             }
2449         }
2450     }
2451
2452     CvDTreeSplit* split = 0;
2453     if( best_subset >= 0 )
2454     {
2455         split = _split ? _split : data->new_split_cat( 0, -1.0f);
2456         split->var_idx = vi;
2457         split->quality = (float)best_val;
2458         memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int));
2459         for( i = 0; i <= best_subset; i++ )
2460         {
2461             int idx = (int)(sum_ptr[i] - sum);
2462             split->subset[idx >> 5] |= 1 << (idx & 31);
2463         }
2464     }
2465     return split;
2466 }
2467
2468 CvDTreeSplit* CvDTree::find_surrogate_split_ord( CvDTreeNode* node, int vi )
2469 {
2470     const float epsilon = FLT_EPSILON*2;
2471     const char* dir = (char*)data->direction->data.ptr;
2472     int n1 = node->get_num_valid(vi);
2473     float* values_buf = data->get_pred_float_buf();
2474     const float* values = 0;
2475     int* indices_buf = data->get_pred_int_buf();
2476     const int* indices = 0;
2477     data->get_ord_var_data( node, vi, values_buf, indices_buf, &values, &indices );
2478     // LL - number of samples that both the primary and the surrogate splits send to the left
2479     // LR - ... primary split sends to the left and the surrogate split sends to the right
2480     // RL - ... primary split sends to the right and the surrogate split sends to the left
2481     // RR - ... both send to the right
2482     int i, best_i = -1, best_inversed = 0;
2483     double best_val;
2484
2485     if( !data->have_priors )
2486     {
2487         int LL = 0, RL = 0, LR, RR;
2488         int worst_val = cvFloor(node->maxlr), _best_val = worst_val;
2489         int sum = 0, sum_abs = 0;
2490
2491         for( i = 0; i < n1; i++ )
2492         {
2493             int d = dir[indices[i]];
2494             sum += d; sum_abs += d & 1;
2495         }
2496
2497         // sum_abs = R + L; sum = R - L
2498         RR = (sum_abs + sum) >> 1;
2499         LR = (sum_abs - sum) >> 1;
2500
2501         // initially all the samples are sent to the right by the surrogate split,
2502         // LR of them are sent to the left by primary split, and RR - to the right.
2503         // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.
2504         for( i = 0; i < n1 - 1; i++ )
2505         {
2506             int d = dir[indices[i]];
2507
2508             if( d < 0 )
2509             {
2510                 LL++; LR--;
2511                 if( LL + RR > _best_val && values[i] + epsilon < values[i+1] )
2512                 {
2513                     best_val = LL + RR;
2514                     best_i = i; best_inversed = 0;
2515                 }
2516             }
2517             else if( d > 0 )
2518             {
2519                 RL++; RR--;
2520                 if( RL + LR > _best_val && values[i] + epsilon < values[i+1] )
2521                 {
2522                     best_val = RL + LR;
2523                     best_i = i; best_inversed = 1;
2524                 }
2525             }
2526         }
2527         best_val = _best_val;
2528     }
2529     else
2530     {
2531         double LL = 0, RL = 0, LR, RR;
2532         double worst_val = node->maxlr;
2533         double sum = 0, sum_abs = 0;
2534         const double* priors = data->priors_mult->data.db;
2535         int* responses_buf = data->get_resp_int_buf();
2536         const int* responses = 0;
2537         data->get_class_labels(node, responses_buf, &responses);
2538         best_val = worst_val;
2539
2540         for( i = 0; i < n1; i++ )
2541         {
2542             int idx = indices[i];
2543             double w = priors[responses[idx]];
2544             int d = dir[idx];
2545             sum += d*w; sum_abs += (d & 1)*w;
2546         }
2547
2548         // sum_abs = R + L; sum = R - L
2549         RR = (sum_abs + sum)*0.5;
2550         LR = (sum_abs - sum)*0.5;
2551
2552         // initially all the samples are sent to the right by the surrogate split,
2553         // LR of them are sent to the left by primary split, and RR - to the right.
2554         // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.
2555         for( i = 0; i < n1 - 1; i++ )
2556         {
2557             int idx = indices[i];
2558             double w = priors[responses[idx]];
2559             int d = dir[idx];
2560
2561             if( d < 0 )
2562             {
2563                 LL += w; LR -= w;
2564                 if( LL + RR > best_val && values[i] + epsilon < values[i+1] )
2565                 {
2566                     best_val = LL + RR;
2567                     best_i = i; best_inversed = 0;
2568                 }
2569             }
2570             else if( d > 0 )
2571             {
2572                 RL += w; RR -= w;
2573                 if( RL + LR > best_val && values[i] + epsilon < values[i+1] )
2574                 {
2575                     best_val = RL + LR;
2576                     best_i = i; best_inversed = 1;
2577                 }
2578             }
2579         }
2580     }
2581     return best_i >= 0 && best_val > node->maxlr ? data->new_split_ord( vi,
2582         (values[best_i] + values[best_i+1])*0.5f, best_i, best_inversed, (float)best_val ) : 0;
2583 }
2584
2585
2586 CvDTreeSplit* CvDTree::find_surrogate_split_cat( CvDTreeNode* node, int vi )
2587 {
2588     const char* dir = (char*)data->direction->data.ptr;
2589     int n = node->sample_count;
2590     int* labels_buf = data->get_pred_int_buf();
2591     const int* labels = 0;
2592     data->get_cat_var_data(node, vi, labels_buf, &labels);
2593     // LL - number of samples that both the primary and the surrogate splits send to the left
2594     // LR - ... primary split sends to the left and the surrogate split sends to the right
2595     // RL - ... primary split sends to the right and the surrogate split sends to the left
2596     // RR - ... both send to the right
2597     CvDTreeSplit* split = data->new_split_cat( vi, 0 );
2598     int i, mi = data->cat_count->data.i[data->get_var_type(vi)], l_win = 0;
2599     double best_val = 0;
2600     double* lc = (double*)cvStackAlloc( (mi+1)*2*sizeof(lc[0]) ) + 1;
2601     double* rc = lc + mi + 1;
2602
2603     for( i = -1; i < mi; i++ )
2604         lc[i] = rc[i] = 0;
2605
2606     // for each category calculate the weight of samples
2607     // sent to the left (lc) and to the right (rc) by the primary split
2608     if( !data->have_priors )
2609     {
2610         int* _lc = (int*)cvStackAlloc((mi+2)*2*sizeof(_lc[0])) + 1;
2611         int* _rc = _lc + mi + 1;
2612
2613         for( i = -1; i < mi; i++ )
2614             _lc[i] = _rc[i] = 0;
2615
2616         for( i = 0; i < n; i++ )
2617         {
2618             int idx = ( (labels[i] == 65535) && (data->is_buf_16u) ) ? -1 : labels[i];
2619             int d = dir[i];
2620             int sum = _lc[idx] + d;
2621             int sum_abs = _rc[idx] + (d & 1);
2622             _lc[idx] = sum; _rc[idx] = sum_abs;
2623         }
2624
2625         for( i = 0; i < mi; i++ )
2626         {
2627             int sum = _lc[i];
2628             int sum_abs = _rc[i];
2629             lc[i] = (sum_abs - sum) >> 1;
2630             rc[i] = (sum_abs + sum) >> 1;
2631         }
2632     }
2633     else
2634     {
2635         const double* priors = data->priors_mult->data.db;
2636         int* responses_buf = data->get_resp_int_buf();
2637         const int* responses = 0;
2638         data->get_class_labels(node, responses_buf, &responses);
2639
2640         for( i = 0; i < n; i++ )
2641         {
2642             int idx = ( (labels[i] == 65535) && (data->is_buf_16u) ) ? -1 : labels[i];
2643             double w = priors[responses[i]];
2644             int d = dir[i];
2645             double sum = lc[idx] + d*w;
2646             double sum_abs = rc[idx] + (d & 1)*w;
2647             lc[idx] = sum; rc[idx] = sum_abs;
2648         }
2649
2650         for( i = 0; i < mi; i++ )
2651         {
2652             double sum = lc[i];
2653             double sum_abs = rc[i];
2654             lc[i] = (sum_abs - sum) * 0.5;
2655             rc[i] = (sum_abs + sum) * 0.5;
2656         }
2657     }
2658
2659     // 2. now form the split.
2660     // in each category send all the samples to the same direction as majority
2661     for( i = 0; i < mi; i++ )
2662     {
2663         double lval = lc[i], rval = rc[i];
2664         if( lval > rval )
2665         {
2666             split->subset[i >> 5] |= 1 << (i & 31);
2667             best_val += lval;
2668             l_win++;
2669         }
2670         else
2671             best_val += rval;
2672     }
2673
2674     split->quality = (float)best_val;
2675     if( split->quality <= node->maxlr || l_win == 0 || l_win == mi )
2676         cvSetRemoveByPtr( data->split_heap, split ), split = 0;
2677
2678     return split;
2679 }
2680
2681
2682 void CvDTree::calc_node_value( CvDTreeNode* node )
2683 {
2684     int i, j, k, n = node->sample_count, cv_n = data->params.cv_folds;
2685     int* cv_labels_buf = data->get_cv_lables_buf();
2686     const int* cv_labels = 0;
2687     data->get_cv_labels(node, cv_labels_buf, &cv_labels);
2688
2689     if( data->is_classifier )
2690     {
2691         // in case of classification tree:
2692         //  * node value is the label of the class that has the largest weight in the node.
2693         //  * node risk is the weighted number of misclassified samples,
2694         //  * j-th cross-validation fold value and risk are calculated as above,
2695         //    but using the samples with cv_labels(*)!=j.
2696         //  * j-th cross-validation fold error is calculated as the weighted number of
2697         //    misclassified samples with cv_labels(*)==j.
2698
2699         // compute the number of instances of each class
2700         int* cls_count = data->counts->data.i;
2701         int* responses_buf = data->get_resp_int_buf();
2702         const int* responses = 0;
2703         data->get_class_labels(node, responses_buf, &responses);
2704         int m = data->get_num_classes();
2705         int* cv_cls_count = (int*)cvStackAlloc(m*cv_n*sizeof(cv_cls_count[0]));
2706         double max_val = -1, total_weight = 0;
2707         int max_k = -1;
2708         double* priors = data->priors_mult->data.db;
2709
2710         for( k = 0; k < m; k++ )
2711             cls_count[k] = 0;
2712
2713         if( cv_n == 0 )
2714         {
2715             for( i = 0; i < n; i++ )
2716                 cls_count[responses[i]]++;
2717         }
2718         else
2719         {
2720             for( j = 0; j < cv_n; j++ )
2721                 for( k = 0; k < m; k++ )
2722                     cv_cls_count[j*m + k] = 0;
2723
2724             for( i = 0; i < n; i++ )
2725             {
2726                 j = cv_labels[i]; k = responses[i];
2727                 cv_cls_count[j*m + k]++;
2728             }
2729
2730             for( j = 0; j < cv_n; j++ )
2731                 for( k = 0; k < m; k++ )
2732                     cls_count[k] += cv_cls_count[j*m + k];
2733         }
2734
2735         if( data->have_priors && node->parent == 0 )
2736         {
2737             // compute priors_mult from priors, take the sample ratio into account.
2738             double sum = 0;
2739             for( k = 0; k < m; k++ )
2740             {
2741                 int n_k = cls_count[k];
2742                 priors[k] = data->priors->data.db[k]*(n_k ? 1./n_k : 0.);
2743                 sum += priors[k];
2744             }
2745             sum = 1./sum;
2746             for( k = 0; k < m; k++ )
2747                 priors[k] *= sum;
2748         }
2749
2750         for( k = 0; k < m; k++ )
2751         {
2752             double val = cls_count[k]*priors[k];
2753             total_weight += val;
2754             if( max_val < val )
2755             {
2756                 max_val = val;
2757                 max_k = k;
2758             }
2759         }
2760
2761         node->class_idx = max_k;
2762         node->value = data->cat_map->data.i[
2763             data->cat_ofs->data.i[data->cat_var_count] + max_k];
2764         node->node_risk = total_weight - max_val;
2765
2766         for( j = 0; j < cv_n; j++ )
2767         {
2768             double sum_k = 0, sum = 0, max_val_k = 0;
2769             max_val = -1; max_k = -1;
2770
2771             for( k = 0; k < m; k++ )
2772             {
2773                 double w = priors[k];
2774                 double val_k = cv_cls_count[j*m + k]*w;
2775                 double val = cls_count[k]*w - val_k;
2776                 sum_k += val_k;
2777                 sum += val;
2778                 if( max_val < val )
2779                 {
2780                     max_val = val;
2781                     max_val_k = val_k;
2782                     max_k = k;
2783                 }
2784             }
2785
2786             node->cv_Tn[j] = INT_MAX;
2787             node->cv_node_risk[j] = sum - max_val;
2788             node->cv_node_error[j] = sum_k - max_val_k;
2789         }
2790     }
2791     else
2792     {
2793         // in case of regression tree:
2794         //  * node value is 1/n*sum_i(Y_i), where Y_i is i-th response,
2795         //    n is the number of samples in the node.
2796         //  * node risk is the sum of squared errors: sum_i((Y_i - <node_value>)^2)
2797         //  * j-th cross-validation fold value and risk are calculated as above,
2798         //    but using the samples with cv_labels(*)!=j.
2799         //  * j-th cross-validation fold error is calculated
2800         //    using samples with cv_labels(*)==j as the test subset:
2801         //    error_j = sum_(i,cv_labels(i)==j)((Y_i - <node_value_j>)^2),
2802         //    where node_value_j is the node value calculated
2803         //    as described in the previous bullet, and summation is done
2804         //    over the samples with cv_labels(*)==j.
2805
2806         double sum = 0, sum2 = 0;
2807         float* values_buf = data->get_resp_float_buf();
2808         const float* values = 0;
2809         data->get_ord_responses(node, values_buf, &values);
2810         double *cv_sum = 0, *cv_sum2 = 0;
2811         int* cv_count = 0;
2812
2813         if( cv_n == 0 )
2814         {
2815             for( i = 0; i < n; i++ )
2816             {
2817                 double t = values[i];
2818                 sum += t;
2819                 sum2 += t*t;
2820             }
2821         }
2822         else
2823         {
2824             cv_sum = (double*)cvStackAlloc( cv_n*sizeof(cv_sum[0]) );
2825             cv_sum2 = (double*)cvStackAlloc( cv_n*sizeof(cv_sum2[0]) );
2826             cv_count = (int*)cvStackAlloc( cv_n*sizeof(cv_count[0]) );
2827
2828             for( j = 0; j < cv_n; j++ )
2829             {
2830                 cv_sum[j] = cv_sum2[j] = 0.;
2831                 cv_count[j] = 0;
2832             }
2833
2834             for( i = 0; i < n; i++ )
2835             {
2836                 j = cv_labels[i];
2837                 double t = values[i];
2838                 double s = cv_sum[j] + t;
2839                 double s2 = cv_sum2[j] + t*t;
2840                 int nc = cv_count[j] + 1;
2841                 cv_sum[j] = s;
2842                 cv_sum2[j] = s2;
2843                 cv_count[j] = nc;
2844             }
2845
2846             for( j = 0; j < cv_n; j++ )
2847             {
2848                 sum += cv_sum[j];
2849                 sum2 += cv_sum2[j];
2850             }
2851         }
2852
2853         node->node_risk = sum2 - (sum/n)*sum;
2854         node->value = sum/n;
2855
2856         for( j = 0; j < cv_n; j++ )
2857         {
2858             double s = cv_sum[j], si = sum - s;
2859             double s2 = cv_sum2[j], s2i = sum2 - s2;
2860             int c = cv_count[j], ci = n - c;
2861             double r = si/MAX(ci,1);
2862             node->cv_node_risk[j] = s2i - r*r*ci;
2863             node->cv_node_error[j] = s2 - 2*r*s + c*r*r;
2864             node->cv_Tn[j] = INT_MAX;
2865         }
2866     }
2867 }
2868
2869
2870 void CvDTree::complete_node_dir( CvDTreeNode* node )
2871 {
2872     int vi, i, n = node->sample_count, nl, nr, d0 = 0, d1 = -1;
2873     int nz = n - node->get_num_valid(node->split->var_idx);
2874     char* dir = (char*)data->direction->data.ptr;
2875
2876     // try to complete direction using surrogate splits
2877     if( nz && data->params.use_surrogates )
2878     {
2879         CvDTreeSplit* split = node->split->next;
2880         for( ; split != 0 && nz; split = split->next )
2881         {
2882             int inversed_mask = split->inversed ? -1 : 0;
2883             vi = split->var_idx;
2884
2885             if( data->get_var_type(vi) >= 0 ) // split on categorical var
2886             {
2887                 int* labels_buf = data->get_pred_int_buf();
2888                 const int* labels = 0;
2889                 data->get_cat_var_data(node, vi, labels_buf, &labels);
2890                 const int* subset = split->subset;
2891
2892                 for( i = 0; i < n; i++ )
2893                 {
2894                     int idx = labels[i];
2895                     if( !dir[i] && ( ((idx >= 0)&&(!data->is_buf_16u)) || ((idx != 65535)&&(data->is_buf_16u)) ))
2896                         
2897                     {
2898                         int d = CV_DTREE_CAT_DIR(idx,subset);
2899                         dir[i] = (char)((d ^ inversed_mask) - inversed_mask);
2900                         if( --nz )
2901                             break;
2902                     }
2903                 }
2904             }
2905             else // split on ordered var
2906             {
2907                 float* values_buf = data->get_pred_float_buf();
2908                 const float* values = 0;
2909                 int* indices_buf = data->get_pred_int_buf();
2910                 const int* indices = 0;
2911                 data->get_ord_var_data( node, vi, values_buf, indices_buf, &values, &indices );
2912                 int split_point = split->ord.split_point;
2913                 int n1 = node->get_num_valid(vi);
2914
2915                 assert( 0 <= split_point && split_point < n-1 );
2916
2917                 for( i = 0; i < n1; i++ )
2918                 {
2919                     int idx = indices[i];
2920                     if( !dir[idx] )
2921                     {
2922                         int d = i <= split_point ? -1 : 1;
2923                         dir[idx] = (char)((d ^ inversed_mask) - inversed_mask);
2924                         if( --nz )
2925                             break;
2926                     }
2927                 }
2928             }
2929         }
2930     }
2931
2932     // find the default direction for the rest
2933     if( nz )
2934     {
2935         for( i = nr = 0; i < n; i++ )
2936             nr += dir[i] > 0;
2937         nl = n - nr - nz;
2938         d0 = nl > nr ? -1 : nr > nl;
2939     }
2940
2941     // make sure that every sample is directed either to the left or to the right
2942     for( i = 0; i < n; i++ )
2943     {
2944         int d = dir[i];
2945         if( !d )
2946         {
2947             d = d0;
2948             if( !d )
2949                 d = d1, d1 = -d1;
2950         }
2951         d = d > 0;
2952         dir[i] = (char)d; // remap (-1,1) to (0,1)
2953     }
2954 }
2955
2956
2957 void CvDTree::split_node_data( CvDTreeNode* node )
2958 {
2959     int vi, i, n = node->sample_count, nl, nr, scount = data->sample_count;
2960     char* dir = (char*)data->direction->data.ptr;
2961     CvDTreeNode *left = 0, *right = 0;
2962     int* new_idx = data->split_buf->data.i;
2963     int new_buf_idx = data->get_child_buf_idx( node );
2964     int work_var_count = data->get_work_var_count();
2965     CvMat* buf = data->buf;
2966     cv::AutoBuffer<int, 1<<14> _temp_buf(n);
2967     int* temp_buf = _temp_buf;
2968
2969     complete_node_dir(node);
2970
2971     for( i = nl = nr = 0; i < n; i++ )
2972     {
2973         int d = dir[i];
2974         // initialize new indices for splitting ordered variables
2975         new_idx[i] = (nl & (d-1)) | (nr & -d); // d ? ri : li
2976         nr += d;
2977         nl += d^1;
2978     }
2979
2980
2981     bool split_input_data;
2982     node->left = left = data->new_node( node, nl, new_buf_idx, node->offset );
2983     node->right = right = data->new_node( node, nr, new_buf_idx, node->offset + nl );
2984
2985     split_input_data = node->depth + 1 < data->params.max_depth &&
2986         (node->left->sample_count > data->params.min_sample_count ||
2987         node->right->sample_count > data->params.min_sample_count);
2988
2989     // split ordered variables, keep both halves sorted.
2990     for( vi = 0; vi < data->var_count; vi++ )
2991     {
2992         int ci = data->get_var_type(vi);
2993         int n1 = node->get_num_valid(vi);
2994         int *src_idx_buf = data->get_pred_int_buf();
2995         const int* src_idx = 0;
2996         float *src_val_buf = data->get_pred_float_buf();
2997         const float* src_val = 0;
2998         
2999         if( ci >= 0 || !split_input_data )
3000             continue;
3001
3002         data->get_ord_var_data(node, vi, src_val_buf, src_idx_buf, &src_val, &src_idx);
3003
3004         for(i = 0; i < n; i++)
3005             temp_buf[i] = src_idx[i];
3006
3007         if (data->is_buf_16u)
3008         {
3009             unsigned short *ldst, *rdst, *ldst0, *rdst0;
3010             //unsigned short tl, tr;
3011             ldst0 = ldst = (unsigned short*)(buf->data.s + left->buf_idx*buf->cols + 
3012                 vi*scount + left->offset);
3013             rdst0 = rdst = (unsigned short*)(ldst + nl);
3014
3015             // split sorted
3016             for( i = 0; i < n1; i++ )
3017             {
3018                 int idx = temp_buf[i];
3019                 int d = dir[idx];
3020                 idx = new_idx[idx];
3021                 if (d)
3022                 {
3023                     *rdst = (unsigned short)idx;
3024                     rdst++;
3025                 }
3026                 else
3027                 {
3028                     *ldst = (unsigned short)idx;
3029                     ldst++;
3030                 }
3031             }
3032
3033             left->set_num_valid(vi, (int)(ldst - ldst0));
3034             right->set_num_valid(vi, (int)(rdst - rdst0));
3035
3036             // split missing
3037             for( ; i < n; i++ )
3038             {
3039                 int idx = temp_buf[i];
3040                 int d = dir[idx];
3041                 idx = new_idx[idx];
3042                 if (d)
3043                 {
3044                     *rdst = (unsigned short)idx;
3045                     rdst++;
3046                 }
3047                 else
3048                 {
3049                     *ldst = (unsigned short)idx;
3050                     ldst++;
3051                 }
3052             }
3053         }
3054         else
3055         {
3056             int *ldst0, *ldst, *rdst0, *rdst;
3057             ldst0 = ldst = buf->data.i + left->buf_idx*buf->cols + 
3058                 vi*scount + left->offset;
3059             rdst0 = rdst = buf->data.i + right->buf_idx*buf->cols + 
3060                 vi*scount + right->offset;
3061
3062             // split sorted
3063             for( i = 0; i < n1; i++ )
3064             {
3065                 int idx = temp_buf[i];
3066                 int d = dir[idx];
3067                 idx = new_idx[idx];
3068                 if (d)
3069                 {
3070                     *rdst = idx;
3071                     rdst++;
3072                 }
3073                 else
3074                 {
3075                     *ldst = idx;
3076                     ldst++;
3077                 }
3078             }
3079
3080             left->set_num_valid(vi, (int)(ldst - ldst0));
3081             right->set_num_valid(vi, (int)(rdst - rdst0));
3082
3083             // split missing
3084             for( ; i < n; i++ )
3085             {
3086                 int idx = temp_buf[i];
3087                 int d = dir[idx];
3088                 idx = new_idx[idx];
3089                 if (d)
3090                 {
3091                     *rdst = idx;
3092                     rdst++;
3093                 }
3094                 else
3095                 {
3096                     *ldst = idx;
3097                     ldst++;
3098                 }
3099             }
3100         }
3101     }
3102
3103     // split categorical vars, responses and cv_labels using new_idx relocation table
3104     for( vi = 0; vi < work_var_count; vi++ )
3105     {
3106         int ci = data->get_var_type(vi);
3107         int n1 = node->get_num_valid(vi), nr1 = 0;
3108         
3109         if( ci < 0 || (vi < data->var_count && !split_input_data) )
3110             continue;
3111
3112         int *src_lbls_buf = data->get_pred_int_buf();
3113         const int* src_lbls = 0;
3114         data->get_cat_var_data(node, vi, src_lbls_buf, &src_lbls);
3115
3116         for(i = 0; i < n; i++)
3117             temp_buf[i] = src_lbls[i];
3118
3119         if (data->is_buf_16u)
3120         {
3121             unsigned short *ldst = (unsigned short *)(buf->data.s + left->buf_idx*buf->cols + 
3122                 vi*scount + left->offset);
3123             unsigned short *rdst = (unsigned short *)(buf->data.s + right->buf_idx*buf->cols + 
3124                 vi*scount + right->offset);
3125             
3126             for( i = 0; i < n; i++ )
3127             {
3128                 int d = dir[i];
3129                 int idx = temp_buf[i];
3130                 if (d)
3131                 {
3132                     *rdst = (unsigned short)idx;
3133                     rdst++;
3134                     nr1 += (idx != 65535 )&d;
3135                 }
3136                 else
3137                 {
3138                     *ldst = (unsigned short)idx;
3139                     ldst++;
3140                 }
3141             }
3142
3143             if( vi < data->var_count )
3144             {
3145                 left->set_num_valid(vi, n1 - nr1);
3146                 right->set_num_valid(vi, nr1);
3147             }
3148         }
3149         else
3150         {
3151             int *ldst = buf->data.i + left->buf_idx*buf->cols + 
3152                 vi*scount + left->offset;
3153             int *rdst = buf->data.i + right->buf_idx*buf->cols + 
3154                 vi*scount + right->offset;
3155             
3156             for( i = 0; i < n; i++ )
3157             {
3158                 int d = dir[i];
3159                 int idx = temp_buf[i];
3160                 if (d)
3161                 {
3162                     *rdst = idx;
3163                     rdst++;
3164                     nr1 += (idx >= 0)&d;
3165                 }
3166                 else
3167                 {
3168                     *ldst = idx;
3169                     ldst++;
3170                 }
3171                 
3172             }
3173
3174             if( vi < data->var_count )
3175             {
3176                 left->set_num_valid(vi, n1 - nr1);
3177                 right->set_num_valid(vi, nr1);
3178             }
3179         }        
3180     }
3181
3182
3183     // split sample indices
3184     int *sample_idx_src_buf = data->get_sample_idx_buf();
3185     const int* sample_idx_src = 0;
3186     data->get_sample_indices(node, sample_idx_src_buf, &sample_idx_src);
3187
3188     for(i = 0; i < n; i++)
3189         temp_buf[i] = sample_idx_src[i];
3190
3191     int pos = data->get_work_var_count();
3192     if (data->is_buf_16u)
3193     {
3194         unsigned short* ldst = (unsigned short*)(buf->data.s + left->buf_idx*buf->cols + 
3195             pos*scount + left->offset);
3196         unsigned short* rdst = (unsigned short*)(buf->data.s + right->buf_idx*buf->cols + 
3197             pos*scount + right->offset);
3198         for (i = 0; i < n; i++)
3199         {
3200             int d = dir[i];
3201             unsigned short idx = (unsigned short)temp_buf[i];
3202             if (d)
3203             {
3204                 *rdst = idx;
3205                 rdst++;
3206             }
3207             else
3208             {
3209                 *ldst = idx;
3210                 ldst++;
3211             }
3212         }
3213     }
3214     else
3215     {
3216         int* ldst = buf->data.i + left->buf_idx*buf->cols + 
3217             pos*scount + left->offset;
3218         int* rdst = buf->data.i + right->buf_idx*buf->cols + 
3219             pos*scount + right->offset;
3220         for (i = 0; i < n; i++)
3221         {
3222             int d = dir[i];
3223             int idx = temp_buf[i];
3224             if (d)
3225             {
3226                 *rdst = idx;
3227                 rdst++;
3228             }
3229             else
3230             {
3231                 *ldst = idx;
3232                 ldst++;
3233             }
3234         }
3235     }
3236     
3237     // deallocate the parent node data that is not needed anymore
3238     data->free_node_data(node);    
3239 }
3240
3241 float CvDTree::calc_error( CvMLData* _data, int type, vector<float> *resp )
3242 {
3243     float err = 0;
3244     const CvMat* values = _data->get_values();
3245     const CvMat* response = _data->get_responses();
3246     const CvMat* missing = _data->get_missing();
3247     const CvMat* sample_idx = (type == CV_TEST_ERROR) ? _data->get_test_sample_idx() : _data->get_train_sample_idx();
3248     const CvMat* var_types = _data->get_var_types();
3249     int* sidx = sample_idx ? sample_idx->data.i : 0;
3250     int r_step = CV_IS_MAT_CONT(response->type) ?
3251                 1 : response->step / CV_ELEM_SIZE(response->type);
3252     bool is_classifier = var_types->data.ptr[var_types->cols-1] == CV_VAR_CATEGORICAL;
3253     int sample_count = sample_idx ? sample_idx->cols : 0;
3254     sample_count = (type == CV_TRAIN_ERROR && sample_count == 0) ? values->rows : sample_count;
3255     float* pred_resp = 0;
3256     if( resp && (sample_count > 0) )
3257     {
3258         resp->resize( sample_count );
3259         pred_resp = &((*resp)[0]);
3260     }
3261
3262     if ( is_classifier )
3263     {
3264         for( int i = 0; i < sample_count; i++ )
3265         {
3266             CvMat sample, miss;
3267             int si = sidx ? sidx[i] : i;
3268             cvGetRow( values, &sample, si ); 
3269             if( missing ) 
3270                 cvGetRow( missing, &miss, si );             
3271             float r = (float)predict( &sample, missing ? &miss : 0 )->value;
3272             if( pred_resp )
3273                 pred_resp[i] = r;
3274             int d = fabs((double)r - response->data.fl[si*r_step]) <= FLT_EPSILON ? 0 : 1;
3275             err += d;
3276         }
3277         err = sample_count ? err / (float)sample_count * 100 : -FLT_MAX;
3278     }
3279     else
3280     {
3281         for( int i = 0; i < sample_count; i++ )
3282         {
3283             CvMat sample, miss;
3284             int si = sidx ? sidx[i] : i;
3285             cvGetRow( values, &sample, si ); 
3286             if( missing ) 
3287                 cvGetRow( missing, &miss, si );             
3288             float r = (float)predict( &sample, missing ? &miss : 0 )->value;
3289             if( pred_resp )
3290                 pred_resp[i] = r;
3291             float d = r - response->data.fl[si*r_step];
3292             err += d*d;
3293         }
3294         err = sample_count ? err / (float)sample_count : -FLT_MAX;    
3295     }
3296     return err;
3297 }
3298
3299 void CvDTree::prune_cv()
3300 {
3301     CvMat* ab = 0;
3302     CvMat* temp = 0;
3303     CvMat* err_jk = 0;
3304
3305     // 1. build tree sequence for each cv fold, calculate error_{Tj,beta_k}.
3306     // 2. choose the best tree index (if need, apply 1SE rule).
3307     // 3. store the best index and cut the branches.
3308
3309     CV_FUNCNAME( "CvDTree::prune_cv" );
3310
3311     __BEGIN__;
3312
3313     int ti, j, tree_count = 0, cv_n = data->params.cv_folds, n = root->sample_count;
3314     // currently, 1SE for regression is not implemented
3315     bool use_1se = data->params.use_1se_rule != 0 && data->is_classifier;
3316     double* err;
3317     double min_err = 0, min_err_se = 0;
3318     int min_idx = -1;
3319
3320     CV_CALL( ab = cvCreateMat( 1, 256, CV_64F ));
3321
3322     // build the main tree sequence, calculate alpha's
3323     for(;;tree_count++)
3324     {
3325         double min_alpha = update_tree_rnc(tree_count, -1);
3326         if( cut_tree(tree_count, -1, min_alpha) )
3327             break;
3328
3329         if( ab->cols <= tree_count )
3330         {
3331             CV_CALL( temp = cvCreateMat( 1, ab->cols*3/2, CV_64F ));
3332             for( ti = 0; ti < ab->cols; ti++ )
3333                 temp->data.db[ti] = ab->data.db[ti];
3334             cvReleaseMat( &ab );
3335             ab = temp;
3336             temp = 0;
3337         }
3338
3339         ab->data.db[tree_count] = min_alpha;
3340     }
3341
3342     ab->data.db[0] = 0.;
3343
3344     if( tree_count > 0 )
3345     {
3346         for( ti = 1; ti < tree_count-1; ti++ )
3347             ab->data.db[ti] = sqrt(ab->data.db[ti]*ab->data.db[ti+1]);
3348         ab->data.db[tree_count-1] = DBL_MAX*0.5;
3349
3350         CV_CALL( err_jk = cvCreateMat( cv_n, tree_count, CV_64F ));
3351         err = err_jk->data.db;
3352
3353         for( j = 0; j < cv_n; j++ )
3354         {
3355             int tj = 0, tk = 0;
3356             for( ; tk < tree_count; tj++ )
3357             {
3358                 double min_alpha = update_tree_rnc(tj, j);
3359                 if( cut_tree(tj, j, min_alpha) )
3360                     min_alpha = DBL_MAX;
3361
3362                 for( ; tk < tree_count; tk++ )
3363                 {
3364                     if( ab->data.db[tk] > min_alpha )
3365                         break;
3366                     err[j*tree_count + tk] = root->tree_error;
3367                 }
3368             }
3369         }
3370
3371         for( ti = 0; ti < tree_count; ti++ )
3372         {
3373             double sum_err = 0;
3374             for( j = 0; j < cv_n; j++ )
3375                 sum_err += err[j*tree_count + ti];
3376             if( ti == 0 || sum_err < min_err )
3377             {
3378                 min_err = sum_err;
3379                 min_idx = ti;
3380                 if( use_1se )
3381                     min_err_se = sqrt( sum_err*(n - sum_err) );
3382             }
3383             else if( sum_err < min_err + min_err_se )
3384                 min_idx = ti;
3385         }
3386     }
3387
3388     pruned_tree_idx = min_idx;
3389     free_prune_data(data->params.truncate_pruned_tree != 0);
3390
3391     __END__;
3392
3393     cvReleaseMat( &err_jk );
3394     cvReleaseMat( &ab );
3395     cvReleaseMat( &temp );
3396 }
3397
3398
3399 double CvDTree::update_tree_rnc( int T, int fold )
3400 {
3401     CvDTreeNode* node = root;
3402     double min_alpha = DBL_MAX;
3403
3404     for(;;)
3405     {
3406         CvDTreeNode* parent;
3407         for(;;)
3408         {
3409             int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn;
3410             if( t <= T || !node->left )
3411             {
3412                 node->complexity = 1;
3413                 node->tree_risk = node->node_risk;
3414                 node->tree_error = 0.;
3415                 if( fold >= 0 )
3416                 {
3417                     node->tree_risk = node->cv_node_risk[fold];
3418                     node->tree_error = node->cv_node_error[fold];
3419                 }
3420                 break;
3421             }
3422             node = node->left;
3423         }
3424
3425         for( parent = node->parent; parent && parent->right == node;
3426             node = parent, parent = parent->parent )
3427         {
3428             parent->complexity += node->complexity;
3429             parent->tree_risk += node->tree_risk;
3430             parent->tree_error += node->tree_error;
3431
3432             parent->alpha = ((fold >= 0 ? parent->cv_node_risk[fold] : parent->node_risk)
3433                 - parent->tree_risk)/(parent->complexity - 1);
3434             min_alpha = MIN( min_alpha, parent->alpha );
3435         }
3436
3437         if( !parent )
3438             break;
3439
3440         parent->complexity = node->complexity;
3441         parent->tree_risk = node->tree_risk;
3442         parent->tree_error = node->tree_error;
3443         node = parent->right;
3444     }
3445
3446     return min_alpha;
3447 }
3448
3449
3450 int CvDTree::cut_tree( int T, int fold, double min_alpha )
3451 {
3452     CvDTreeNode* node = root;
3453     if( !node->left )
3454         return 1;
3455
3456     for(;;)
3457     {
3458         CvDTreeNode* parent;
3459         for(;;)
3460         {
3461             int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn;
3462             if( t <= T || !node->left )
3463                 break;
3464             if( node->alpha <= min_alpha + FLT_EPSILON )
3465             {
3466                 if( fold >= 0 )
3467                     node->cv_Tn[fold] = T;
3468                 else
3469                     node->Tn = T;
3470                 if( node == root )
3471                     return 1;
3472                 break;
3473             }
3474             node = node->left;
3475         }
3476
3477         for( parent = node->parent; parent && parent->right == node;
3478             node = parent, parent = parent->parent )
3479             ;
3480
3481         if( !parent )
3482             break;
3483
3484         node = parent->right;
3485     }
3486
3487     return 0;
3488 }
3489
3490
3491 void CvDTree::free_prune_data(bool cut_tree)
3492 {
3493     CvDTreeNode* node = root;
3494
3495     for(;;)
3496     {
3497         CvDTreeNode* parent;
3498         for(;;)
3499         {
3500             // do not call cvSetRemoveByPtr( cv_heap, node->cv_Tn )
3501             // as we will clear the whole cross-validation heap at the end
3502             node->cv_Tn = 0;
3503             node->cv_node_error = node->cv_node_risk = 0;
3504             if( !node->left )
3505                 break;
3506             node = node->left;
3507         }
3508
3509         for( parent = node->parent; parent && parent->right == node;
3510             node = parent, parent = parent->parent )
3511         {
3512             if( cut_tree && parent->Tn <= pruned_tree_idx )
3513             {
3514                 data->free_node( parent->left );
3515                 data->free_node( parent->right );
3516                 parent->left = parent->right = 0;
3517             }
3518         }
3519
3520         if( !parent )
3521             break;
3522
3523         node = parent->right;
3524     }
3525
3526     if( data->cv_heap )
3527         cvClearSet( data->cv_heap );
3528 }
3529
3530
3531 void CvDTree::free_tree()
3532 {
3533     if( root && data && data->shared )
3534     {
3535         pruned_tree_idx = INT_MIN;
3536         free_prune_data(true);
3537         data->free_node(root);
3538         root = 0;
3539     }
3540 }
3541
3542 CvDTreeNode* CvDTree::predict( const CvMat* _sample,
3543     const CvMat* _missing, bool preprocessed_input ) const
3544 {
3545     CvDTreeNode* result = 0;
3546     int* catbuf = 0;
3547
3548     CV_FUNCNAME( "CvDTree::predict" );
3549
3550     __BEGIN__;
3551
3552     int i, step, mstep = 0;
3553     const float* sample;
3554     const uchar* m = 0;
3555     CvDTreeNode* node = root;
3556     const int* vtype;
3557     const int* vidx;
3558     const int* cmap;
3559     const int* cofs;
3560
3561     if( !node )
3562         CV_ERROR( CV_StsError, "The tree has not been trained yet" );
3563
3564     if( !CV_IS_MAT(_sample) || CV_MAT_TYPE(_sample->type) != CV_32FC1 ||
3565         (_sample->cols != 1 && _sample->rows != 1) ||
3566         (_sample->cols + _sample->rows - 1 != data->var_all && !preprocessed_input) ||
3567         (_sample->cols + _sample->rows - 1 != data->var_count && preprocessed_input) )
3568             CV_ERROR( CV_StsBadArg,
3569         "the input sample must be 1d floating-point vector with the same "
3570         "number of elements as the total number of variables used for training" );
3571
3572     sample = _sample->data.fl;
3573     step = CV_IS_MAT_CONT(_sample->type) ? 1 : _sample->step/sizeof(sample[0]);
3574
3575     if( data->cat_count && !preprocessed_input ) // cache for categorical variables
3576     {
3577         int n = data->cat_count->cols;
3578         catbuf = (int*)cvStackAlloc(n*sizeof(catbuf[0]));
3579         for( i = 0; i < n; i++ )
3580             catbuf[i] = -1;
3581     }
3582
3583     if( _missing )
3584     {
3585         if( !CV_IS_MAT(_missing) || !CV_IS_MASK_ARR(_missing) ||
3586         !CV_ARE_SIZES_EQ(_missing, _sample) )
3587             CV_ERROR( CV_StsBadArg,
3588         "the missing data mask must be 8-bit vector of the same size as input sample" );
3589         m = _missing->data.ptr;
3590         mstep = CV_IS_MAT_CONT(_missing->type) ? 1 : _missing->step/sizeof(m[0]);
3591     }
3592
3593     vtype = data->var_type->data.i;
3594     vidx = data->var_idx && !preprocessed_input ? data->var_idx->data.i : 0;
3595     cmap = data->cat_map ? data->cat_map->data.i : 0;
3596     cofs = data->cat_ofs ? data->cat_ofs->data.i : 0;
3597
3598     while( node->Tn > pruned_tree_idx && node->left )
3599     {
3600         CvDTreeSplit* split = node->split;
3601         int dir = 0;
3602         for( ; !dir && split != 0; split = split->next )
3603         {
3604             int vi = split->var_idx;
3605             int ci = vtype[vi];
3606             i = vidx ? vidx[vi] : vi;
3607             float val = sample[i*step];
3608             if( m && m[i*mstep] )
3609                 continue;
3610             if( ci < 0 ) // ordered
3611                 dir = val <= split->ord.c ? -1 : 1;
3612             else // categorical
3613             {
3614                 int c;
3615                 if( preprocessed_input )
3616                     c = cvRound(val);
3617                 else
3618                 {
3619                     c = catbuf[ci];
3620                     if( c < 0 )
3621                     {
3622                         int a = c = cofs[ci];
3623                         int b = (ci+1 >= data->cat_ofs->cols) ? data->cat_map->cols : cofs[ci+1];
3624                         
3625                         int ival = cvRound(val);
3626                         if( ival != val )
3627                             CV_ERROR( CV_StsBadArg,
3628                             "one of input categorical variable is not an integer" );
3629                         
3630                         int sh = 0;
3631                         while( a < b )
3632                         {
3633                             sh++;
3634                             c = (a + b) >> 1;
3635                             if( ival < cmap[c] )
3636                                 b = c;
3637                             else if( ival > cmap[c] )
3638                                 a = c+1;
3639                             else
3640                                 break;
3641                         }
3642
3643                         if( c < 0 || ival != cmap[c] )
3644                             continue;
3645
3646                         catbuf[ci] = c -= cofs[ci];
3647                     }
3648                 }
3649                 c = ( (c == 65535) && data->is_buf_16u ) ? -1 : c;
3650                 dir = CV_DTREE_CAT_DIR(c, split->subset);
3651             }
3652
3653             if( split->inversed )
3654                 dir = -dir;
3655         }
3656
3657         if( !dir )
3658         {
3659             double diff = node->right->sample_count - node->left->sample_count;
3660             dir = diff < 0 ? -1 : 1;
3661         }
3662         node = dir < 0 ? node->left : node->right;
3663     }
3664
3665     result = node;
3666
3667     __END__;
3668
3669     return result;
3670 }
3671
3672
3673 CvDTreeNode* CvDTree::predict( const Mat& _sample, const Mat& _missing, bool preprocessed_input ) const
3674 {
3675     CvMat sample = _sample, mmask = _missing;
3676     return predict(&sample, mmask.data.ptr ? &mmask : 0, preprocessed_input);
3677 }
3678
3679
3680 const CvMat* CvDTree::get_var_importance()
3681 {
3682     if( !var_importance )
3683     {
3684         CvDTreeNode* node = root;
3685         double* importance;
3686         if( !node )
3687             return 0;
3688         var_importance = cvCreateMat( 1, data->var_count, CV_64F );
3689         cvZero( var_importance );
3690         importance = var_importance->data.db;
3691
3692         for(;;)
3693         {
3694             CvDTreeNode* parent;
3695             for( ;; node = node->left )
3696             {
3697                 CvDTreeSplit* split = node->split;
3698
3699                 if( !node->left || node->Tn <= pruned_tree_idx )
3700                     break;
3701
3702                 for( ; split != 0; split = split->next )
3703                     importance[split->var_idx] += split->quality;
3704             }
3705
3706             for( parent = node->parent; parent && parent->right == node;
3707                 node = parent, parent = parent->parent )
3708                 ;
3709
3710             if( !parent )
3711                 break;
3712
3713             node = parent->right;
3714         }
3715
3716         cvNormalize( var_importance, var_importance, 1., 0, CV_L1 );
3717     }
3718
3719     return var_importance;
3720 }
3721
3722
3723 void CvDTree::write_split( CvFileStorage* fs, CvDTreeSplit* split ) const
3724 {
3725     int ci;
3726
3727     cvStartWriteStruct( fs, 0, CV_NODE_MAP + CV_NODE_FLOW );
3728     cvWriteInt( fs, "var", split->var_idx );
3729     cvWriteReal( fs, "quality", split->quality );
3730
3731     ci = data->get_var_type(split->var_idx);
3732     if( ci >= 0 ) // split on a categorical var
3733     {
3734         int i, n = data->cat_count->data.i[ci], to_right = 0, default_dir;
3735         for( i = 0; i < n; i++ )
3736             to_right += CV_DTREE_CAT_DIR(i,split->subset) > 0;
3737
3738         // ad-hoc rule when to use inverse categorical split notation
3739         // to achieve more compact and clear representation
3740         default_dir = to_right <= 1 || to_right <= MIN(3, n/2) || to_right <= n/3 ? -1 : 1;
3741
3742         cvStartWriteStruct( fs, default_dir*(split->inversed ? -1 : 1) > 0 ?
3743                             "in" : "not_in", CV_NODE_SEQ+CV_NODE_FLOW );
3744
3745         for( i = 0; i < n; i++ )
3746         {
3747             int dir = CV_DTREE_CAT_DIR(i,split->subset);
3748             if( dir*default_dir < 0 )
3749                 cvWriteInt( fs, 0, i );
3750         }
3751         cvEndWriteStruct( fs );
3752     }
3753     else
3754         cvWriteReal( fs, !split->inversed ? "le" : "gt", split->ord.c );
3755
3756     cvEndWriteStruct( fs );
3757 }
3758
3759
3760 void CvDTree::write_node( CvFileStorage* fs, CvDTreeNode* node ) const
3761 {
3762     CvDTreeSplit* split;
3763
3764     cvStartWriteStruct( fs, 0, CV_NODE_MAP );
3765
3766     cvWriteInt( fs, "depth", node->depth );
3767     cvWriteInt( fs, "sample_count", node->sample_count );
3768     cvWriteReal( fs, "value", node->value );
3769
3770     if( data->is_classifier )
3771         cvWriteInt( fs, "norm_class_idx", node->class_idx );
3772
3773     cvWriteInt( fs, "Tn", node->Tn );
3774     cvWriteInt( fs, "complexity", node->complexity );
3775     cvWriteReal( fs, "alpha", node->alpha );
3776     cvWriteReal( fs, "node_risk", node->node_risk );
3777     cvWriteReal( fs, "tree_risk", node->tree_risk );
3778     cvWriteReal( fs, "tree_error", node->tree_error );
3779
3780     if( node->left )
3781     {
3782         cvStartWriteStruct( fs, "splits", CV_NODE_SEQ );
3783
3784         for( split = node->split; split != 0; split = split->next )
3785             write_split( fs, split );
3786
3787         cvEndWriteStruct( fs );
3788     }
3789
3790     cvEndWriteStruct( fs );
3791 }
3792
3793
3794 void CvDTree::write_tree_nodes( CvFileStorage* fs ) const
3795 {
3796     //CV_FUNCNAME( "CvDTree::write_tree_nodes" );
3797
3798     __BEGIN__;
3799
3800     CvDTreeNode* node = root;
3801
3802     // traverse the tree and save all the nodes in depth-first order
3803     for(;;)
3804     {
3805         CvDTreeNode* parent;
3806         for(;;)
3807         {
3808             write_node( fs, node );
3809             if( !node->left )
3810                 break;
3811             node = node->left;
3812         }
3813
3814         for( parent = node->parent; parent && parent->right == node;
3815             node = parent, parent = parent->parent )
3816             ;
3817
3818         if( !parent )
3819             break;
3820
3821         node = parent->right;
3822     }
3823
3824     __END__;
3825 }
3826
3827
3828 void CvDTree::write( CvFileStorage* fs, const char* name ) const
3829 {
3830     //CV_FUNCNAME( "CvDTree::write" );
3831
3832     __BEGIN__;
3833
3834     cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_TREE );
3835
3836     //get_var_importance();
3837     data->write_params( fs );
3838     //if( var_importance )
3839     //cvWrite( fs, "var_importance", var_importance );
3840     write( fs );
3841
3842     cvEndWriteStruct( fs );
3843
3844     __END__;
3845 }
3846
3847
3848 void CvDTree::write( CvFileStorage* fs ) const
3849 {
3850     //CV_FUNCNAME( "CvDTree::write" );
3851
3852     __BEGIN__;
3853
3854     cvWriteInt( fs, "best_tree_idx", pruned_tree_idx );
3855
3856     cvStartWriteStruct( fs, "nodes", CV_NODE_SEQ );
3857     write_tree_nodes( fs );
3858     cvEndWriteStruct( fs );
3859
3860     __END__;
3861 }
3862
3863
3864 CvDTreeSplit* CvDTree::read_split( CvFileStorage* fs, CvFileNode* fnode )
3865 {
3866     CvDTreeSplit* split = 0;
3867
3868     CV_FUNCNAME( "CvDTree::read_split" );
3869
3870     __BEGIN__;
3871
3872     int vi, ci;
3873
3874     if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP )
3875         CV_ERROR( CV_StsParseError, "some of the splits are not stored properly" );
3876
3877     vi = cvReadIntByName( fs, fnode, "var", -1 );
3878     if( (unsigned)vi >= (unsigned)data->var_count )
3879         CV_ERROR( CV_StsOutOfRange, "Split variable index is out of range" );
3880
3881     ci = data->get_var_type(vi);
3882     if( ci >= 0 ) // split on categorical var
3883     {
3884         int i, n = data->cat_count->data.i[ci], inversed = 0, val;
3885         CvSeqReader reader;
3886         CvFileNode* inseq;
3887         split = data->new_split_cat( vi, 0 );
3888         inseq = cvGetFileNodeByName( fs, fnode, "in" );
3889         if( !inseq )
3890         {
3891             inseq = cvGetFileNodeByName( fs, fnode, "not_in" );
3892             inversed = 1;
3893         }
3894         if( !inseq ||
3895             (CV_NODE_TYPE(inseq->tag) != CV_NODE_SEQ && CV_NODE_TYPE(inseq->tag) != CV_NODE_INT))
3896             CV_ERROR( CV_StsParseError,
3897             "Either 'in' or 'not_in' tags should be inside a categorical split data" );
3898
3899         if( CV_NODE_TYPE(inseq->tag) == CV_NODE_INT )
3900         {
3901             val = inseq->data.i;
3902             if( (unsigned)val >= (unsigned)n )
3903                 CV_ERROR( CV_StsOutOfRange, "some of in/not_in elements are out of range" );
3904
3905             split->subset[val >> 5] |= 1 << (val & 31);
3906         }
3907         else
3908         {
3909             cvStartReadSeq( inseq->data.seq, &reader );
3910
3911             for( i = 0; i < reader.seq->total; i++ )
3912             {
3913                 CvFileNode* inode = (CvFileNode*)reader.ptr;
3914                 val = inode->data.i;
3915                 if( CV_NODE_TYPE(inode->tag) != CV_NODE_INT || (unsigned)val >= (unsigned)n )
3916                     CV_ERROR( CV_StsOutOfRange, "some of in/not_in elements are out of range" );
3917
3918                 split->subset[val >> 5] |= 1 << (val & 31);
3919                 CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
3920             }
3921         }
3922
3923         // for categorical splits we do not use inversed splits,
3924         // instead we inverse the variable set in the split
3925         if( inversed )
3926             for( i = 0; i < (n + 31) >> 5; i++ )
3927                 split->subset[i] ^= -1;
3928     }
3929     else
3930     {
3931         CvFileNode* cmp_node;
3932         split = data->new_split_ord( vi, 0, 0, 0, 0 );
3933
3934         cmp_node = cvGetFileNodeByName( fs, fnode, "le" );
3935         if( !cmp_node )
3936         {
3937             cmp_node = cvGetFileNodeByName( fs, fnode, "gt" );
3938             split->inversed = 1;
3939         }
3940
3941         split->ord.c = (float)cvReadReal( cmp_node );
3942     }
3943
3944     split->quality = (float)cvReadRealByName( fs, fnode, "quality" );
3945
3946     __END__;
3947
3948     return split;
3949 }
3950
3951
3952 CvDTreeNode* CvDTree::read_node( CvFileStorage* fs, CvFileNode* fnode, CvDTreeNode* parent )
3953 {
3954     CvDTreeNode* node = 0;
3955
3956     CV_FUNCNAME( "CvDTree::read_node" );
3957
3958     __BEGIN__;
3959
3960     CvFileNode* splits;
3961     int i, depth;
3962
3963     if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP )
3964         CV_ERROR( CV_StsParseError, "some of the tree elements are not stored properly" );
3965
3966     CV_CALL( node = data->new_node( parent, 0, 0, 0 ));
3967     depth = cvReadIntByName( fs, fnode, "depth", -1 );
3968     if( depth != node->depth )
3969         CV_ERROR( CV_StsParseError, "incorrect node depth" );
3970
3971     node->sample_count = cvReadIntByName( fs, fnode, "sample_count" );
3972     node->value = cvReadRealByName( fs, fnode, "value" );
3973     if( data->is_classifier )
3974         node->class_idx = cvReadIntByName( fs, fnode, "norm_class_idx" );
3975
3976     node->Tn = cvReadIntByName( fs, fnode, "Tn" );
3977     node->complexity = cvReadIntByName( fs, fnode, "complexity" );
3978     node->alpha = cvReadRealByName( fs, fnode, "alpha" );
3979     node->node_risk = cvReadRealByName( fs, fnode, "node_risk" );
3980     node->tree_risk = cvReadRealByName( fs, fnode, "tree_risk" );
3981     node->tree_error = cvReadRealByName( fs, fnode, "tree_error" );
3982
3983     splits = cvGetFileNodeByName( fs, fnode, "splits" );
3984     if( splits )
3985     {
3986         CvSeqReader reader;
3987         CvDTreeSplit* last_split = 0;
3988
3989         if( CV_NODE_TYPE(splits->tag) != CV_NODE_SEQ )
3990             CV_ERROR( CV_StsParseError, "splits tag must stored as a sequence" );
3991
3992         cvStartReadSeq( splits->data.seq, &reader );
3993         for( i = 0; i < reader.seq->total; i++ )
3994         {
3995             CvDTreeSplit* split;
3996             CV_CALL( split = read_split( fs, (CvFileNode*)reader.ptr ));
3997             if( !last_split )
3998                 node->split = last_split = split;
3999             else
4000                 last_split = last_split->next = split;
4001
4002             CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
4003         }
4004     }
4005
4006     __END__;
4007
4008     return node;
4009 }
4010
4011
4012 void CvDTree::read_tree_nodes( CvFileStorage* fs, CvFileNode* fnode )
4013 {
4014     CV_FUNCNAME( "CvDTree::read_tree_nodes" );
4015
4016     __BEGIN__;
4017
4018     CvSeqReader reader;
4019     CvDTreeNode _root;
4020     CvDTreeNode* parent = &_root;
4021     int i;
4022     parent->left = parent->right = parent->parent = 0;
4023
4024     cvStartReadSeq( fnode->data.seq, &reader );
4025
4026     for( i = 0; i < reader.seq->total; i++ )
4027     {
4028         CvDTreeNode* node;
4029
4030         CV_CALL( node = read_node( fs, (CvFileNode*)reader.ptr, parent != &_root ? parent : 0 ));
4031         if( !parent->left )
4032             parent->left = node;
4033         else
4034             parent->right = node;
4035         if( node->split )
4036             parent = node;
4037         else
4038         {
4039             while( parent && parent->right )
4040                 parent = parent->parent;
4041         }
4042
4043         CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
4044     }
4045
4046     root = _root.left;
4047
4048     __END__;
4049 }
4050
4051
4052 void CvDTree::read( CvFileStorage* fs, CvFileNode* fnode )
4053 {
4054     CvDTreeTrainData* _data = new CvDTreeTrainData();
4055     _data->read_params( fs, fnode );
4056
4057     read( fs, fnode, _data );
4058     get_var_importance();
4059 }
4060
4061
4062 // a special entry point for reading weak decision trees from the tree ensembles
4063 void CvDTree::read( CvFileStorage* fs, CvFileNode* node, CvDTreeTrainData* _data )
4064 {
4065     CV_FUNCNAME( "CvDTree::read" );
4066
4067     __BEGIN__;
4068
4069     CvFileNode* tree_nodes;
4070
4071     clear();
4072     data = _data;
4073
4074     tree_nodes = cvGetFileNodeByName( fs, node, "nodes" );
4075     if( !tree_nodes || CV_NODE_TYPE(tree_nodes->tag) != CV_NODE_SEQ )
4076         CV_ERROR( CV_StsParseError, "nodes tag is missing" );
4077
4078     pruned_tree_idx = cvReadIntByName( fs, node, "best_tree_idx", -1 );
4079     read_tree_nodes( fs, tree_nodes );
4080
4081     __END__;
4082 }
4083
4084 /* End of file. */