Update to 2.0.0 tree from current Fremantle build
[opencv] / apps / haartraining / cvboost.cpp
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                        Intel License Agreement
11 //                For Open Source Computer Vision Library
12 //
13 // Copyright (C) 2000, Intel Corporation, all rights reserved.
14 // Third party copyrights are property of their respective owners.
15 //
16 // Redistribution and use in source and binary forms, with or without modification,
17 // are permitted provided that the following conditions are met:
18 //
19 //   * Redistribution's of source code must retain the above copyright notice,
20 //     this list of conditions and the following disclaimer.
21 //
22 //   * Redistribution's in binary form must reproduce the above copyright notice,
23 //     this list of conditions and the following disclaimer in the documentation
24 //     and/or other materials provided with the distribution.
25 //
26 //   * The name of Intel Corporation may not be used to endorse or promote products
27 //     derived from this software without specific prior written permission.
28 //
29 // This software is provided by the copyright holders and contributors "as is" and
30 // any express or implied warranties, including, but not limited to, the implied
31 // warranties of merchantability and fitness for a particular purpose are disclaimed.
32 // In no event shall the Intel Corporation or contributors be liable for any direct,
33 // indirect, incidental, special, exemplary, or consequential damages
34 // (including, but not limited to, procurement of substitute goods or services;
35 // loss of use, data, or profits; or business interruption) however caused
36 // and on any theory of liability, whether in contract, strict liability,
37 // or tort (including negligence or otherwise) arising in any way out of
38 // the use of this software, even if advised of the possibility of such damage.
39 //
40 //M*/
41
42 #ifdef HAVE_CONFIG_H
43   #include "cvconfig.h"
44 #endif
45
46 #ifdef HAVE_MALLOC_H
47   #include <malloc.h>
48 #endif
49
50 #ifdef HAVE_MEMORY_H
51   #include <memory.h>
52 #endif
53
54 #ifdef _OPENMP
55   #include <omp.h>
56 #endif /* _OPENMP */
57
58 #include <cstdio>
59 #include <cfloat>
60 #include <cmath>
61 #include <ctime>
62 #include <climits>
63
64 #include "_cvcommon.h"
65 #include "cvclassifier.h"
66
67
68 #define CV_BOOST_IMPL
69
70 typedef struct CvValArray
71 {
72     uchar* data;
73     size_t step;
74 } CvValArray;
75
76 #define CMP_VALUES( idx1, idx2 )                                 \
77     ( *( (float*) (aux->data + ((int) (idx1)) * aux->step ) ) <  \
78       *( (float*) (aux->data + ((int) (idx2)) * aux->step ) ) )
79
80 CV_IMPLEMENT_QSORT_EX( icvSortIndexedValArray_16s, short, CMP_VALUES, CvValArray* )
81
82 CV_IMPLEMENT_QSORT_EX( icvSortIndexedValArray_32s, int,   CMP_VALUES, CvValArray* )
83
84 CV_IMPLEMENT_QSORT_EX( icvSortIndexedValArray_32f, float, CMP_VALUES, CvValArray* )
85
86 CV_BOOST_IMPL
87 void cvGetSortedIndices( CvMat* val, CvMat* idx, int sortcols )
88 {
89     int idxtype = 0;
90     size_t istep = 0;
91     size_t jstep = 0;
92
93     int i = 0;
94     int j = 0;
95
96     CvValArray va;
97
98     CV_Assert( idx != NULL );
99     CV_Assert( val != NULL );
100
101     idxtype = CV_MAT_TYPE( idx->type );
102     CV_Assert( idxtype == CV_16SC1 || idxtype == CV_32SC1 || idxtype == CV_32FC1 );
103     CV_Assert( CV_MAT_TYPE( val->type ) == CV_32FC1 );
104     if( sortcols )
105     {
106         CV_Assert( idx->rows == val->cols );
107         CV_Assert( idx->cols == val->rows );
108         istep = CV_ELEM_SIZE( val->type );
109         jstep = val->step;
110     }
111     else
112     {
113         CV_Assert( idx->rows == val->rows );
114         CV_Assert( idx->cols == val->cols );
115         istep = val->step;
116         jstep = CV_ELEM_SIZE( val->type );
117     }
118
119     va.data = val->data.ptr;
120     va.step = jstep;
121     switch( idxtype )
122     {
123         case CV_16SC1:
124             for( i = 0; i < idx->rows; i++ )
125             {
126                 for( j = 0; j < idx->cols; j++ )
127                 {
128                     CV_MAT_ELEM( *idx, short, i, j ) = (short) j;
129                 }
130                 icvSortIndexedValArray_16s( (short*) (idx->data.ptr + i * idx->step),
131                                             idx->cols, &va );
132                 va.data += istep;
133             }
134             break;
135
136         case CV_32SC1:
137             for( i = 0; i < idx->rows; i++ )
138             {
139                 for( j = 0; j < idx->cols; j++ )
140                 {
141                     CV_MAT_ELEM( *idx, int, i, j ) = j;
142                 }
143                 icvSortIndexedValArray_32s( (int*) (idx->data.ptr + i * idx->step),
144                                             idx->cols, &va );
145                 va.data += istep;
146             }
147             break;
148
149         case CV_32FC1:
150             for( i = 0; i < idx->rows; i++ )
151             {
152                 for( j = 0; j < idx->cols; j++ )
153                 {
154                     CV_MAT_ELEM( *idx, float, i, j ) = (float) j;
155                 }
156                 icvSortIndexedValArray_32f( (float*) (idx->data.ptr + i * idx->step),
157                                             idx->cols, &va );
158                 va.data += istep;
159             }
160             break;
161
162         default:
163             assert( 0 );
164             break;
165     }
166 }
167
168 CV_BOOST_IMPL
169 void cvReleaseStumpClassifier( CvClassifier** classifier )
170 {
171     cvFree( classifier );
172     *classifier = 0;
173 }
174
175 CV_BOOST_IMPL
176 float cvEvalStumpClassifier( CvClassifier* classifier, CvMat* sample )
177 {
178     assert( classifier != NULL );
179     assert( sample != NULL );
180     assert( CV_MAT_TYPE( sample->type ) == CV_32FC1 );
181     
182     if( (CV_MAT_ELEM( (*sample), float, 0,
183             ((CvStumpClassifier*) classifier)->compidx )) <
184         ((CvStumpClassifier*) classifier)->threshold ) 
185         return ((CvStumpClassifier*) classifier)->left;
186     return ((CvStumpClassifier*) classifier)->right;
187 }
188
189 #define ICV_DEF_FIND_STUMP_THRESHOLD( suffix, type, error )                              \
190 CV_BOOST_IMPL int icvFindStumpThreshold_##suffix(                                              \
191         uchar* data, size_t datastep,                                                    \
192         uchar* wdata, size_t wstep,                                                      \
193         uchar* ydata, size_t ystep,                                                      \
194         uchar* idxdata, size_t idxstep, int num,                                         \
195         float* lerror,                                                                   \
196         float* rerror,                                                                   \
197         float* threshold, float* left, float* right,                                     \
198         float* sumw, float* sumwy, float* sumwyy )                                       \
199 {                                                                                        \
200     int found = 0;                                                                       \
201     float wyl  = 0.0F;                                                                   \
202     float wl   = 0.0F;                                                                   \
203     float wyyl = 0.0F;                                                                   \
204     float wyr  = 0.0F;                                                                   \
205     float wr   = 0.0F;                                                                   \
206                                                                                          \
207     float curleft  = 0.0F;                                                               \
208     float curright = 0.0F;                                                               \
209     float* prevval = NULL;                                                               \
210     float* curval  = NULL;                                                               \
211     float curlerror = 0.0F;                                                              \
212     float currerror = 0.0F;                                                              \
213     float wposl;                                                                         \
214     float wposr;                                                                         \
215                                                                                          \
216     int i = 0;                                                                           \
217     int idx = 0;                                                                         \
218                                                                                          \
219     wposl = wposr = 0.0F;                                                                \
220     if( *sumw == FLT_MAX )                                                               \
221     {                                                                                    \
222         /* calculate sums */                                                             \
223         float *y = NULL;                                                                 \
224         float *w = NULL;                                                                 \
225         float wy = 0.0F;                                                                 \
226                                                                                          \
227         *sumw   = 0.0F;                                                                  \
228         *sumwy  = 0.0F;                                                                  \
229         *sumwyy = 0.0F;                                                                  \
230         for( i = 0; i < num; i++ )                                                       \
231         {                                                                                \
232             idx = (int) ( *((type*) (idxdata + i*idxstep)) );                            \
233             w = (float*) (wdata + idx * wstep);                                          \
234             *sumw += *w;                                                                 \
235             y = (float*) (ydata + idx * ystep);                                          \
236             wy = (*w) * (*y);                                                            \
237             *sumwy += wy;                                                                \
238             *sumwyy += wy * (*y);                                                        \
239         }                                                                                \
240     }                                                                                    \
241                                                                                          \
242     for( i = 0; i < num; i++ )                                                           \
243     {                                                                                    \
244         idx = (int) ( *((type*) (idxdata + i*idxstep)) );                                \
245         curval = (float*) (data + idx * datastep);                                       \
246          /* for debug purpose */                                                         \
247         if( i > 0 ) assert( (*prevval) <= (*curval) );                                   \
248                                                                                          \
249         wyr  = *sumwy - wyl;                                                             \
250         wr   = *sumw  - wl;                                                              \
251                                                                                          \
252         if( wl > 0.0 ) curleft = wyl / wl;                                               \
253         else curleft = 0.0F;                                                             \
254                                                                                          \
255         if( wr > 0.0 ) curright = wyr / wr;                                              \
256         else curright = 0.0F;                                                            \
257                                                                                          \
258         error                                                                            \
259                                                                                          \
260         if( curlerror + currerror < (*lerror) + (*rerror) )                              \
261         {                                                                                \
262             (*lerror) = curlerror;                                                       \
263             (*rerror) = currerror;                                                       \
264             *threshold = *curval;                                                        \
265             if( i > 0 ) {                                                                \
266                 *threshold = 0.5F * (*threshold + *prevval);                             \
267             }                                                                            \
268             *left  = curleft;                                                            \
269             *right = curright;                                                           \
270             found = 1;                                                                   \
271         }                                                                                \
272                                                                                          \
273         do                                                                               \
274         {                                                                                \
275             wl  += *((float*) (wdata + idx * wstep));                                    \
276             wyl += (*((float*) (wdata + idx * wstep)))                                   \
277                 * (*((float*) (ydata + idx * ystep)));                                   \
278             wyyl += *((float*) (wdata + idx * wstep))                                    \
279                 * (*((float*) (ydata + idx * ystep)))                                    \
280                 * (*((float*) (ydata + idx * ystep)));                                   \
281         }                                                                                \
282         while( (++i) < num &&                                                            \
283             ( *((float*) (data + (idx =                                                  \
284                 (int) ( *((type*) (idxdata + i*idxstep))) ) * datastep))                 \
285                 == *curval ) );                                                          \
286         --i;                                                                             \
287         prevval = curval;                                                                \
288     } /* for each value */                                                               \
289                                                                                          \
290     return found;                                                                        \
291 }
292
293 /* misclassification error
294  * err = MIN( wpos, wneg );
295  */
296 #define ICV_DEF_FIND_STUMP_THRESHOLD_MISC( suffix, type )                                \
297     ICV_DEF_FIND_STUMP_THRESHOLD( misc_##suffix, type,                                   \
298         wposl = 0.5F * ( wl + wyl );                                                     \
299         wposr = 0.5F * ( wr + wyr );                                                     \
300         curleft = 0.5F * ( 1.0F + curleft );                                             \
301         curright = 0.5F * ( 1.0F + curright );                                           \
302         curlerror = MIN( wposl, wl - wposl );                                            \
303         currerror = MIN( wposr, wr - wposr );                                            \
304     )
305
306 /* gini error
307  * err = 2 * wpos * wneg /(wpos + wneg)
308  */
309 #define ICV_DEF_FIND_STUMP_THRESHOLD_GINI( suffix, type )                                \
310     ICV_DEF_FIND_STUMP_THRESHOLD( gini_##suffix, type,                                   \
311         wposl = 0.5F * ( wl + wyl );                                                     \
312         wposr = 0.5F * ( wr + wyr );                                                     \
313         curleft = 0.5F * ( 1.0F + curleft );                                             \
314         curright = 0.5F * ( 1.0F + curright );                                           \
315         curlerror = 2.0F * wposl * ( 1.0F - curleft );                                   \
316         currerror = 2.0F * wposr * ( 1.0F - curright );                                  \
317     )
318
319 #define CV_ENTROPY_THRESHOLD FLT_MIN
320
321 /* entropy error
322  * err = - wpos * log(wpos / (wpos + wneg)) - wneg * log(wneg / (wpos + wneg))
323  */
324 #define ICV_DEF_FIND_STUMP_THRESHOLD_ENTROPY( suffix, type )                             \
325     ICV_DEF_FIND_STUMP_THRESHOLD( entropy_##suffix, type,                                \
326         wposl = 0.5F * ( wl + wyl );                                                     \
327         wposr = 0.5F * ( wr + wyr );                                                     \
328         curleft = 0.5F * ( 1.0F + curleft );                                             \
329         curright = 0.5F * ( 1.0F + curright );                                           \
330         curlerror = currerror = 0.0F;                                                    \
331         if( curleft > CV_ENTROPY_THRESHOLD )                                             \
332             curlerror -= wposl * logf( curleft );                                        \
333         if( curleft < 1.0F - CV_ENTROPY_THRESHOLD )                                      \
334             curlerror -= (wl - wposl) * logf( 1.0F - curleft );                          \
335                                                                                          \
336         if( curright > CV_ENTROPY_THRESHOLD )                                            \
337             currerror -= wposr * logf( curright );                                       \
338         if( curright < 1.0F - CV_ENTROPY_THRESHOLD )                                     \
339             currerror -= (wr - wposr) * logf( 1.0F - curright );                         \
340     )
341
342 /* least sum of squares error */
343 #define ICV_DEF_FIND_STUMP_THRESHOLD_SQ( suffix, type )                                  \
344     ICV_DEF_FIND_STUMP_THRESHOLD( sq_##suffix, type,                                     \
345         /* calculate error (sum of squares)          */                                  \
346         /* err = sum( w * (y - left(rigt)Val)^2 )    */                                  \
347         curlerror = wyyl + curleft * curleft * wl - 2.0F * curleft * wyl;                \
348         currerror = (*sumwyy) - wyyl + curright * curright * wr - 2.0F * curright * wyr; \
349     )
350
351 ICV_DEF_FIND_STUMP_THRESHOLD_MISC( 16s, short )
352
353 ICV_DEF_FIND_STUMP_THRESHOLD_MISC( 32s, int )
354
355 ICV_DEF_FIND_STUMP_THRESHOLD_MISC( 32f, float )
356
357
358 ICV_DEF_FIND_STUMP_THRESHOLD_GINI( 16s, short )
359
360 ICV_DEF_FIND_STUMP_THRESHOLD_GINI( 32s, int )
361
362 ICV_DEF_FIND_STUMP_THRESHOLD_GINI( 32f, float )
363
364
365 ICV_DEF_FIND_STUMP_THRESHOLD_ENTROPY( 16s, short )
366
367 ICV_DEF_FIND_STUMP_THRESHOLD_ENTROPY( 32s, int )
368
369 ICV_DEF_FIND_STUMP_THRESHOLD_ENTROPY( 32f, float )
370
371
372 ICV_DEF_FIND_STUMP_THRESHOLD_SQ( 16s, short )
373
374 ICV_DEF_FIND_STUMP_THRESHOLD_SQ( 32s, int )
375
376 ICV_DEF_FIND_STUMP_THRESHOLD_SQ( 32f, float )
377
378 typedef int (*CvFindThresholdFunc)( uchar* data, size_t datastep,
379                                     uchar* wdata, size_t wstep,
380                                     uchar* ydata, size_t ystep,
381                                     uchar* idxdata, size_t idxstep, int num,
382                                     float* lerror,
383                                     float* rerror,
384                                     float* threshold, float* left, float* right,
385                                     float* sumw, float* sumwy, float* sumwyy );
386
387 CvFindThresholdFunc findStumpThreshold_16s[4] = {
388         icvFindStumpThreshold_misc_16s,
389         icvFindStumpThreshold_gini_16s,
390         icvFindStumpThreshold_entropy_16s,
391         icvFindStumpThreshold_sq_16s
392     };
393
394 CvFindThresholdFunc findStumpThreshold_32s[4] = {
395         icvFindStumpThreshold_misc_32s,
396         icvFindStumpThreshold_gini_32s,
397         icvFindStumpThreshold_entropy_32s,
398         icvFindStumpThreshold_sq_32s
399     };
400
401 CvFindThresholdFunc findStumpThreshold_32f[4] = {
402         icvFindStumpThreshold_misc_32f,
403         icvFindStumpThreshold_gini_32f,
404         icvFindStumpThreshold_entropy_32f,
405         icvFindStumpThreshold_sq_32f
406     };
407
408 CV_BOOST_IMPL
409 CvClassifier* cvCreateStumpClassifier( CvMat* trainData,
410                       int flags,
411                       CvMat* trainClasses,
412                       CvMat* /*typeMask*/,
413                       CvMat* missedMeasurementsMask,
414                       CvMat* compIdx,
415                       CvMat* sampleIdx,
416                       CvMat* weights,
417                       CvClassifierTrainParams* trainParams
418                     )
419 {
420     CvStumpClassifier* stump = NULL;
421     int m = 0; /* number of samples */
422     int n = 0; /* number of components */
423     uchar* data = NULL;
424     int cstep   = 0;
425     int sstep   = 0;
426     uchar* ydata = NULL;
427     int ystep    = 0;
428     uchar* idxdata = NULL;
429     int idxstep    = 0;
430     int l = 0; /* number of indices */     
431     uchar* wdata = NULL;
432     int wstep    = 0;
433
434     int* idx = NULL;
435     int i = 0;
436     
437     float sumw   = FLT_MAX;
438     float sumwy  = FLT_MAX;
439     float sumwyy = FLT_MAX;
440
441     CV_Assert( trainData != NULL );
442     CV_Assert( CV_MAT_TYPE( trainData->type ) == CV_32FC1 );
443     CV_Assert( trainClasses != NULL );
444     CV_Assert( CV_MAT_TYPE( trainClasses->type ) == CV_32FC1 );
445     CV_Assert( missedMeasurementsMask == NULL );
446     CV_Assert( compIdx == NULL );
447     CV_Assert( weights != NULL );
448     CV_Assert( CV_MAT_TYPE( weights->type ) == CV_32FC1 );
449     CV_Assert( trainParams != NULL );
450
451     data = trainData->data.ptr;
452     if( CV_IS_ROW_SAMPLE( flags ) )
453     {
454         cstep = CV_ELEM_SIZE( trainData->type );
455         sstep = trainData->step;
456         m = trainData->rows;
457         n = trainData->cols;
458     }
459     else
460     {
461         sstep = CV_ELEM_SIZE( trainData->type );
462         cstep = trainData->step;
463         m = trainData->cols;
464         n = trainData->rows;
465     }
466
467     ydata = trainClasses->data.ptr;
468     if( trainClasses->rows == 1 )
469     {
470         assert( trainClasses->cols == m );
471         ystep = CV_ELEM_SIZE( trainClasses->type );
472     }
473     else
474     {
475         assert( trainClasses->rows == m );
476         ystep = trainClasses->step;
477     }
478
479     wdata = weights->data.ptr;
480     if( weights->rows == 1 )
481     {
482         assert( weights->cols == m );
483         wstep = CV_ELEM_SIZE( weights->type );
484     }
485     else
486     {
487         assert( weights->rows == m );
488         wstep = weights->step;
489     }
490
491     l = m;
492     if( sampleIdx != NULL )
493     {
494         assert( CV_MAT_TYPE( sampleIdx->type ) == CV_32FC1 );
495
496         idxdata = sampleIdx->data.ptr;
497         if( sampleIdx->rows == 1 )
498         {
499             l = sampleIdx->cols;
500             idxstep = CV_ELEM_SIZE( sampleIdx->type );
501         }
502         else
503         {
504             l = sampleIdx->rows;
505             idxstep = sampleIdx->step;
506         }
507         assert( l <= m );
508     }
509
510     idx = (int*) cvAlloc( l * sizeof( int ) );
511     stump = (CvStumpClassifier*) cvAlloc( sizeof( CvStumpClassifier) );
512
513     /* START */
514     memset( (void*) stump, 0, sizeof( CvStumpClassifier ) );
515
516     stump->eval = cvEvalStumpClassifier;
517     stump->tune = NULL;
518     stump->save = NULL;
519     stump->release = cvReleaseStumpClassifier;
520
521     stump->lerror = FLT_MAX;
522     stump->rerror = FLT_MAX;
523     stump->left  = 0.0F;
524     stump->right = 0.0F;
525
526     /* copy indices */
527     if( sampleIdx != NULL )
528     {
529         for( i = 0; i < l; i++ )
530         {
531             idx[i] = (int) *((float*) (idxdata + i*idxstep));
532         }
533     }
534     else
535     {
536         for( i = 0; i < l; i++ )
537         {
538             idx[i] = i;
539         }
540     }
541
542     for( i = 0; i < n; i++ )
543     {
544         CvValArray va;
545
546         va.data = data + i * ((size_t) cstep);
547         va.step = sstep;
548         icvSortIndexedValArray_32s( idx, l, &va );
549         if( findStumpThreshold_32s[(int) ((CvStumpTrainParams*) trainParams)->error]
550               ( data + i * ((size_t) cstep), sstep,
551                 wdata, wstep, ydata, ystep, (uchar*) idx, sizeof( int ), l,
552                 &(stump->lerror), &(stump->rerror),
553                 &(stump->threshold), &(stump->left), &(stump->right), 
554                 &sumw, &sumwy, &sumwyy ) )
555         {
556             stump->compidx = i;
557         }
558     } /* for each component */
559
560     /* END */
561
562     cvFree( &idx );
563
564     if( ((CvStumpTrainParams*) trainParams)->type == CV_CLASSIFICATION_CLASS )
565     {
566         stump->left = 2.0F * (stump->left >= 0.5F) - 1.0F;
567         stump->right = 2.0F * (stump->right >= 0.5F) - 1.0F;
568     }
569
570     return (CvClassifier*) stump;
571 }
572
573 /*
574  * cvCreateMTStumpClassifier
575  *
576  * Multithreaded stump classifier constructor
577  * Includes huge train data support through callback function
578  */
579 CV_BOOST_IMPL
580 CvClassifier* cvCreateMTStumpClassifier( CvMat* trainData,
581                       int flags,
582                       CvMat* trainClasses,
583                       CvMat* /*typeMask*/,
584                       CvMat* missedMeasurementsMask,
585                       CvMat* compIdx,
586                       CvMat* sampleIdx,
587                       CvMat* weights,
588                       CvClassifierTrainParams* trainParams )
589 {
590     CvStumpClassifier* stump = NULL;
591     int m = 0; /* number of samples */
592     int n = 0; /* number of components */
593     uchar* data = NULL;
594     size_t cstep   = 0;
595     size_t sstep   = 0;
596     int    datan   = 0; /* num components */
597     uchar* ydata = NULL;
598     size_t ystep = 0;
599     uchar* idxdata = NULL;
600     size_t idxstep = 0;
601     int    l = 0; /* number of indices */     
602     uchar* wdata = NULL;
603     size_t wstep = 0;
604
605     uchar* sorteddata = NULL;
606     int    sortedtype    = 0;
607     size_t sortedcstep   = 0; /* component step */
608     size_t sortedsstep   = 0; /* sample step */
609     int    sortedn       = 0; /* num components */
610     int    sortedm       = 0; /* num samples */
611
612     char* filter = NULL;
613     int i = 0;
614     
615     int compidx = 0;
616     int stumperror;
617     int portion;
618
619     /* private variables */
620     CvMat mat;
621     CvValArray va;
622     float lerror;
623     float rerror;
624     float left;
625     float right;
626     float threshold;
627     int optcompidx;
628
629     float sumw;
630     float sumwy;
631     float sumwyy;
632
633     int t_compidx;
634     int t_n;
635     
636     int ti;
637     int tj;
638     int tk;
639
640     uchar* t_data;
641     size_t t_cstep;
642     size_t t_sstep;
643
644     size_t matcstep;
645     size_t matsstep;
646
647     int* t_idx;
648     /* end private variables */
649
650     CV_Assert( trainParams != NULL );
651     CV_Assert( trainClasses != NULL );
652     CV_Assert( CV_MAT_TYPE( trainClasses->type ) == CV_32FC1 );
653     CV_Assert( missedMeasurementsMask == NULL );
654     CV_Assert( compIdx == NULL );
655
656     stumperror = (int) ((CvMTStumpTrainParams*) trainParams)->error;
657
658     ydata = trainClasses->data.ptr;
659     if( trainClasses->rows == 1 )
660     {
661         m = trainClasses->cols;
662         ystep = CV_ELEM_SIZE( trainClasses->type );
663     }
664     else
665     {
666         m = trainClasses->rows;
667         ystep = trainClasses->step;
668     }
669
670     wdata = weights->data.ptr;
671     if( weights->rows == 1 )
672     {
673         CV_Assert( weights->cols == m );
674         wstep = CV_ELEM_SIZE( weights->type );
675     }
676     else
677     {
678         CV_Assert( weights->rows == m );
679         wstep = weights->step;
680     }
681
682     if( ((CvMTStumpTrainParams*) trainParams)->sortedIdx != NULL )
683     {
684         sortedtype =
685             CV_MAT_TYPE( ((CvMTStumpTrainParams*) trainParams)->sortedIdx->type );
686         assert( sortedtype == CV_16SC1 || sortedtype == CV_32SC1
687                 || sortedtype == CV_32FC1 );
688         sorteddata = ((CvMTStumpTrainParams*) trainParams)->sortedIdx->data.ptr;
689         sortedsstep = CV_ELEM_SIZE( sortedtype );
690         sortedcstep = ((CvMTStumpTrainParams*) trainParams)->sortedIdx->step;
691         sortedn = ((CvMTStumpTrainParams*) trainParams)->sortedIdx->rows;
692         sortedm = ((CvMTStumpTrainParams*) trainParams)->sortedIdx->cols;
693     }
694
695     if( trainData == NULL )
696     {
697         assert( ((CvMTStumpTrainParams*) trainParams)->getTrainData != NULL );
698         n = ((CvMTStumpTrainParams*) trainParams)->numcomp;
699         assert( n > 0 );
700     }
701     else
702     {
703         assert( CV_MAT_TYPE( trainData->type ) == CV_32FC1 );
704         data = trainData->data.ptr;
705         if( CV_IS_ROW_SAMPLE( flags ) )
706         {
707             cstep = CV_ELEM_SIZE( trainData->type );
708             sstep = trainData->step;
709             assert( m == trainData->rows );
710             datan = n = trainData->cols;
711         }
712         else
713         {
714             sstep = CV_ELEM_SIZE( trainData->type );
715             cstep = trainData->step;
716             assert( m == trainData->cols );
717             datan = n = trainData->rows;
718         }
719         if( ((CvMTStumpTrainParams*) trainParams)->getTrainData != NULL )
720         {
721             n = ((CvMTStumpTrainParams*) trainParams)->numcomp;
722         }        
723     }
724     assert( datan <= n );
725
726     if( sampleIdx != NULL )
727     {
728         assert( CV_MAT_TYPE( sampleIdx->type ) == CV_32FC1 );
729         idxdata = sampleIdx->data.ptr;
730         idxstep = ( sampleIdx->rows == 1 )
731             ? CV_ELEM_SIZE( sampleIdx->type ) : sampleIdx->step;
732         l = ( sampleIdx->rows == 1 ) ? sampleIdx->cols : sampleIdx->rows;
733
734         if( sorteddata != NULL )
735         {
736             filter = (char*) cvAlloc( sizeof( char ) * m );
737             memset( (void*) filter, 0, sizeof( char ) * m );
738             for( i = 0; i < l; i++ )
739             {
740                 filter[(int) *((float*) (idxdata + i * idxstep))] = (char) 1;
741             }
742         }
743     }
744     else
745     {
746         l = m;
747     }
748
749     stump = (CvStumpClassifier*) cvAlloc( sizeof( CvStumpClassifier) );
750
751     /* START */
752     memset( (void*) stump, 0, sizeof( CvStumpClassifier ) );
753
754     portion = ((CvMTStumpTrainParams*)trainParams)->portion;
755     
756     if( portion < 1 )
757     {
758         /* auto portion */
759         portion = n;
760         #ifdef _OPENMP
761         portion /= omp_get_max_threads();        
762         #endif /* _OPENMP */        
763     }
764
765     stump->eval = cvEvalStumpClassifier;
766     stump->tune = NULL;
767     stump->save = NULL;
768     stump->release = cvReleaseStumpClassifier;
769
770     stump->lerror = FLT_MAX;
771     stump->rerror = FLT_MAX;
772     stump->left  = 0.0F;
773     stump->right = 0.0F;
774
775     compidx = 0;
776     #ifdef _OPENMP
777     #pragma omp parallel private(mat, va, lerror, rerror, left, right, threshold, \
778                                  optcompidx, sumw, sumwy, sumwyy, t_compidx, t_n, \
779                                  ti, tj, tk, t_data, t_cstep, t_sstep, matcstep,  \
780                                  matsstep, t_idx)
781     #endif /* _OPENMP */
782     {
783         lerror = FLT_MAX;
784         rerror = FLT_MAX;
785         left  = 0.0F;
786         right = 0.0F;
787         threshold = 0.0F;
788         optcompidx = 0;
789
790         sumw   = FLT_MAX;
791         sumwy  = FLT_MAX;
792         sumwyy = FLT_MAX;
793
794         t_compidx = 0;
795         t_n = 0;
796         
797         ti = 0;
798         tj = 0;
799         tk = 0;
800
801         t_data = NULL;
802         t_cstep = 0;
803         t_sstep = 0;
804
805         matcstep = 0;
806         matsstep = 0;
807
808         t_idx = NULL;
809
810         mat.data.ptr = NULL;
811         
812         if( datan < n )
813         {
814             /* prepare matrix for callback */
815             if( CV_IS_ROW_SAMPLE( flags ) )
816             {
817                 mat = cvMat( m, portion, CV_32FC1, 0 );
818                 matcstep = CV_ELEM_SIZE( mat.type );
819                 matsstep = mat.step;
820             }
821             else
822             {
823                 mat = cvMat( portion, m, CV_32FC1, 0 );
824                 matcstep = mat.step;
825                 matsstep = CV_ELEM_SIZE( mat.type );
826             }
827             mat.data.ptr = (uchar*) cvAlloc( sizeof( float ) * mat.rows * mat.cols );
828         }
829
830         if( filter != NULL || sortedn < n )
831         {
832             t_idx = (int*) cvAlloc( sizeof( int ) * m );
833             if( sortedn == 0 || filter == NULL )
834             {
835                 if( idxdata != NULL )
836                 {
837                     for( ti = 0; ti < l; ti++ )
838                     {
839                         t_idx[ti] = (int) *((float*) (idxdata + ti * idxstep));
840                     }
841                 }
842                 else
843                 {
844                     for( ti = 0; ti < l; ti++ )
845                     {
846                         t_idx[ti] = ti;
847                     }
848                 }                
849             }
850         }
851
852         #ifdef _OPENMP
853         #pragma omp critical(c_compidx)
854         #endif /* _OPENMP */
855         {
856             t_compidx = compidx;
857             compidx += portion;
858         }
859         while( t_compidx < n )
860         {
861             t_n = portion;
862             if( t_compidx < datan )
863             {
864                 t_n = ( t_n < (datan - t_compidx) ) ? t_n : (datan - t_compidx);
865                 t_data = data;
866                 t_cstep = cstep;
867                 t_sstep = sstep;
868             }
869             else
870             {
871                 t_n = ( t_n < (n - t_compidx) ) ? t_n : (n - t_compidx);
872                 t_cstep = matcstep;
873                 t_sstep = matsstep;
874                 t_data = mat.data.ptr - t_compidx * ((size_t) t_cstep );
875
876                 /* calculate components */
877                 ((CvMTStumpTrainParams*)trainParams)->getTrainData( &mat,
878                         sampleIdx, compIdx, t_compidx, t_n,
879                         ((CvMTStumpTrainParams*)trainParams)->userdata );
880             }
881
882             if( sorteddata != NULL )
883             {
884                 if( filter != NULL )
885                 {
886                     /* have sorted indices and filter */
887                     switch( sortedtype )
888                     {
889                         case CV_16SC1:
890                             for( ti = t_compidx; ti < MIN( sortedn, t_compidx + t_n ); ti++ )
891                             {
892                                 tk = 0;
893                                 for( tj = 0; tj < sortedm; tj++ )
894                                 {
895                                     int curidx = (int) ( *((short*) (sorteddata
896                                             + ti * sortedcstep + tj * sortedsstep)) );
897                                     if( filter[curidx] != 0 )
898                                     {
899                                         t_idx[tk++] = curidx;
900                                     }
901                                 }
902                                 if( findStumpThreshold_32s[stumperror]( 
903                                         t_data + ti * t_cstep, t_sstep,
904                                         wdata, wstep, ydata, ystep,
905                                         (uchar*) t_idx, sizeof( int ), tk,
906                                         &lerror, &rerror,
907                                         &threshold, &left, &right, 
908                                         &sumw, &sumwy, &sumwyy ) )
909                                 {
910                                     optcompidx = ti;
911                                 }
912                             }
913                             break;
914                         case CV_32SC1:
915                             for( ti = t_compidx; ti < MIN( sortedn, t_compidx + t_n ); ti++ )
916                             {
917                                 tk = 0;
918                                 for( tj = 0; tj < sortedm; tj++ )
919                                 {
920                                     int curidx = (int) ( *((int*) (sorteddata
921                                             + ti * sortedcstep + tj * sortedsstep)) );
922                                     if( filter[curidx] != 0 )
923                                     {
924                                         t_idx[tk++] = curidx;
925                                     }
926                                 }
927                                 if( findStumpThreshold_32s[stumperror]( 
928                                         t_data + ti * t_cstep, t_sstep,
929                                         wdata, wstep, ydata, ystep,
930                                         (uchar*) t_idx, sizeof( int ), tk,
931                                         &lerror, &rerror,
932                                         &threshold, &left, &right, 
933                                         &sumw, &sumwy, &sumwyy ) )
934                                 {
935                                     optcompidx = ti;
936                                 }
937                             }
938                             break;
939                         case CV_32FC1:
940                             for( ti = t_compidx; ti < MIN( sortedn, t_compidx + t_n ); ti++ )
941                             {
942                                 tk = 0;
943                                 for( tj = 0; tj < sortedm; tj++ )
944                                 {
945                                     int curidx = (int) ( *((float*) (sorteddata
946                                             + ti * sortedcstep + tj * sortedsstep)) );
947                                     if( filter[curidx] != 0 )
948                                     {
949                                         t_idx[tk++] = curidx;
950                                     }
951                                 }
952                                 if( findStumpThreshold_32s[stumperror]( 
953                                         t_data + ti * t_cstep, t_sstep,
954                                         wdata, wstep, ydata, ystep,
955                                         (uchar*) t_idx, sizeof( int ), tk,
956                                         &lerror, &rerror,
957                                         &threshold, &left, &right, 
958                                         &sumw, &sumwy, &sumwyy ) )
959                                 {
960                                     optcompidx = ti;
961                                 }
962                             }
963                             break;
964                         default:
965                             assert( 0 );
966                             break;
967                     }
968                 }
969                 else
970                 {
971                     /* have sorted indices */
972                     switch( sortedtype )
973                     {
974                         case CV_16SC1:
975                             for( ti = t_compidx; ti < MIN( sortedn, t_compidx + t_n ); ti++ )
976                             {
977                                 if( findStumpThreshold_16s[stumperror]( 
978                                         t_data + ti * t_cstep, t_sstep,
979                                         wdata, wstep, ydata, ystep,
980                                         sorteddata + ti * sortedcstep, sortedsstep, sortedm,
981                                         &lerror, &rerror,
982                                         &threshold, &left, &right, 
983                                         &sumw, &sumwy, &sumwyy ) )
984                                 {
985                                     optcompidx = ti;
986                                 }
987                             }
988                             break;
989                         case CV_32SC1:
990                             for( ti = t_compidx; ti < MIN( sortedn, t_compidx + t_n ); ti++ )
991                             {
992                                 if( findStumpThreshold_32s[stumperror]( 
993                                         t_data + ti * t_cstep, t_sstep,
994                                         wdata, wstep, ydata, ystep,
995                                         sorteddata + ti * sortedcstep, sortedsstep, sortedm,
996                                         &lerror, &rerror,
997                                         &threshold, &left, &right, 
998                                         &sumw, &sumwy, &sumwyy ) )
999                                 {
1000                                     optcompidx = ti;
1001                                 }
1002                             }
1003                             break;
1004                         case CV_32FC1:
1005                             for( ti = t_compidx; ti < MIN( sortedn, t_compidx + t_n ); ti++ )
1006                             {
1007                                 if( findStumpThreshold_32f[stumperror]( 
1008                                         t_data + ti * t_cstep, t_sstep,
1009                                         wdata, wstep, ydata, ystep,
1010                                         sorteddata + ti * sortedcstep, sortedsstep, sortedm,
1011                                         &lerror, &rerror,
1012                                         &threshold, &left, &right, 
1013                                         &sumw, &sumwy, &sumwyy ) )
1014                                 {
1015                                     optcompidx = ti;
1016                                 }
1017                             }
1018                             break;
1019                         default:
1020                             assert( 0 );
1021                             break;
1022                     }
1023                 }
1024             }
1025
1026             ti = MAX( t_compidx, MIN( sortedn, t_compidx + t_n ) );
1027             for( ; ti < t_compidx + t_n; ti++ )
1028             {
1029                 va.data = t_data + ti * t_cstep;
1030                 va.step = t_sstep;
1031                 icvSortIndexedValArray_32s( t_idx, l, &va );
1032                 if( findStumpThreshold_32s[stumperror]( 
1033                         t_data + ti * t_cstep, t_sstep,
1034                         wdata, wstep, ydata, ystep,
1035                         (uchar*)t_idx, sizeof( int ), l,
1036                         &lerror, &rerror,
1037                         &threshold, &left, &right, 
1038                         &sumw, &sumwy, &sumwyy ) )
1039                 {
1040                     optcompidx = ti;
1041                 }
1042             }
1043             #ifdef _OPENMP
1044             #pragma omp critical(c_compidx)
1045             #endif /* _OPENMP */
1046             {
1047                 t_compidx = compidx;
1048                 compidx += portion;
1049             }
1050         } /* while have training data */
1051
1052         /* get the best classifier */
1053         #ifdef _OPENMP
1054         #pragma omp critical(c_beststump)
1055         #endif /* _OPENMP */
1056         {
1057             if( lerror + rerror < stump->lerror + stump->rerror )
1058             {
1059                 stump->lerror    = lerror;
1060                 stump->rerror    = rerror;
1061                 stump->compidx   = optcompidx;
1062                 stump->threshold = threshold;
1063                 stump->left      = left;
1064                 stump->right     = right;
1065             }
1066         }
1067
1068         /* free allocated memory */
1069         if( mat.data.ptr != NULL )
1070         {
1071             cvFree( &(mat.data.ptr) );
1072         }
1073         if( t_idx != NULL )
1074         {
1075             cvFree( &t_idx );
1076         }
1077     } /* end of parallel region */
1078
1079     /* END */
1080
1081     /* free allocated memory */
1082     if( filter != NULL )
1083     {
1084         cvFree( &filter );
1085     }
1086
1087     if( ((CvMTStumpTrainParams*) trainParams)->type == CV_CLASSIFICATION_CLASS )
1088     {
1089         stump->left = 2.0F * (stump->left >= 0.5F) - 1.0F;
1090         stump->right = 2.0F * (stump->right >= 0.5F) - 1.0F;
1091     }
1092
1093     return (CvClassifier*) stump;
1094 }
1095
1096 CV_BOOST_IMPL
1097 float cvEvalCARTClassifier( CvClassifier* classifier, CvMat* sample )
1098 {
1099     CV_FUNCNAME( "cvEvalCARTClassifier" );
1100
1101     int idx = 0;
1102
1103     __BEGIN__;
1104
1105
1106     CV_ASSERT( classifier != NULL );
1107     CV_ASSERT( sample != NULL );
1108     CV_ASSERT( CV_MAT_TYPE( sample->type ) == CV_32FC1 );
1109     CV_ASSERT( sample->rows == 1 || sample->cols == 1 );
1110
1111     if( sample->rows == 1 )
1112     {
1113         do
1114         {
1115             if( (CV_MAT_ELEM( (*sample), float, 0,
1116                     ((CvCARTClassifier*) classifier)->compidx[idx] )) <
1117                 ((CvCARTClassifier*) classifier)->threshold[idx] ) 
1118             {
1119                 idx = ((CvCARTClassifier*) classifier)->left[idx];
1120             }
1121             else
1122             {
1123                 idx = ((CvCARTClassifier*) classifier)->right[idx];
1124             }
1125         } while( idx > 0 );
1126     }
1127     else
1128     {
1129         do
1130         {
1131             if( (CV_MAT_ELEM( (*sample), float,
1132                     ((CvCARTClassifier*) classifier)->compidx[idx], 0 )) <
1133                 ((CvCARTClassifier*) classifier)->threshold[idx] ) 
1134             {
1135                 idx = ((CvCARTClassifier*) classifier)->left[idx];
1136             }
1137             else
1138             {
1139                 idx = ((CvCARTClassifier*) classifier)->right[idx];
1140             }
1141         } while( idx > 0 );
1142     } 
1143
1144     __END__;
1145
1146     return ((CvCARTClassifier*) classifier)->val[-idx];
1147 }
1148
1149 CV_BOOST_IMPL
1150 float cvEvalCARTClassifierIdx( CvClassifier* classifier, CvMat* sample )
1151 {
1152     CV_FUNCNAME( "cvEvalCARTClassifierIdx" );
1153
1154     int idx = 0;
1155
1156     __BEGIN__;
1157
1158
1159     CV_ASSERT( classifier != NULL );
1160     CV_ASSERT( sample != NULL );
1161     CV_ASSERT( CV_MAT_TYPE( sample->type ) == CV_32FC1 );
1162     CV_ASSERT( sample->rows == 1 || sample->cols == 1 );
1163
1164     if( sample->rows == 1 )
1165     {
1166         do
1167         {
1168             if( (CV_MAT_ELEM( (*sample), float, 0,
1169                     ((CvCARTClassifier*) classifier)->compidx[idx] )) <
1170                 ((CvCARTClassifier*) classifier)->threshold[idx] ) 
1171             {
1172                 idx = ((CvCARTClassifier*) classifier)->left[idx];
1173             }
1174             else
1175             {
1176                 idx = ((CvCARTClassifier*) classifier)->right[idx];
1177             }
1178         } while( idx > 0 );
1179     }
1180     else
1181     {
1182         do
1183         {
1184             if( (CV_MAT_ELEM( (*sample), float,
1185                     ((CvCARTClassifier*) classifier)->compidx[idx], 0 )) <
1186                 ((CvCARTClassifier*) classifier)->threshold[idx] ) 
1187             {
1188                 idx = ((CvCARTClassifier*) classifier)->left[idx];
1189             }
1190             else
1191             {
1192                 idx = ((CvCARTClassifier*) classifier)->right[idx];
1193             }
1194         } while( idx > 0 );
1195     } 
1196
1197     __END__;
1198
1199     return (float) (-idx);
1200 }
1201
1202 CV_BOOST_IMPL
1203 void cvReleaseCARTClassifier( CvClassifier** classifier )
1204 {
1205     cvFree( classifier );
1206     *classifier = NULL;
1207 }
1208
1209 void CV_CDECL icvDefaultSplitIdx_R( int compidx, float threshold,
1210                                     CvMat* idx, CvMat** left, CvMat** right,
1211                                     void* userdata )
1212 {
1213     CvMat* trainData = (CvMat*) userdata;
1214     int i = 0;
1215
1216     *left = cvCreateMat( 1, trainData->rows, CV_32FC1 );
1217     *right = cvCreateMat( 1, trainData->rows, CV_32FC1 );
1218     (*left)->cols = (*right)->cols = 0;
1219     if( idx == NULL )
1220     {
1221         for( i = 0; i < trainData->rows; i++ )
1222         {
1223             if( CV_MAT_ELEM( *trainData, float, i, compidx ) < threshold )
1224             {
1225                 (*left)->data.fl[(*left)->cols++] = (float) i;
1226             }
1227             else
1228             {
1229                 (*right)->data.fl[(*right)->cols++] = (float) i;
1230             }
1231         }
1232     }
1233     else
1234     {
1235         uchar* idxdata;
1236         int idxnum;
1237         int idxstep;
1238         int index;
1239
1240         idxdata = idx->data.ptr;
1241         idxnum = (idx->rows == 1) ? idx->cols : idx->rows;
1242         idxstep = (idx->rows == 1) ? CV_ELEM_SIZE( idx->type ) : idx->step;
1243         for( i = 0; i < idxnum; i++ )
1244         {
1245             index = (int) *((float*) (idxdata + i * idxstep));
1246             if( CV_MAT_ELEM( *trainData, float, index, compidx ) < threshold )
1247             {
1248                 (*left)->data.fl[(*left)->cols++] = (float) index;
1249             }
1250             else
1251             {
1252                 (*right)->data.fl[(*right)->cols++] = (float) index;
1253             }
1254         }
1255     }
1256 }
1257
1258 void CV_CDECL icvDefaultSplitIdx_C( int compidx, float threshold,
1259                                     CvMat* idx, CvMat** left, CvMat** right,
1260                                     void* userdata )
1261 {
1262     CvMat* trainData = (CvMat*) userdata;
1263     int i = 0;
1264
1265     *left = cvCreateMat( 1, trainData->cols, CV_32FC1 );
1266     *right = cvCreateMat( 1, trainData->cols, CV_32FC1 );
1267     (*left)->cols = (*right)->cols = 0;
1268     if( idx == NULL )
1269     {
1270         for( i = 0; i < trainData->cols; i++ )
1271         {
1272             if( CV_MAT_ELEM( *trainData, float, compidx, i ) < threshold )
1273             {
1274                 (*left)->data.fl[(*left)->cols++] = (float) i;
1275             }
1276             else
1277             {
1278                 (*right)->data.fl[(*right)->cols++] = (float) i;
1279             }
1280         }
1281     }
1282     else
1283     {
1284         uchar* idxdata;
1285         int idxnum;
1286         int idxstep;
1287         int index;
1288
1289         idxdata = idx->data.ptr;
1290         idxnum = (idx->rows == 1) ? idx->cols : idx->rows;
1291         idxstep = (idx->rows == 1) ? CV_ELEM_SIZE( idx->type ) : idx->step;
1292         for( i = 0; i < idxnum; i++ )
1293         {
1294             index = (int) *((float*) (idxdata + i * idxstep));
1295             if( CV_MAT_ELEM( *trainData, float, compidx, index ) < threshold )
1296             {
1297                 (*left)->data.fl[(*left)->cols++] = (float) index;
1298             }
1299             else
1300             {
1301                 (*right)->data.fl[(*right)->cols++] = (float) index;
1302             }
1303         }
1304     }
1305 }
1306
1307 /* internal structure used in CART creation */
1308 typedef struct CvCARTNode
1309 {
1310     CvMat* sampleIdx;
1311     CvStumpClassifier* stump;
1312     int parent;
1313     int leftflag;
1314     float errdrop;
1315 } CvCARTNode;
1316
1317 CV_BOOST_IMPL
1318 CvClassifier* cvCreateCARTClassifier( CvMat* trainData,
1319                      int flags,
1320                      CvMat* trainClasses,
1321                      CvMat* typeMask,
1322                      CvMat* missedMeasurementsMask,
1323                      CvMat* compIdx,
1324                      CvMat* sampleIdx,
1325                      CvMat* weights,
1326                      CvClassifierTrainParams* trainParams )
1327 {
1328     CvCARTClassifier* cart = NULL;
1329     size_t datasize = 0;
1330     int count = 0;
1331     int i = 0;
1332     int j = 0;
1333     
1334     CvCARTNode* intnode = NULL;
1335     CvCARTNode* list = NULL;
1336     int listcount = 0;
1337     CvMat* lidx = NULL;
1338     CvMat* ridx = NULL;
1339     
1340     float maxerrdrop = 0.0F;
1341     int idx = 0;
1342
1343     void (*splitIdxCallback)( int compidx, float threshold,
1344                               CvMat* idx, CvMat** left, CvMat** right,
1345                               void* userdata );
1346     void* userdata;
1347
1348     count = ((CvCARTTrainParams*) trainParams)->count;
1349     
1350     assert( count > 0 );
1351
1352     datasize = sizeof( *cart ) + (sizeof( float ) + 3 * sizeof( int )) * count + 
1353         sizeof( float ) * (count + 1);
1354     
1355     cart = (CvCARTClassifier*) cvAlloc( datasize );
1356     memset( cart, 0, datasize );
1357     
1358     cart->count = count;
1359     
1360     cart->eval = cvEvalCARTClassifier;
1361     cart->save = NULL;
1362     cart->release = cvReleaseCARTClassifier;
1363
1364     cart->compidx = (int*) (cart + 1);
1365     cart->threshold = (float*) (cart->compidx + count);
1366     cart->left  = (int*) (cart->threshold + count);
1367     cart->right = (int*) (cart->left + count);
1368     cart->val = (float*) (cart->right + count);
1369
1370     datasize = sizeof( CvCARTNode ) * (count + count);
1371     intnode = (CvCARTNode*) cvAlloc( datasize );
1372     memset( intnode, 0, datasize );
1373     list = (CvCARTNode*) (intnode + count);
1374
1375     splitIdxCallback = ((CvCARTTrainParams*) trainParams)->splitIdx;
1376     userdata = ((CvCARTTrainParams*) trainParams)->userdata;
1377     if( splitIdxCallback == NULL )
1378     {
1379         splitIdxCallback = ( CV_IS_ROW_SAMPLE( flags ) )
1380             ? icvDefaultSplitIdx_R : icvDefaultSplitIdx_C;
1381         userdata = trainData;
1382     }
1383
1384     /* create root of the tree */
1385     intnode[0].sampleIdx = sampleIdx;
1386     intnode[0].stump = (CvStumpClassifier*)
1387         ((CvCARTTrainParams*) trainParams)->stumpConstructor( trainData, flags,
1388             trainClasses, typeMask, missedMeasurementsMask, compIdx, sampleIdx, weights,
1389             ((CvCARTTrainParams*) trainParams)->stumpTrainParams );
1390     cart->left[0] = cart->right[0] = 0;
1391
1392     /* build tree */
1393     listcount = 0;
1394     for( i = 1; i < count; i++ )
1395     {
1396         /* split last added node */
1397         splitIdxCallback( intnode[i-1].stump->compidx, intnode[i-1].stump->threshold,
1398             intnode[i-1].sampleIdx, &lidx, &ridx, userdata );
1399         
1400         if( intnode[i-1].stump->lerror != 0.0F )
1401         {
1402             list[listcount].sampleIdx = lidx;
1403             list[listcount].stump = (CvStumpClassifier*)
1404                 ((CvCARTTrainParams*) trainParams)->stumpConstructor( trainData, flags,
1405                     trainClasses, typeMask, missedMeasurementsMask, compIdx,
1406                     list[listcount].sampleIdx,
1407                     weights, ((CvCARTTrainParams*) trainParams)->stumpTrainParams );
1408             list[listcount].errdrop = intnode[i-1].stump->lerror
1409                 - (list[listcount].stump->lerror + list[listcount].stump->rerror);
1410             list[listcount].leftflag = 1;
1411             list[listcount].parent = i-1;
1412             listcount++;
1413         }
1414         else
1415         {
1416             cvReleaseMat( &lidx );
1417         }
1418         if( intnode[i-1].stump->rerror != 0.0F )
1419         {
1420             list[listcount].sampleIdx = ridx;
1421             list[listcount].stump = (CvStumpClassifier*)
1422                 ((CvCARTTrainParams*) trainParams)->stumpConstructor( trainData, flags,
1423                     trainClasses, typeMask, missedMeasurementsMask, compIdx,
1424                     list[listcount].sampleIdx,
1425                     weights, ((CvCARTTrainParams*) trainParams)->stumpTrainParams );
1426             list[listcount].errdrop = intnode[i-1].stump->rerror
1427                 - (list[listcount].stump->lerror + list[listcount].stump->rerror);
1428             list[listcount].leftflag = 0;
1429             list[listcount].parent = i-1;
1430             listcount++;
1431         }
1432         else
1433         {
1434             cvReleaseMat( &ridx );
1435         }
1436         
1437         if( listcount == 0 ) break;
1438
1439         /* find the best node to be added to the tree */
1440         idx = 0;
1441         maxerrdrop = list[idx].errdrop;
1442         for( j = 1; j < listcount; j++ )
1443         {
1444             if( list[j].errdrop > maxerrdrop )
1445             {
1446                 idx = j;
1447                 maxerrdrop = list[j].errdrop;
1448             }
1449         }
1450         intnode[i] = list[idx];
1451         if( list[idx].leftflag )
1452         {
1453             cart->left[list[idx].parent] = i;
1454         }
1455         else
1456         {
1457             cart->right[list[idx].parent] = i;
1458         }
1459         if( idx != (listcount - 1) )
1460         {
1461             list[idx] = list[listcount - 1];
1462         }
1463         listcount--;
1464     }
1465
1466     /* fill <cart> fields */
1467     j = 0;
1468     cart->count = 0;
1469     for( i = 0; i < count && (intnode[i].stump != NULL); i++ )
1470     {
1471         cart->count++;
1472         cart->compidx[i] = intnode[i].stump->compidx;
1473         cart->threshold[i] = intnode[i].stump->threshold;
1474         
1475         /* leaves */
1476         if( cart->left[i] <= 0 )
1477         {
1478             cart->left[i] = -j;
1479             cart->val[j] = intnode[i].stump->left;
1480             j++;
1481         }
1482         if( cart->right[i] <= 0 )
1483         {
1484             cart->right[i] = -j;
1485             cart->val[j] = intnode[i].stump->right;
1486             j++;
1487         }
1488     }
1489     
1490     /* CLEAN UP */
1491     for( i = 0; i < count && (intnode[i].stump != NULL); i++ )
1492     {
1493         intnode[i].stump->release( (CvClassifier**) &(intnode[i].stump) );
1494         if( i != 0 )
1495         {
1496             cvReleaseMat( &(intnode[i].sampleIdx) );
1497         }
1498     }
1499     for( i = 0; i < listcount; i++ )
1500     {
1501         list[i].stump->release( (CvClassifier**) &(list[i].stump) );
1502         cvReleaseMat( &(list[i].sampleIdx) );
1503     }
1504     
1505     cvFree( &intnode );
1506
1507     return (CvClassifier*) cart;
1508 }
1509
1510 /****************************************************************************************\
1511 *                                        Boosting                                        *
1512 \****************************************************************************************/
1513
1514 typedef struct CvBoostTrainer
1515 {
1516     CvBoostType type;
1517     int count;             /* (idx) ? number_of_indices : number_of_samples */
1518     int* idx;
1519     float* F;
1520 } CvBoostTrainer;
1521
1522 /*
1523  * cvBoostStartTraining, cvBoostNextWeakClassifier, cvBoostEndTraining
1524  *
1525  * These functions perform training of 2-class boosting classifier
1526  * using ANY appropriate weak classifier
1527  */
1528
1529 CV_BOOST_IMPL
1530 CvBoostTrainer* icvBoostStartTraining( CvMat* trainClasses,
1531                                        CvMat* weakTrainVals,
1532                                        CvMat* /*weights*/,
1533                                        CvMat* sampleIdx,
1534                                        CvBoostType type )
1535 {
1536     uchar* ydata;
1537     int ystep;
1538     int m;
1539     uchar* traindata;
1540     int trainstep;
1541     int trainnum;
1542     int i;
1543     int idx;
1544
1545     size_t datasize;
1546     CvBoostTrainer* ptr;
1547
1548     int idxnum;
1549     int idxstep;
1550     uchar* idxdata;
1551
1552     assert( trainClasses != NULL );
1553     assert( CV_MAT_TYPE( trainClasses->type ) == CV_32FC1 );
1554     assert( weakTrainVals != NULL );
1555     assert( CV_MAT_TYPE( weakTrainVals->type ) == CV_32FC1 );
1556
1557     CV_MAT2VEC( *trainClasses, ydata, ystep, m );
1558     CV_MAT2VEC( *weakTrainVals, traindata, trainstep, trainnum );
1559
1560     assert( m == trainnum );
1561
1562     idxnum = 0;
1563     idxstep = 0;
1564     idxdata = NULL;
1565     if( sampleIdx )
1566     {
1567         CV_MAT2VEC( *sampleIdx, idxdata, idxstep, idxnum );
1568     }
1569         
1570     datasize = sizeof( *ptr ) + sizeof( *ptr->idx ) * idxnum;
1571     ptr = (CvBoostTrainer*) cvAlloc( datasize );
1572     memset( ptr, 0, datasize );
1573     ptr->F = NULL;
1574     ptr->idx = NULL;
1575
1576     ptr->count = m;
1577     ptr->type = type;
1578     
1579     if( idxnum > 0 )
1580     {
1581         CvScalar s;
1582
1583         ptr->idx = (int*) (ptr + 1);
1584         ptr->count = idxnum;
1585         for( i = 0; i < ptr->count; i++ )
1586         {
1587             cvRawDataToScalar( idxdata + i*idxstep, CV_MAT_TYPE( sampleIdx->type ), &s );
1588             ptr->idx[i] = (int) s.val[0];
1589         }
1590     }
1591     for( i = 0; i < ptr->count; i++ )
1592     {
1593         idx = (ptr->idx) ? ptr->idx[i] : i;
1594
1595         *((float*) (traindata + idx * trainstep)) = 
1596             2.0F * (*((float*) (ydata + idx * ystep))) - 1.0F;
1597     }
1598
1599     return ptr;
1600 }
1601
1602 /*
1603  *
1604  * Discrete AdaBoost functions
1605  *
1606  */
1607 CV_BOOST_IMPL
1608 float icvBoostNextWeakClassifierDAB( CvMat* weakEvalVals,
1609                                      CvMat* trainClasses,
1610                                      CvMat* /*weakTrainVals*/,
1611                                      CvMat* weights,
1612                                      CvBoostTrainer* trainer )
1613 {
1614     uchar* evaldata;
1615     int evalstep;
1616     int m;
1617     uchar* ydata;
1618     int ystep;
1619     int ynum;
1620     uchar* wdata;
1621     int wstep;
1622     int wnum;
1623
1624     float sumw;
1625     float err;
1626     int i;
1627     int idx;
1628
1629     CV_Assert( weakEvalVals != NULL );
1630     CV_Assert( CV_MAT_TYPE( weakEvalVals->type ) == CV_32FC1 );
1631     CV_Assert( trainClasses != NULL );
1632     CV_Assert( CV_MAT_TYPE( trainClasses->type ) == CV_32FC1 );
1633     CV_Assert( weights != NULL );
1634     CV_Assert( CV_MAT_TYPE( weights ->type ) == CV_32FC1 );
1635
1636     CV_MAT2VEC( *weakEvalVals, evaldata, evalstep, m );
1637     CV_MAT2VEC( *trainClasses, ydata, ystep, ynum );
1638     CV_MAT2VEC( *weights, wdata, wstep, wnum );
1639
1640     assert( m == ynum );
1641     assert( m == wnum );
1642
1643     sumw = 0.0F;
1644     err = 0.0F;
1645     for( i = 0; i < trainer->count; i++ )
1646     {
1647         idx = (trainer->idx) ? trainer->idx[i] : i;
1648
1649         sumw += *((float*) (wdata + idx*wstep));
1650         err += (*((float*) (wdata + idx*wstep))) *
1651             ( (*((float*) (evaldata + idx*evalstep))) != 
1652                 2.0F * (*((float*) (ydata + idx*ystep))) - 1.0F );
1653     }
1654     err /= sumw;
1655     err = -cvLogRatio( err );
1656     
1657     for( i = 0; i < trainer->count; i++ )
1658     {
1659         idx = (trainer->idx) ? trainer->idx[i] : i;
1660
1661         *((float*) (wdata + idx*wstep)) *= expf( err * 
1662             ((*((float*) (evaldata + idx*evalstep))) != 
1663                 2.0F * (*((float*) (ydata + idx*ystep))) - 1.0F) );
1664         sumw += *((float*) (wdata + idx*wstep));
1665     }
1666     for( i = 0; i < trainer->count; i++ )
1667     {
1668         idx = (trainer->idx) ? trainer->idx[i] : i;
1669
1670         *((float*) (wdata + idx * wstep)) /= sumw;
1671     }
1672     
1673     return err;
1674 }
1675
1676 /*
1677  *
1678  * Real AdaBoost functions
1679  *
1680  */
1681 CV_BOOST_IMPL
1682 float icvBoostNextWeakClassifierRAB( CvMat* weakEvalVals,
1683                                      CvMat* trainClasses,
1684                                      CvMat* /*weakTrainVals*/,
1685                                      CvMat* weights,
1686                                      CvBoostTrainer* trainer )
1687 {
1688     uchar* evaldata;
1689     int evalstep;
1690     int m;
1691     uchar* ydata;
1692     int ystep;
1693     int ynum;
1694     uchar* wdata;
1695     int wstep;
1696     int wnum;
1697
1698     float sumw;
1699     int i, idx;
1700
1701     CV_Assert( weakEvalVals != NULL );
1702     CV_Assert( CV_MAT_TYPE( weakEvalVals->type ) == CV_32FC1 );
1703     CV_Assert( trainClasses != NULL );
1704     CV_Assert( CV_MAT_TYPE( trainClasses->type ) == CV_32FC1 );
1705     CV_Assert( weights != NULL );
1706     CV_Assert( CV_MAT_TYPE( weights ->type ) == CV_32FC1 );
1707
1708     CV_MAT2VEC( *weakEvalVals, evaldata, evalstep, m );
1709     CV_MAT2VEC( *trainClasses, ydata, ystep, ynum );
1710     CV_MAT2VEC( *weights, wdata, wstep, wnum );
1711
1712     CV_Assert( m == ynum );
1713     CV_Assert( m == wnum );
1714
1715
1716     sumw = 0.0F;
1717     for( i = 0; i < trainer->count; i++ )
1718     {
1719         idx = (trainer->idx) ? trainer->idx[i] : i;
1720
1721         *((float*) (wdata + idx*wstep)) *= expf( (-(*((float*) (ydata + idx*ystep))) + 0.5F)
1722             * cvLogRatio( *((float*) (evaldata + idx*evalstep)) ) );
1723         sumw += *((float*) (wdata + idx*wstep));
1724     }
1725     for( i = 0; i < trainer->count; i++ )
1726     {
1727         idx = (trainer->idx) ? trainer->idx[i] : i;
1728
1729         *((float*) (wdata + idx*wstep)) /= sumw;
1730     }
1731     
1732     return 1.0F;
1733 }
1734
1735 /*
1736  *
1737  * LogitBoost functions
1738  *
1739  */
1740 #define CV_LB_PROB_THRESH      0.01F
1741 #define CV_LB_WEIGHT_THRESHOLD 0.0001F
1742
1743 CV_BOOST_IMPL
1744 void icvResponsesAndWeightsLB( int num, uchar* wdata, int wstep,
1745                                uchar* ydata, int ystep,
1746                                uchar* fdata, int fstep,
1747                                uchar* traindata, int trainstep,
1748                                int* indices )
1749 {
1750     int i, idx;
1751     float p;
1752
1753     for( i = 0; i < num; i++ )
1754     {
1755         idx = (indices) ? indices[i] : i;
1756
1757         p = 1.0F / (1.0F + expf( -(*((float*) (fdata + idx*fstep)))) );
1758         *((float*) (wdata + idx*wstep)) = MAX( p * (1.0F - p), CV_LB_WEIGHT_THRESHOLD );
1759         if( *((float*) (ydata + idx*ystep)) == 1.0F )
1760         {
1761             *((float*) (traindata + idx*trainstep)) = 
1762                 1.0F / (MAX( p, CV_LB_PROB_THRESH ));
1763         }
1764         else
1765         {
1766             *((float*) (traindata + idx*trainstep)) = 
1767                 -1.0F / (MAX( 1.0F - p, CV_LB_PROB_THRESH ));
1768         }
1769     }
1770 }
1771
1772 CV_BOOST_IMPL
1773 CvBoostTrainer* icvBoostStartTrainingLB( CvMat* trainClasses,
1774                                          CvMat* weakTrainVals,
1775                                          CvMat* weights,
1776                                          CvMat* sampleIdx,
1777                                          CvBoostType type )
1778 {
1779     size_t datasize;
1780     CvBoostTrainer* ptr;
1781
1782     uchar* ydata;
1783     int ystep;
1784     int m;
1785     uchar* traindata;
1786     int trainstep;
1787     int trainnum;
1788     uchar* wdata;
1789     int wstep;
1790     int wnum;
1791     int i;
1792
1793     int idxnum;
1794     int idxstep;
1795     uchar* idxdata;
1796
1797     assert( trainClasses != NULL );
1798     assert( CV_MAT_TYPE( trainClasses->type ) == CV_32FC1 );
1799     assert( weakTrainVals != NULL );
1800     assert( CV_MAT_TYPE( weakTrainVals->type ) == CV_32FC1 );
1801     assert( weights != NULL );
1802     assert( CV_MAT_TYPE( weights->type ) == CV_32FC1 );
1803
1804     CV_MAT2VEC( *trainClasses, ydata, ystep, m );
1805     CV_MAT2VEC( *weakTrainVals, traindata, trainstep, trainnum );
1806     CV_MAT2VEC( *weights, wdata, wstep, wnum );
1807
1808     assert( m == trainnum );
1809     assert( m == wnum );
1810
1811
1812     idxnum = 0;
1813     idxstep = 0;
1814     idxdata = NULL;
1815     if( sampleIdx )
1816     {
1817         CV_MAT2VEC( *sampleIdx, idxdata, idxstep, idxnum );
1818     }
1819         
1820     datasize = sizeof( *ptr ) + sizeof( *ptr->F ) * m + sizeof( *ptr->idx ) * idxnum;
1821     ptr = (CvBoostTrainer*) cvAlloc( datasize );
1822     memset( ptr, 0, datasize );
1823     ptr->F = (float*) (ptr + 1);
1824     ptr->idx = NULL;
1825
1826     ptr->count = m;
1827     ptr->type = type;
1828     
1829     if( idxnum > 0 )
1830     {
1831         CvScalar s;
1832
1833         ptr->idx = (int*) (ptr->F + m);
1834         ptr->count = idxnum;
1835         for( i = 0; i < ptr->count; i++ )
1836         {
1837             cvRawDataToScalar( idxdata + i*idxstep, CV_MAT_TYPE( sampleIdx->type ), &s );
1838             ptr->idx[i] = (int) s.val[0];
1839         }
1840     }
1841
1842     for( i = 0; i < m; i++ )
1843     {
1844         ptr->F[i] = 0.0F;
1845     }
1846
1847     icvResponsesAndWeightsLB( ptr->count, wdata, wstep, ydata, ystep,
1848                               (uchar*) ptr->F, sizeof( *ptr->F ),
1849                               traindata, trainstep, ptr->idx );
1850
1851     return ptr;
1852 }
1853
1854 CV_BOOST_IMPL
1855 float icvBoostNextWeakClassifierLB( CvMat* weakEvalVals,
1856                                     CvMat* trainClasses,
1857                                     CvMat* weakTrainVals,
1858                                     CvMat* weights,
1859                                     CvBoostTrainer* trainer )
1860 {
1861     uchar* evaldata;
1862     int evalstep;
1863     int m;
1864     uchar* ydata;
1865     int ystep;
1866     int ynum;
1867     uchar* traindata;
1868     int trainstep;
1869     int trainnum;
1870     uchar* wdata;
1871     int wstep;
1872     int wnum;
1873     int i, idx;
1874
1875     assert( weakEvalVals != NULL );
1876     assert( CV_MAT_TYPE( weakEvalVals->type ) == CV_32FC1 );
1877     assert( trainClasses != NULL );
1878     assert( CV_MAT_TYPE( trainClasses->type ) == CV_32FC1 );
1879     assert( weakTrainVals != NULL );
1880     assert( CV_MAT_TYPE( weakTrainVals->type ) == CV_32FC1 );
1881     assert( weights != NULL );
1882     assert( CV_MAT_TYPE( weights ->type ) == CV_32FC1 );
1883
1884     CV_MAT2VEC( *weakEvalVals, evaldata, evalstep, m );
1885     CV_MAT2VEC( *trainClasses, ydata, ystep, ynum );
1886     CV_MAT2VEC( *weakTrainVals, traindata, trainstep, trainnum );
1887     CV_MAT2VEC( *weights, wdata, wstep, wnum );
1888
1889     assert( m == ynum );
1890     assert( m == wnum );
1891     assert( m == trainnum );
1892     //assert( m == trainer->count );
1893
1894     for( i = 0; i < trainer->count; i++ )
1895     {
1896         idx = (trainer->idx) ? trainer->idx[i] : i;
1897
1898         trainer->F[idx] += *((float*) (evaldata + idx * evalstep));
1899     }
1900     
1901     icvResponsesAndWeightsLB( trainer->count, wdata, wstep, ydata, ystep,
1902                               (uchar*) trainer->F, sizeof( *trainer->F ),
1903                               traindata, trainstep, trainer->idx );
1904
1905     return 1.0F;
1906 }
1907
1908 /*
1909  *
1910  * Gentle AdaBoost
1911  *
1912  */
1913 CV_BOOST_IMPL
1914 float icvBoostNextWeakClassifierGAB( CvMat* weakEvalVals,
1915                                      CvMat* trainClasses,
1916                                      CvMat* /*weakTrainVals*/,
1917                                      CvMat* weights,
1918                                      CvBoostTrainer* trainer )
1919 {
1920     uchar* evaldata;
1921     int evalstep;
1922     int m;
1923     uchar* ydata;
1924     int ystep;
1925     int ynum;
1926     uchar* wdata;
1927     int wstep;
1928     int wnum;
1929
1930     int i, idx;
1931     float sumw;
1932
1933     CV_Assert( weakEvalVals != NULL );
1934     CV_Assert( CV_MAT_TYPE( weakEvalVals->type ) == CV_32FC1 );
1935     CV_Assert( trainClasses != NULL );
1936     CV_Assert( CV_MAT_TYPE( trainClasses->type ) == CV_32FC1 );
1937     CV_Assert( weights != NULL );
1938     CV_Assert( CV_MAT_TYPE( weights->type ) == CV_32FC1 );
1939
1940     CV_MAT2VEC( *weakEvalVals, evaldata, evalstep, m );
1941     CV_MAT2VEC( *trainClasses, ydata, ystep, ynum );
1942     CV_MAT2VEC( *weights, wdata, wstep, wnum );
1943
1944     assert( m == ynum );
1945     assert( m == wnum );
1946
1947     sumw = 0.0F;
1948     for( i = 0; i < trainer->count; i++ )
1949     {
1950         idx = (trainer->idx) ? trainer->idx[i] : i;
1951
1952         *((float*) (wdata + idx*wstep)) *= 
1953             expf( -(*((float*) (evaldata + idx*evalstep)))
1954                   * ( 2.0F * (*((float*) (ydata + idx*ystep))) - 1.0F ) );
1955         sumw += *((float*) (wdata + idx*wstep));
1956     }
1957     
1958     for( i = 0; i < trainer->count; i++ )
1959     {
1960         idx = (trainer->idx) ? trainer->idx[i] : i;
1961
1962         *((float*) (wdata + idx*wstep)) /= sumw;
1963     }
1964
1965     return 1.0F;
1966 }
1967
1968 typedef CvBoostTrainer* (*CvBoostStartTraining)( CvMat* trainClasses,
1969                                                  CvMat* weakTrainVals,
1970                                                  CvMat* weights,
1971                                                  CvMat* sampleIdx,
1972                                                  CvBoostType type );
1973
1974 typedef float (*CvBoostNextWeakClassifier)( CvMat* weakEvalVals,
1975                                             CvMat* trainClasses,
1976                                             CvMat* weakTrainVals,
1977                                             CvMat* weights,
1978                                             CvBoostTrainer* data );
1979
1980 CvBoostStartTraining startTraining[4] = {
1981         icvBoostStartTraining,
1982         icvBoostStartTraining,
1983         icvBoostStartTrainingLB,
1984         icvBoostStartTraining
1985     };
1986
1987 CvBoostNextWeakClassifier nextWeakClassifier[4] = {
1988         icvBoostNextWeakClassifierDAB,
1989         icvBoostNextWeakClassifierRAB,
1990         icvBoostNextWeakClassifierLB,
1991         icvBoostNextWeakClassifierGAB
1992     };
1993
1994 /*
1995  *
1996  * Dispatchers
1997  *
1998  */
1999 CV_BOOST_IMPL
2000 CvBoostTrainer* cvBoostStartTraining( CvMat* trainClasses,
2001                                       CvMat* weakTrainVals,
2002                                       CvMat* weights,
2003                                       CvMat* sampleIdx,
2004                                       CvBoostType type )
2005 {
2006     return startTraining[type]( trainClasses, weakTrainVals, weights, sampleIdx, type );
2007 }
2008
2009 CV_BOOST_IMPL
2010 void cvBoostEndTraining( CvBoostTrainer** trainer )
2011 {
2012     cvFree( trainer );
2013     *trainer = NULL;
2014 }
2015
2016 CV_BOOST_IMPL
2017 float cvBoostNextWeakClassifier( CvMat* weakEvalVals,
2018                                  CvMat* trainClasses,
2019                                  CvMat* weakTrainVals,
2020                                  CvMat* weights,
2021                                  CvBoostTrainer* trainer )
2022 {
2023     return nextWeakClassifier[trainer->type]( weakEvalVals, trainClasses,
2024         weakTrainVals, weights, trainer    );
2025 }
2026
2027 /****************************************************************************************\
2028 *                                    Boosted tree models                                 *
2029 \****************************************************************************************/
2030
2031 typedef struct CvBtTrainer
2032 {
2033     /* {{ external */    
2034     CvMat* trainData;
2035     int flags;
2036     
2037     CvMat* trainClasses;
2038     int m;
2039     uchar* ydata;
2040     int ystep;
2041
2042     CvMat* sampleIdx;
2043     int numsamples;
2044     
2045     float param[2];
2046     CvBoostType type;
2047     int numclasses;
2048     /* }} external */
2049
2050     CvMTStumpTrainParams stumpParams;
2051     CvCARTTrainParams  cartParams;
2052
2053     float* f;          /* F_(m-1) */
2054     CvMat* y;          /* yhat    */
2055     CvMat* weights;
2056     CvBoostTrainer* boosttrainer;
2057 } CvBtTrainer;
2058
2059 /*
2060  * cvBtStart, cvBtNext, cvBtEnd
2061  *
2062  * These functions perform iterative training of
2063  * 2-class (CV_DABCLASS - CV_GABCLASS, CV_L2CLASS), K-class (CV_LKCLASS) classifier
2064  * or fit regression model (CV_LSREG, CV_LADREG, CV_MREG)
2065  * using decision tree as a weak classifier.
2066  */
2067
2068 typedef void (*CvZeroApproxFunc)( float* approx, CvBtTrainer* trainer );
2069
2070 /* Mean zero approximation */
2071 void icvZeroApproxMean( float* approx, CvBtTrainer* trainer )
2072 {
2073     int i;
2074     int idx;
2075
2076     approx[0] = 0.0F;
2077     for( i = 0; i < trainer->numsamples; i++ )
2078     {
2079         idx = icvGetIdxAt( trainer->sampleIdx, i );
2080         approx[0] += *((float*) (trainer->ydata + idx * trainer->ystep));
2081     }
2082     approx[0] /= (float) trainer->numsamples;
2083 }
2084
2085 /*
2086  * Median zero approximation
2087  */
2088 void icvZeroApproxMed( float* approx, CvBtTrainer* trainer )
2089 {
2090     int i;
2091     int idx;
2092
2093     for( i = 0; i < trainer->numsamples; i++ )
2094     {
2095         idx = icvGetIdxAt( trainer->sampleIdx, i );
2096         trainer->f[i] = *((float*) (trainer->ydata + idx * trainer->ystep));
2097     }
2098     
2099     icvSort_32f( trainer->f, trainer->numsamples, 0 );
2100     approx[0] = trainer->f[trainer->numsamples / 2];
2101 }
2102
2103 /*
2104  * 0.5 * log( mean(y) / (1 - mean(y)) ) where y in {0, 1}
2105  */
2106 void icvZeroApproxLog( float* approx, CvBtTrainer* trainer )
2107 {
2108     float y_mean;
2109
2110     icvZeroApproxMean( &y_mean, trainer );
2111     approx[0] = 0.5F * cvLogRatio( y_mean );
2112 }
2113
2114 /*
2115  * 0 zero approximation
2116  */
2117 void icvZeroApprox0( float* approx, CvBtTrainer* trainer )
2118 {
2119     int i;
2120
2121     for( i = 0; i < trainer->numclasses; i++ )
2122     {
2123         approx[i] = 0.0F;
2124     }
2125 }
2126
2127 static CvZeroApproxFunc icvZeroApproxFunc[] =
2128 {
2129     icvZeroApprox0,    /* CV_DABCLASS */
2130     icvZeroApprox0,    /* CV_RABCLASS */
2131     icvZeroApprox0,    /* CV_LBCLASS  */
2132     icvZeroApprox0,    /* CV_GABCLASS */
2133     icvZeroApproxLog,  /* CV_L2CLASS  */
2134     icvZeroApprox0,    /* CV_LKCLASS  */
2135     icvZeroApproxMean, /* CV_LSREG    */
2136     icvZeroApproxMed,  /* CV_LADREG   */
2137     icvZeroApproxMed,  /* CV_MREG     */
2138 };
2139
2140 CV_BOOST_IMPL
2141 void cvBtNext( CvCARTClassifier** trees, CvBtTrainer* trainer );
2142
2143 CV_BOOST_IMPL
2144 CvBtTrainer* cvBtStart( CvCARTClassifier** trees,
2145                         CvMat* trainData,
2146                         int flags,
2147                         CvMat* trainClasses,
2148                         CvMat* sampleIdx,
2149                         int numsplits,
2150                         CvBoostType type,
2151                         int numclasses,
2152                         float* param )
2153 {
2154     CvBtTrainer* ptr = 0;
2155
2156     CV_FUNCNAME( "cvBtStart" );
2157
2158     __BEGIN__;
2159
2160     size_t data_size;
2161     float* zero_approx;
2162     int m;
2163     int i, j;
2164     
2165     if( trees == NULL )
2166     {
2167         CV_ERROR( CV_StsNullPtr, "Invalid trees parameter" );
2168     }
2169     
2170     if( type < CV_DABCLASS || type > CV_MREG ) 
2171     {
2172         CV_ERROR( CV_StsUnsupportedFormat, "Unsupported type parameter" );
2173     }
2174     if( type == CV_LKCLASS )
2175     {
2176         CV_ASSERT( numclasses >= 2 );
2177     }
2178     else
2179     {
2180         numclasses = 1;
2181     }
2182
2183     m = MAX( trainClasses->rows, trainClasses->cols );
2184     ptr = NULL;
2185     data_size = sizeof( *ptr );
2186     if( type > CV_GABCLASS )
2187     {
2188         data_size += m * numclasses * sizeof( *(ptr->f) );
2189     }
2190     CV_CALL( ptr = (CvBtTrainer*) cvAlloc( data_size ) );
2191     memset( ptr, 0, data_size );
2192     ptr->f = (float*) (ptr + 1);
2193
2194     ptr->trainData = trainData;
2195     ptr->flags = flags;
2196     ptr->trainClasses = trainClasses;
2197     CV_MAT2VEC( *trainClasses, ptr->ydata, ptr->ystep, ptr->m );
2198     
2199     memset( &(ptr->cartParams), 0, sizeof( ptr->cartParams ) );
2200     memset( &(ptr->stumpParams), 0, sizeof( ptr->stumpParams ) );
2201
2202     switch( type )
2203     {
2204         case CV_DABCLASS:
2205             ptr->stumpParams.error = CV_MISCLASSIFICATION;
2206             ptr->stumpParams.type  = CV_CLASSIFICATION_CLASS;
2207             break;
2208         case CV_RABCLASS:
2209             ptr->stumpParams.error = CV_GINI;
2210             ptr->stumpParams.type  = CV_CLASSIFICATION;
2211             break;
2212         default:
2213             ptr->stumpParams.error = CV_SQUARE;
2214             ptr->stumpParams.type  = CV_REGRESSION;
2215     }
2216     ptr->cartParams.count = numsplits;
2217     ptr->cartParams.stumpTrainParams = (CvClassifierTrainParams*) &(ptr->stumpParams);
2218     ptr->cartParams.stumpConstructor = cvCreateMTStumpClassifier;
2219
2220     ptr->param[0] = param[0];
2221     ptr->param[1] = param[1];
2222     ptr->type = type;
2223     ptr->numclasses = numclasses;
2224
2225     CV_CALL( ptr->y = cvCreateMat( 1, m, CV_32FC1 ) );
2226     ptr->sampleIdx = sampleIdx;
2227     ptr->numsamples = ( sampleIdx == NULL ) ? ptr->m
2228                              : MAX( sampleIdx->rows, sampleIdx->cols );
2229     
2230     ptr->weights = cvCreateMat( 1, m, CV_32FC1 );
2231     cvSet( ptr->weights, cvScalar( 1.0 ) );    
2232     
2233     if( type <= CV_GABCLASS )
2234     {
2235         ptr->boosttrainer = cvBoostStartTraining( ptr->trainClasses, ptr->y,
2236             ptr->weights, NULL, type );
2237
2238         CV_CALL( cvBtNext( trees, ptr ) );
2239     }
2240     else
2241     {
2242         data_size = sizeof( *zero_approx ) * numclasses;
2243         CV_CALL( zero_approx = (float*) cvAlloc( data_size ) );
2244         icvZeroApproxFunc[type]( zero_approx, ptr );
2245         for( i = 0; i < m; i++ )
2246         {
2247             for( j = 0; j < numclasses; j++ )
2248             {
2249                 ptr->f[i * numclasses + j] = zero_approx[j];
2250             }
2251         }
2252
2253         CV_CALL( cvBtNext( trees, ptr ) );
2254
2255         for( i = 0; i < numclasses; i++ )
2256         {
2257             for( j = 0; j <= trees[i]->count; j++ )
2258             {
2259                 trees[i]->val[j] += zero_approx[i];
2260             }
2261         }    
2262         CV_CALL( cvFree( &zero_approx ) );
2263     }
2264
2265     __END__;
2266
2267     return ptr;
2268 }
2269
2270 void icvBtNext_LSREG( CvCARTClassifier** trees, CvBtTrainer* trainer )
2271 {
2272     int i;
2273
2274     /* yhat_i = y_i - F_(m-1)(x_i) */
2275     for( i = 0; i < trainer->m; i++ )
2276     {
2277         trainer->y->data.fl[i] = 
2278             *((float*) (trainer->ydata + i * trainer->ystep)) - trainer->f[i];
2279     }
2280
2281     trees[0] = (CvCARTClassifier*) cvCreateCARTClassifier( trainer->trainData,
2282         trainer->flags,
2283         trainer->y, NULL, NULL, NULL, trainer->sampleIdx, trainer->weights,
2284         (CvClassifierTrainParams*) &trainer->cartParams );
2285 }
2286
2287
2288 void icvBtNext_LADREG( CvCARTClassifier** trees, CvBtTrainer* trainer )
2289 {
2290     CvCARTClassifier* ptr;
2291     int i, j;
2292     CvMat sample;
2293     int sample_step;
2294     uchar* sample_data;
2295     int index;
2296     
2297     int data_size;
2298     int* idx;
2299     float* resp;
2300     int respnum;
2301     float val;
2302
2303     data_size = trainer->m * sizeof( *idx );
2304     idx = (int*) cvAlloc( data_size );
2305     data_size = trainer->m * sizeof( *resp );
2306     resp = (float*) cvAlloc( data_size );
2307
2308     /* yhat_i = sign(y_i - F_(m-1)(x_i)) */
2309     for( i = 0; i < trainer->numsamples; i++ )
2310     {
2311         index = icvGetIdxAt( trainer->sampleIdx, i );
2312         trainer->y->data.fl[index] = (float)
2313              CV_SIGN( *((float*) (trainer->ydata + index * trainer->ystep))
2314                      - trainer->f[index] );
2315     }
2316
2317     ptr = (CvCARTClassifier*) cvCreateCARTClassifier( trainer->trainData, trainer->flags,
2318         trainer->y, NULL, NULL, NULL, trainer->sampleIdx, trainer->weights,
2319         (CvClassifierTrainParams*) &trainer->cartParams );
2320
2321     CV_GET_SAMPLE( *trainer->trainData, trainer->flags, 0, sample );
2322     CV_GET_SAMPLE_STEP( *trainer->trainData, trainer->flags, sample_step );
2323     sample_data = sample.data.ptr;
2324     for( i = 0; i < trainer->numsamples; i++ )
2325     {
2326         index = icvGetIdxAt( trainer->sampleIdx, i );
2327         sample.data.ptr = sample_data + index * sample_step;
2328         idx[index] = (int) cvEvalCARTClassifierIdx( (CvClassifier*) ptr, &sample );
2329     }
2330     for( j = 0; j <= ptr->count; j++ )
2331     {
2332         respnum = 0;
2333         for( i = 0; i < trainer->numsamples; i++ )
2334         {
2335             index = icvGetIdxAt( trainer->sampleIdx, i );
2336             if( idx[index] == j )
2337             {
2338                 resp[respnum++] = *((float*) (trainer->ydata + index * trainer->ystep))
2339                                   - trainer->f[index];
2340             }
2341         }
2342         if( respnum > 0 )
2343         {
2344             icvSort_32f( resp, respnum, 0 );
2345             val = resp[respnum / 2];
2346         }
2347         else
2348         {
2349             val = 0.0F;
2350         }
2351         ptr->val[j] = val;
2352     }
2353
2354     cvFree( &idx );
2355     cvFree( &resp );
2356     
2357     trees[0] = ptr;
2358 }
2359
2360
2361 void icvBtNext_MREG( CvCARTClassifier** trees, CvBtTrainer* trainer )
2362 {
2363     CvCARTClassifier* ptr;
2364     int i, j;
2365     CvMat sample;
2366     int sample_step;
2367     uchar* sample_data;
2368     
2369     int data_size;
2370     int* idx;
2371     float* resid;
2372     float* resp;
2373     int respnum;
2374     float rhat;
2375     float val;
2376     float delta;
2377     int index;
2378
2379     data_size = trainer->m * sizeof( *idx );
2380     idx = (int*) cvAlloc( data_size );
2381     data_size = trainer->m * sizeof( *resp );
2382     resp = (float*) cvAlloc( data_size );
2383     data_size = trainer->m * sizeof( *resid );
2384     resid = (float*) cvAlloc( data_size );
2385
2386     /* resid_i = (y_i - F_(m-1)(x_i)) */
2387     for( i = 0; i < trainer->numsamples; i++ )
2388     {
2389         index = icvGetIdxAt( trainer->sampleIdx, i );
2390         resid[index] = *((float*) (trainer->ydata + index * trainer->ystep))
2391                        - trainer->f[index];
2392         /* for delta */
2393         resp[i] = (float) fabs( resid[index] );
2394     }
2395     
2396     /* delta = quantile_alpha{abs(resid_i)} */
2397     icvSort_32f( resp, trainer->numsamples, 0 );
2398     delta = resp[(int)(trainer->param[1] * (trainer->numsamples - 1))];
2399
2400     /* yhat_i */
2401     for( i = 0; i < trainer->numsamples; i++ )
2402     {
2403         index = icvGetIdxAt( trainer->sampleIdx, i );
2404         trainer->y->data.fl[index] = MIN( delta, ((float) fabs( resid[index] )) ) *
2405                                  CV_SIGN( resid[index] );
2406     }
2407     
2408     ptr = (CvCARTClassifier*) cvCreateCARTClassifier( trainer->trainData, trainer->flags,
2409         trainer->y, NULL, NULL, NULL, trainer->sampleIdx, trainer->weights,
2410         (CvClassifierTrainParams*) &trainer->cartParams );
2411
2412     CV_GET_SAMPLE( *trainer->trainData, trainer->flags, 0, sample );
2413     CV_GET_SAMPLE_STEP( *trainer->trainData, trainer->flags, sample_step );
2414     sample_data = sample.data.ptr;
2415     for( i = 0; i < trainer->numsamples; i++ )
2416     {
2417         index = icvGetIdxAt( trainer->sampleIdx, i );
2418         sample.data.ptr = sample_data + index * sample_step;
2419         idx[index] = (int) cvEvalCARTClassifierIdx( (CvClassifier*) ptr, &sample );
2420     }
2421     for( j = 0; j <= ptr->count; j++ )
2422     {
2423         respnum = 0;
2424
2425         for( i = 0; i < trainer->numsamples; i++ )
2426         {
2427             index = icvGetIdxAt( trainer->sampleIdx, i );
2428             if( idx[index] == j )
2429             {
2430                 resp[respnum++] = *((float*) (trainer->ydata + index * trainer->ystep))
2431                                   - trainer->f[index];
2432             }
2433         }
2434         if( respnum > 0 )
2435         {
2436             /* rhat = median(y_i - F_(m-1)(x_i)) */
2437             icvSort_32f( resp, respnum, 0 );
2438             rhat = resp[respnum / 2];
2439             
2440             /* val = sum{sign(r_i - rhat_i) * min(delta, abs(r_i - rhat_i)}
2441              * r_i = y_i - F_(m-1)(x_i)
2442              */
2443             val = 0.0F;
2444             for( i = 0; i < respnum; i++ )
2445             {
2446                 val += CV_SIGN( resp[i] - rhat )
2447                        * MIN( delta, (float) fabs( resp[i] - rhat ) );
2448             }
2449
2450             val = rhat + val / (float) respnum;
2451         }
2452         else
2453         {
2454             val = 0.0F;
2455         }
2456
2457         ptr->val[j] = val;
2458
2459     }
2460
2461     cvFree( &resid );
2462     cvFree( &resp );
2463     cvFree( &idx );
2464     
2465     trees[0] = ptr;
2466 }
2467
2468 //#define CV_VAL_MAX 1e304
2469
2470 //#define CV_LOG_VAL_MAX 700.0
2471
2472 #define CV_VAL_MAX 1e+8
2473
2474 #define CV_LOG_VAL_MAX 18.0
2475
2476 void icvBtNext_L2CLASS( CvCARTClassifier** trees, CvBtTrainer* trainer )
2477 {
2478     CvCARTClassifier* ptr;
2479     int i, j;
2480     CvMat sample;
2481     int sample_step;
2482     uchar* sample_data;
2483     
2484     int data_size;
2485     int* idx;
2486     int respnum;
2487     float val;
2488     double val_f;
2489
2490     float sum_weights;
2491     float* weights;
2492     float* sorted_weights;
2493     CvMat* trimmed_idx;
2494     CvMat* sample_idx;
2495     int index;
2496     int trimmed_num;
2497
2498     data_size = trainer->m * sizeof( *idx );
2499     idx = (int*) cvAlloc( data_size );
2500
2501     data_size = trainer->m * sizeof( *weights );
2502     weights = (float*) cvAlloc( data_size );
2503     data_size = trainer->m * sizeof( *sorted_weights );
2504     sorted_weights = (float*) cvAlloc( data_size );
2505     
2506     /* yhat_i = (4 * y_i - 2) / ( 1 + exp( (4 * y_i - 2) * F_(m-1)(x_i) ) ).
2507      *   y_i in {0, 1}
2508      */
2509     sum_weights = 0.0F;
2510     for( i = 0; i < trainer->numsamples; i++ )
2511     {
2512         index = icvGetIdxAt( trainer->sampleIdx, i );
2513         val = 4.0F * (*((float*) (trainer->ydata + index * trainer->ystep))) - 2.0F;
2514         val_f = val * trainer->f[index];
2515         val_f = ( val_f < CV_LOG_VAL_MAX ) ? exp( val_f ) : CV_LOG_VAL_MAX;
2516         val = (float) ( (double) val / ( 1.0 + val_f ) );
2517         trainer->y->data.fl[index] = val;
2518         val = (float) fabs( val );
2519         weights[index] = val * (2.0F - val);
2520         sorted_weights[i] = weights[index];
2521         sum_weights += sorted_weights[i];
2522     }
2523     
2524     trimmed_idx = NULL;
2525     sample_idx = trainer->sampleIdx;
2526     trimmed_num = trainer->numsamples;
2527     if( trainer->param[1] < 1.0F )
2528     {
2529         /* perform weight trimming */
2530         
2531         float threshold;
2532         int count;
2533         
2534         icvSort_32f( sorted_weights, trainer->numsamples, 0 );
2535
2536         sum_weights *= (1.0F - trainer->param[1]);
2537         
2538         i = -1;
2539         do { sum_weights -= sorted_weights[++i]; }
2540         while( sum_weights > 0.0F && i < (trainer->numsamples - 1) );
2541         
2542         threshold = sorted_weights[i];
2543
2544         while( i > 0 && sorted_weights[i-1] == threshold ) i--;
2545
2546         if( i > 0 )
2547         {
2548             trimmed_num = trainer->numsamples - i;            
2549             trimmed_idx = cvCreateMat( 1, trimmed_num, CV_32FC1 );
2550             count = 0;
2551             for( i = 0; i < trainer->numsamples; i++ )
2552             {
2553                 index = icvGetIdxAt( trainer->sampleIdx, i );
2554                 if( weights[index] >= threshold )
2555                 {
2556                     CV_MAT_ELEM( *trimmed_idx, float, 0, count ) = (float) index;
2557                     count++;
2558                 }
2559             }
2560             
2561             assert( count == trimmed_num );
2562
2563             sample_idx = trimmed_idx;
2564
2565             printf( "Used samples %%: %g\n", 
2566                 (float) trimmed_num / (float) trainer->numsamples * 100.0F );
2567         }
2568     }
2569
2570     ptr = (CvCARTClassifier*) cvCreateCARTClassifier( trainer->trainData, trainer->flags,
2571         trainer->y, NULL, NULL, NULL, sample_idx, trainer->weights,
2572         (CvClassifierTrainParams*) &trainer->cartParams );
2573
2574     CV_GET_SAMPLE( *trainer->trainData, trainer->flags, 0, sample );
2575     CV_GET_SAMPLE_STEP( *trainer->trainData, trainer->flags, sample_step );
2576     sample_data = sample.data.ptr;
2577     for( i = 0; i < trimmed_num; i++ )
2578     {
2579         index = icvGetIdxAt( sample_idx, i );
2580         sample.data.ptr = sample_data + index * sample_step;
2581         idx[index] = (int) cvEvalCARTClassifierIdx( (CvClassifier*) ptr, &sample );
2582     }
2583     for( j = 0; j <= ptr->count; j++ )
2584     {
2585         respnum = 0;
2586         val = 0.0F;
2587         sum_weights = 0.0F;
2588         for( i = 0; i < trimmed_num; i++ )
2589         {
2590             index = icvGetIdxAt( sample_idx, i );
2591             if( idx[index] == j )
2592             {
2593                 val += trainer->y->data.fl[index];
2594                 sum_weights += weights[index];
2595                 respnum++;
2596             }
2597         }
2598         if( sum_weights > 0.0F )
2599         {
2600             val /= sum_weights;
2601         }
2602         else
2603         {
2604             val = 0.0F;
2605         }
2606         ptr->val[j] = val;
2607     }
2608     
2609     if( trimmed_idx != NULL ) cvReleaseMat( &trimmed_idx );
2610     cvFree( &sorted_weights );
2611     cvFree( &weights );
2612     cvFree( &idx );
2613     
2614     trees[0] = ptr;
2615 }
2616
2617 void icvBtNext_LKCLASS( CvCARTClassifier** trees, CvBtTrainer* trainer )
2618 {
2619     int i, j, k, kk, num;
2620     CvMat sample;
2621     int sample_step;
2622     uchar* sample_data;
2623     
2624     int data_size;
2625     int* idx;
2626     int respnum;
2627     float val;
2628
2629     float sum_weights;
2630     float* weights;
2631     float* sorted_weights;
2632     CvMat* trimmed_idx;
2633     CvMat* sample_idx;
2634     int index;
2635     int trimmed_num;
2636     double sum_exp_f;
2637     double exp_f;
2638     double f_k;
2639
2640     data_size = trainer->m * sizeof( *idx );
2641     idx = (int*) cvAlloc( data_size );
2642     data_size = trainer->m * sizeof( *weights );
2643     weights = (float*) cvAlloc( data_size );
2644     data_size = trainer->m * sizeof( *sorted_weights );
2645     sorted_weights = (float*) cvAlloc( data_size );
2646     trimmed_idx = cvCreateMat( 1, trainer->numsamples, CV_32FC1 );
2647
2648     for( k = 0; k < trainer->numclasses; k++ )
2649     {
2650         /* yhat_i = y_i - p_k(x_i), y_i in {0, 1}      */
2651         /* p_k(x_i) = exp(f_k(x_i)) / (sum_exp_f(x_i)) */
2652         sum_weights = 0.0F;
2653         for( i = 0; i < trainer->numsamples; i++ )
2654         {
2655             index = icvGetIdxAt( trainer->sampleIdx, i );
2656             /* p_k(x_i) = 1 / (1 + sum(exp(f_kk(x_i) - f_k(x_i)))), kk != k */
2657             num = index * trainer->numclasses;
2658             f_k = (double) trainer->f[num + k];
2659             sum_exp_f = 1.0;
2660             for( kk = 0; kk < trainer->numclasses; kk++ )
2661             {
2662                 if( kk == k ) continue;
2663                 exp_f = (double) trainer->f[num + kk] - f_k;
2664                 exp_f = (exp_f < CV_LOG_VAL_MAX) ? exp( exp_f ) : CV_VAL_MAX;
2665                 if( exp_f == CV_VAL_MAX || exp_f >= (CV_VAL_MAX - sum_exp_f) )
2666                 {
2667                     sum_exp_f = CV_VAL_MAX;
2668                     break;
2669                 }
2670                 sum_exp_f += exp_f;
2671             }
2672
2673             val = (float) ( (*((float*) (trainer->ydata + index * trainer->ystep))) 
2674                             == (float) k );
2675             val -= (float) ( (sum_exp_f == CV_VAL_MAX) ? 0.0 : ( 1.0 / sum_exp_f ) );
2676
2677             assert( val >= -1.0F );
2678             assert( val <= 1.0F );
2679
2680             trainer->y->data.fl[index] = val;
2681             val = (float) fabs( val );
2682             weights[index] = val * (1.0F - val);
2683             sorted_weights[i] = weights[index];
2684             sum_weights += sorted_weights[i];
2685         }
2686
2687         sample_idx = trainer->sampleIdx;
2688         trimmed_num = trainer->numsamples;
2689         if( trainer->param[1] < 1.0F )
2690         {
2691             /* perform weight trimming */
2692         
2693             float threshold;
2694             int count;
2695         
2696             icvSort_32f( sorted_weights, trainer->numsamples, 0 );
2697
2698             sum_weights *= (1.0F - trainer->param[1]);
2699         
2700             i = -1;
2701             do { sum_weights -= sorted_weights[++i]; }
2702             while( sum_weights > 0.0F && i < (trainer->numsamples - 1) );
2703         
2704             threshold = sorted_weights[i];
2705
2706             while( i > 0 && sorted_weights[i-1] == threshold ) i--;
2707
2708             if( i > 0 )
2709             {
2710                 trimmed_num = trainer->numsamples - i;            
2711                 trimmed_idx->cols = trimmed_num;
2712                 count = 0;
2713                 for( i = 0; i < trainer->numsamples; i++ )
2714                 {
2715                     index = icvGetIdxAt( trainer->sampleIdx, i );
2716                     if( weights[index] >= threshold )
2717                     {
2718                         CV_MAT_ELEM( *trimmed_idx, float, 0, count ) = (float) index;
2719                         count++;
2720                     }
2721                 }
2722             
2723                 assert( count == trimmed_num );
2724
2725                 sample_idx = trimmed_idx;
2726
2727                 printf( "k: %d Used samples %%: %g\n", k, 
2728                     (float) trimmed_num / (float) trainer->numsamples * 100.0F );
2729             }
2730         } /* weight trimming */
2731
2732         trees[k] = (CvCARTClassifier*) cvCreateCARTClassifier( trainer->trainData,
2733             trainer->flags, trainer->y, NULL, NULL, NULL, sample_idx, trainer->weights,
2734             (CvClassifierTrainParams*) &trainer->cartParams );
2735
2736         CV_GET_SAMPLE( *trainer->trainData, trainer->flags, 0, sample );
2737         CV_GET_SAMPLE_STEP( *trainer->trainData, trainer->flags, sample_step );
2738         sample_data = sample.data.ptr;
2739         for( i = 0; i < trimmed_num; i++ )
2740         {
2741             index = icvGetIdxAt( sample_idx, i );
2742             sample.data.ptr = sample_data + index * sample_step;
2743             idx[index] = (int) cvEvalCARTClassifierIdx( (CvClassifier*) trees[k],
2744                                                         &sample );
2745         }
2746         for( j = 0; j <= trees[k]->count; j++ )
2747         {
2748             respnum = 0;
2749             val = 0.0F;
2750             sum_weights = 0.0F;
2751             for( i = 0; i < trimmed_num; i++ )
2752             {
2753                 index = icvGetIdxAt( sample_idx, i );
2754                 if( idx[index] == j )
2755                 {
2756                     val += trainer->y->data.fl[index];
2757                     sum_weights += weights[index];
2758                     respnum++;
2759                 }
2760             }
2761             if( sum_weights > 0.0F )
2762             {
2763                 val = ((float) (trainer->numclasses - 1)) * val /
2764                       ((float) (trainer->numclasses)) / sum_weights;
2765             }
2766             else
2767             {
2768                 val = 0.0F;
2769             }
2770             trees[k]->val[j] = val;
2771         }
2772     } /* for each class */
2773     
2774     cvReleaseMat( &trimmed_idx );
2775     cvFree( &sorted_weights );
2776     cvFree( &weights );
2777     cvFree( &idx );
2778 }
2779
2780
2781 void icvBtNext_XXBCLASS( CvCARTClassifier** trees, CvBtTrainer* trainer )
2782 {
2783     float alpha;
2784     int i;
2785     CvMat* weak_eval_vals;
2786     CvMat* sample_idx;
2787     int num_samples;
2788     CvMat sample;
2789     uchar* sample_data;
2790     int sample_step;
2791
2792     weak_eval_vals = cvCreateMat( 1, trainer->m, CV_32FC1 );
2793
2794     sample_idx = cvTrimWeights( trainer->weights, trainer->sampleIdx,
2795                                 trainer->param[1] );
2796     num_samples = ( sample_idx == NULL )
2797         ? trainer->m : MAX( sample_idx->rows, sample_idx->cols );
2798
2799     printf( "Used samples %%: %g\n", 
2800         (float) num_samples / (float) trainer->numsamples * 100.0F );
2801
2802     trees[0] = (CvCARTClassifier*) cvCreateCARTClassifier( trainer->trainData,
2803         trainer->flags, trainer->y, NULL, NULL, NULL,
2804         sample_idx, trainer->weights,
2805         (CvClassifierTrainParams*) &trainer->cartParams );
2806     
2807     /* evaluate samples */
2808     CV_GET_SAMPLE( *trainer->trainData, trainer->flags, 0, sample );
2809     CV_GET_SAMPLE_STEP( *trainer->trainData, trainer->flags, sample_step );
2810     sample_data = sample.data.ptr;
2811     
2812     for( i = 0; i < trainer->m; i++ )
2813     {
2814         sample.data.ptr = sample_data + i * sample_step;
2815         weak_eval_vals->data.fl[i] = trees[0]->eval( (CvClassifier*) trees[0], &sample );
2816     }
2817
2818     alpha = cvBoostNextWeakClassifier( weak_eval_vals, trainer->trainClasses,
2819         trainer->y, trainer->weights, trainer->boosttrainer );
2820     
2821     /* multiply tree by alpha */
2822     for( i = 0; i <= trees[0]->count; i++ )
2823     {
2824         trees[0]->val[i] *= alpha;
2825     }
2826     if( trainer->type == CV_RABCLASS )
2827     {
2828         for( i = 0; i <= trees[0]->count; i++ )
2829         {
2830             trees[0]->val[i] = cvLogRatio( trees[0]->val[i] );
2831         }
2832     }
2833     
2834     if( sample_idx != NULL && sample_idx != trainer->sampleIdx )
2835     {
2836         cvReleaseMat( &sample_idx );
2837     }
2838     cvReleaseMat( &weak_eval_vals );
2839 }
2840
2841 typedef void (*CvBtNextFunc)( CvCARTClassifier** trees, CvBtTrainer* trainer );
2842
2843 static CvBtNextFunc icvBtNextFunc[] =
2844 {
2845     icvBtNext_XXBCLASS,
2846     icvBtNext_XXBCLASS,
2847     icvBtNext_XXBCLASS,
2848     icvBtNext_XXBCLASS,
2849     icvBtNext_L2CLASS,
2850     icvBtNext_LKCLASS,
2851     icvBtNext_LSREG,
2852     icvBtNext_LADREG,
2853     icvBtNext_MREG
2854 };
2855
2856 CV_BOOST_IMPL
2857 void cvBtNext( CvCARTClassifier** trees, CvBtTrainer* trainer )
2858 {
2859     int i, j;
2860     int index;
2861     CvMat sample;
2862     int sample_step;
2863     uchar* sample_data;
2864
2865     icvBtNextFunc[trainer->type]( trees, trainer );        
2866
2867     /* shrinkage */
2868     if( trainer->param[0] != 1.0F )
2869     {
2870         for( j = 0; j < trainer->numclasses; j++ )
2871         {
2872             for( i = 0; i <= trees[j]->count; i++ )
2873             {
2874                 trees[j]->val[i] *= trainer->param[0];
2875             }
2876         }
2877     }
2878
2879     if( trainer->type > CV_GABCLASS )
2880     {
2881         /* update F_(m-1) */
2882         CV_GET_SAMPLE( *(trainer->trainData), trainer->flags, 0, sample );
2883         CV_GET_SAMPLE_STEP( *(trainer->trainData), trainer->flags, sample_step );
2884         sample_data = sample.data.ptr;
2885         for( i = 0; i < trainer->numsamples; i++ )
2886         {
2887             index = icvGetIdxAt( trainer->sampleIdx, i );
2888             sample.data.ptr = sample_data + index * sample_step;
2889             for( j = 0; j < trainer->numclasses; j++ )
2890             {            
2891                 trainer->f[index * trainer->numclasses + j] += 
2892                     trees[j]->eval( (CvClassifier*) (trees[j]), &sample );
2893             }
2894         }
2895     }
2896 }
2897
2898 CV_BOOST_IMPL
2899 void cvBtEnd( CvBtTrainer** trainer )
2900 {
2901     CV_FUNCNAME( "cvBtEnd" );
2902     
2903     __BEGIN__;
2904     
2905     if( trainer == NULL || (*trainer) == NULL )
2906     {
2907         CV_ERROR( CV_StsNullPtr, "Invalid trainer parameter" );
2908     }
2909     
2910     if( (*trainer)->y != NULL )
2911     {
2912         CV_CALL( cvReleaseMat( &((*trainer)->y) ) );
2913     }
2914     if( (*trainer)->weights != NULL )
2915     {
2916         CV_CALL( cvReleaseMat( &((*trainer)->weights) ) );
2917     }
2918     if( (*trainer)->boosttrainer != NULL )
2919     {
2920         CV_CALL( cvBoostEndTraining( &((*trainer)->boosttrainer) ) );
2921     }
2922     CV_CALL( cvFree( trainer ) );
2923
2924     __END__;
2925 }
2926
2927 /****************************************************************************************\
2928 *                         Boosted tree model as a classifier                             *
2929 \****************************************************************************************/
2930
2931 CV_BOOST_IMPL
2932 float cvEvalBtClassifier( CvClassifier* classifier, CvMat* sample )
2933 {
2934     float val;
2935
2936     CV_FUNCNAME( "cvEvalBtClassifier" );
2937
2938     __BEGIN__;
2939     
2940     int i;
2941
2942     val = 0.0F;
2943     if( CV_IS_TUNABLE( classifier->flags ) )
2944     {
2945         CvSeqReader reader;
2946         CvCARTClassifier* tree;
2947
2948         CV_CALL( cvStartReadSeq( ((CvBtClassifier*) classifier)->seq, &reader ) );
2949         for( i = 0; i < ((CvBtClassifier*) classifier)->numiter; i++ )
2950         {
2951             CV_READ_SEQ_ELEM( tree, reader );
2952             val += tree->eval( (CvClassifier*) tree, sample );
2953         }
2954     }
2955     else
2956     {
2957         CvCARTClassifier** ptree;
2958
2959         ptree = ((CvBtClassifier*) classifier)->trees;
2960         for( i = 0; i < ((CvBtClassifier*) classifier)->numiter; i++ )
2961         {
2962             val += (*ptree)->eval( (CvClassifier*) (*ptree), sample );
2963             ptree++;
2964         }
2965     }
2966
2967     __END__;
2968
2969     return val;
2970 }
2971
2972 CV_BOOST_IMPL
2973 float cvEvalBtClassifier2( CvClassifier* classifier, CvMat* sample )
2974 {
2975     float val;
2976
2977     CV_FUNCNAME( "cvEvalBtClassifier2" );
2978
2979     __BEGIN__;
2980     
2981     CV_CALL( val = cvEvalBtClassifier( classifier, sample ) );
2982
2983     __END__;
2984
2985     return (float) (val >= 0.0F);
2986 }
2987
2988 CV_BOOST_IMPL
2989 float cvEvalBtClassifierK( CvClassifier* classifier, CvMat* sample )
2990 {
2991     int cls = 0;
2992
2993     CV_FUNCNAME( "cvEvalBtClassifierK" );
2994
2995     __BEGIN__;
2996     
2997     int i, k;
2998     float max_val;
2999     int numclasses;
3000
3001     float* vals;
3002     size_t data_size;
3003
3004     numclasses = ((CvBtClassifier*) classifier)->numclasses;
3005     data_size = sizeof( *vals ) * numclasses;
3006     CV_CALL( vals = (float*) cvAlloc( data_size ) );
3007     memset( vals, 0, data_size );
3008
3009     if( CV_IS_TUNABLE( classifier->flags ) )
3010     {
3011         CvSeqReader reader;
3012         CvCARTClassifier* tree;
3013
3014         CV_CALL( cvStartReadSeq( ((CvBtClassifier*) classifier)->seq, &reader ) );
3015         for( i = 0; i < ((CvBtClassifier*) classifier)->numiter; i++ )
3016         {
3017             for( k = 0; k < numclasses; k++ )
3018             {
3019                 CV_READ_SEQ_ELEM( tree, reader );
3020                 vals[k] += tree->eval( (CvClassifier*) tree, sample );
3021             }
3022         }
3023
3024     }
3025     else
3026     {
3027         CvCARTClassifier** ptree;
3028
3029         ptree = ((CvBtClassifier*) classifier)->trees;
3030         for( i = 0; i < ((CvBtClassifier*) classifier)->numiter; i++ )
3031         {
3032             for( k = 0; k < numclasses; k++ )
3033             {
3034                 vals[k] += (*ptree)->eval( (CvClassifier*) (*ptree), sample );
3035                 ptree++;
3036             }
3037         }
3038     }
3039
3040     max_val = vals[cls];
3041     for( k = 1; k < numclasses; k++ )
3042     {
3043         if( vals[k] > max_val )
3044         {
3045             max_val = vals[k];
3046             cls = k;
3047         }
3048     }
3049
3050     CV_CALL( cvFree( &vals ) );
3051
3052     __END__;
3053
3054     return (float) cls;
3055 }
3056
3057 typedef float (*CvEvalBtClassifier)( CvClassifier* classifier, CvMat* sample );
3058
3059 static CvEvalBtClassifier icvEvalBtClassifier[] =
3060 {
3061     cvEvalBtClassifier2,
3062     cvEvalBtClassifier2,
3063     cvEvalBtClassifier2,
3064     cvEvalBtClassifier2,
3065     cvEvalBtClassifier2,
3066     cvEvalBtClassifierK,
3067     cvEvalBtClassifier,
3068     cvEvalBtClassifier,
3069     cvEvalBtClassifier
3070 };
3071
3072 CV_BOOST_IMPL
3073 int cvSaveBtClassifier( CvClassifier* classifier, const char* filename )
3074 {
3075     CV_FUNCNAME( "cvSaveBtClassifier" );
3076
3077     __BEGIN__;
3078
3079     FILE* file;
3080     int i, j;
3081     CvSeqReader reader;
3082     memset(&reader, 0, sizeof(reader));
3083     CvCARTClassifier* tree;
3084
3085     CV_ASSERT( classifier );
3086     CV_ASSERT( filename );
3087     
3088     if( !icvMkDir( filename ) || (file = fopen( filename, "w" )) == 0 )
3089     {
3090         CV_ERROR( CV_StsError, "Unable to create file" );
3091     }
3092
3093     if( CV_IS_TUNABLE( classifier->flags ) )
3094     {
3095         CV_CALL( cvStartReadSeq( ((CvBtClassifier*) classifier)->seq, &reader ) );
3096     }
3097     fprintf( file, "%d %d\n%d\n%d\n", (int) ((CvBtClassifier*) classifier)->type,
3098                                       ((CvBtClassifier*) classifier)->numclasses,
3099                                       ((CvBtClassifier*) classifier)->numfeatures,
3100                                       ((CvBtClassifier*) classifier)->numiter );
3101     
3102     for( i = 0; i < ((CvBtClassifier*) classifier)->numclasses *
3103                     ((CvBtClassifier*) classifier)->numiter; i++ )
3104     {
3105         if( CV_IS_TUNABLE( classifier->flags ) )
3106         {
3107             CV_READ_SEQ_ELEM( tree, reader );
3108         }
3109         else
3110         {
3111             tree = ((CvBtClassifier*) classifier)->trees[i];
3112         }
3113
3114         fprintf( file, "%d\n", tree->count );
3115         for( j = 0; j < tree->count; j++ )
3116         {
3117             fprintf( file, "%d %g %d %d\n", tree->compidx[j],
3118                                             tree->threshold[j],
3119                                             tree->left[j],
3120                                             tree->right[j] );
3121         }
3122         for( j = 0; j <= tree->count; j++ )
3123         {
3124             fprintf( file, "%g ", tree->val[j] );
3125         }
3126         fprintf( file, "\n" );
3127     }
3128
3129     fclose( file );
3130
3131     __END__;
3132
3133     return 1;
3134 }
3135
3136
3137 CV_BOOST_IMPL
3138 void cvReleaseBtClassifier( CvClassifier** ptr )
3139 {
3140     CV_FUNCNAME( "cvReleaseBtClassifier" );
3141
3142     __BEGIN__;
3143
3144     int i;
3145
3146     if( ptr == NULL || *ptr == NULL )
3147     {
3148         CV_ERROR( CV_StsNullPtr, "" );
3149     }
3150     if( CV_IS_TUNABLE( (*ptr)->flags ) )
3151     {
3152         CvSeqReader reader;
3153         CvCARTClassifier* tree;
3154
3155         CV_CALL( cvStartReadSeq( ((CvBtClassifier*) *ptr)->seq, &reader ) );
3156         for( i = 0; i < ((CvBtClassifier*) *ptr)->numclasses *
3157                         ((CvBtClassifier*) *ptr)->numiter; i++ )
3158         {
3159             CV_READ_SEQ_ELEM( tree, reader );
3160             tree->release( (CvClassifier**) (&tree) );
3161         }
3162         CV_CALL( cvReleaseMemStorage( &(((CvBtClassifier*) *ptr)->seq->storage) ) );
3163     }
3164     else
3165     {
3166         CvCARTClassifier** ptree;
3167
3168         ptree = ((CvBtClassifier*) *ptr)->trees;
3169         for( i = 0; i < ((CvBtClassifier*) *ptr)->numclasses *
3170                         ((CvBtClassifier*) *ptr)->numiter; i++ )
3171         {
3172             (*ptree)->release( (CvClassifier**) ptree );
3173             ptree++;
3174         }
3175     }
3176
3177     CV_CALL( cvFree( ptr ) );
3178     *ptr = NULL;
3179
3180     __END__;
3181 }
3182
3183 void cvTuneBtClassifier( CvClassifier* classifier, CvMat*, int flags,
3184                          CvMat*, CvMat* , CvMat*, CvMat*, CvMat* )
3185 {
3186     CV_FUNCNAME( "cvTuneBtClassifier" );
3187
3188     __BEGIN__;
3189
3190     size_t data_size;
3191
3192     if( CV_IS_TUNABLE( flags ) )
3193     {
3194         if( !CV_IS_TUNABLE( classifier->flags ) )
3195         {
3196             CV_ERROR( CV_StsUnsupportedFormat,
3197                       "Classifier does not support tune function" );
3198         }
3199         else
3200         {
3201             /* tune classifier */
3202             CvCARTClassifier** trees;
3203
3204             printf( "Iteration %d\n", ((CvBtClassifier*) classifier)->numiter + 1 );
3205
3206             data_size = sizeof( *trees ) * ((CvBtClassifier*) classifier)->numclasses;
3207             CV_CALL( trees = (CvCARTClassifier**) cvAlloc( data_size ) );
3208             CV_CALL( cvBtNext( trees,
3209                 (CvBtTrainer*) ((CvBtClassifier*) classifier)->trainer ) );
3210             CV_CALL( cvSeqPushMulti( ((CvBtClassifier*) classifier)->seq,
3211                 trees, ((CvBtClassifier*) classifier)->numclasses ) );
3212             CV_CALL( cvFree( &trees ) );
3213             ((CvBtClassifier*) classifier)->numiter++;
3214         }
3215     }
3216     else
3217     {
3218         if( CV_IS_TUNABLE( classifier->flags ) )
3219         {
3220             /* convert */
3221             void* ptr;
3222
3223             assert( ((CvBtClassifier*) classifier)->seq->total ==
3224                         ((CvBtClassifier*) classifier)->numiter *
3225                         ((CvBtClassifier*) classifier)->numclasses );
3226
3227             data_size = sizeof( ((CvBtClassifier*) classifier)->trees[0] ) *
3228                 ((CvBtClassifier*) classifier)->seq->total;
3229             CV_CALL( ptr = cvAlloc( data_size ) );
3230             CV_CALL( cvCvtSeqToArray( ((CvBtClassifier*) classifier)->seq, ptr ) );
3231             CV_CALL( cvReleaseMemStorage( 
3232                     &(((CvBtClassifier*) classifier)->seq->storage) ) );
3233             ((CvBtClassifier*) classifier)->trees = (CvCARTClassifier**) ptr;
3234             classifier->flags &= ~CV_TUNABLE;
3235             CV_CALL( cvBtEnd( (CvBtTrainer**)
3236                 &(((CvBtClassifier*) classifier)->trainer )) );
3237             ((CvBtClassifier*) classifier)->trainer = NULL;
3238         }
3239     }
3240
3241     __END__;
3242 }
3243
3244 CvBtClassifier* icvAllocBtClassifier( CvBoostType type, int flags, int numclasses,
3245                                       int numiter )
3246 {
3247     CvBtClassifier* ptr;
3248     size_t data_size;
3249
3250     assert( numclasses >= 1 );
3251     assert( numiter >= 0 );
3252     assert( ( numclasses == 1 ) || (type == CV_LKCLASS) );
3253
3254     data_size = sizeof( *ptr );
3255     ptr = (CvBtClassifier*) cvAlloc( data_size );
3256     memset( ptr, 0, data_size );
3257
3258     if( CV_IS_TUNABLE( flags ) )
3259     {
3260         ptr->seq = cvCreateSeq( 0, sizeof( *(ptr->seq) ), sizeof( *(ptr->trees) ),
3261                                 cvCreateMemStorage() );
3262         ptr->numiter = 0;
3263     }
3264     else
3265     {
3266         data_size = numclasses * numiter * sizeof( *(ptr->trees) );
3267         ptr->trees = (CvCARTClassifier**) cvAlloc( data_size );
3268         memset( ptr->trees, 0, data_size );
3269
3270         ptr->numiter = numiter;
3271     }
3272
3273     ptr->flags = flags;
3274     ptr->numclasses = numclasses;
3275     ptr->type = type;
3276
3277     ptr->eval = icvEvalBtClassifier[(int) type];
3278     ptr->tune = cvTuneBtClassifier;
3279     ptr->save = cvSaveBtClassifier;
3280     ptr->release = cvReleaseBtClassifier;
3281
3282     return ptr;
3283 }
3284
3285 CV_BOOST_IMPL
3286 CvClassifier* cvCreateBtClassifier( CvMat* trainData,
3287                                     int flags,
3288                                     CvMat* trainClasses,
3289                                     CvMat* typeMask,
3290                                     CvMat* missedMeasurementsMask,
3291                                     CvMat* compIdx,
3292                                     CvMat* sampleIdx,
3293                                     CvMat* weights,
3294                                     CvClassifierTrainParams* trainParams )
3295 {
3296     CvBtClassifier* ptr = 0;
3297
3298     CV_FUNCNAME( "cvCreateBtClassifier" );
3299
3300     __BEGIN__;
3301     CvBoostType type;
3302     int num_classes;
3303     int num_iter;
3304     int i;
3305     CvCARTClassifier** trees;
3306     size_t data_size;
3307
3308     CV_ASSERT( trainData != NULL );
3309     CV_ASSERT( trainClasses != NULL );
3310     CV_ASSERT( typeMask == NULL );
3311     CV_ASSERT( missedMeasurementsMask == NULL );
3312     CV_ASSERT( compIdx == NULL );
3313     CV_ASSERT( weights == NULL );
3314     CV_ASSERT( trainParams != NULL );
3315
3316     type = ((CvBtClassifierTrainParams*) trainParams)->type;
3317     
3318     if( type >= CV_DABCLASS && type <= CV_GABCLASS && sampleIdx )
3319     {
3320         CV_ERROR( CV_StsBadArg, "Sample indices are not supported for this type" );
3321     }
3322
3323     if( type == CV_LKCLASS )
3324     {
3325         double min_val;
3326         double max_val;
3327
3328         cvMinMaxLoc( trainClasses, &min_val, &max_val );
3329         num_classes = (int) (max_val + 1.0);
3330         
3331         CV_ASSERT( num_classes >= 2 );
3332     }
3333     else
3334     {
3335         num_classes = 1;
3336     }
3337     num_iter = ((CvBtClassifierTrainParams*) trainParams)->numiter;
3338     
3339     CV_ASSERT( num_iter > 0 );
3340
3341     ptr = icvAllocBtClassifier( type, CV_TUNABLE | flags, num_classes, num_iter );
3342     ptr->numfeatures = (CV_IS_ROW_SAMPLE( flags )) ? trainData->cols : trainData->rows;
3343     
3344     i = 0;
3345
3346     printf( "Iteration %d\n", 1 );
3347
3348     data_size = sizeof( *trees ) * ptr->numclasses;
3349     CV_CALL( trees = (CvCARTClassifier**) cvAlloc( data_size ) );
3350
3351     CV_CALL( ptr->trainer = cvBtStart( trees, trainData, flags, trainClasses, sampleIdx,
3352         ((CvBtClassifierTrainParams*) trainParams)->numsplits, type, num_classes,
3353         &(((CvBtClassifierTrainParams*) trainParams)->param[0]) ) );
3354
3355     CV_CALL( cvSeqPushMulti( ptr->seq, trees, ptr->numclasses ) );
3356     CV_CALL( cvFree( &trees ) );
3357     ptr->numiter++;
3358     
3359     for( i = 1; i < num_iter; i++ )
3360     {
3361         ptr->tune( (CvClassifier*) ptr, NULL, CV_TUNABLE, NULL, NULL, NULL, NULL, NULL );
3362     }
3363     if( !CV_IS_TUNABLE( flags ) )
3364     {
3365         /* convert */
3366         ptr->tune( (CvClassifier*) ptr, NULL, 0, NULL, NULL, NULL, NULL, NULL );
3367     }
3368
3369     __END__;
3370
3371     return (CvClassifier*) ptr;
3372 }
3373
3374 CV_BOOST_IMPL
3375 CvClassifier* cvCreateBtClassifierFromFile( const char* filename )
3376 {
3377     CvBtClassifier* ptr = 0;
3378
3379     CV_FUNCNAME( "cvCreateBtClassifierFromFile" );
3380     
3381     __BEGIN__;
3382
3383     FILE* file;
3384     int i, j;
3385     int data_size;
3386     int num_classifiers;
3387     int num_features;
3388     int num_classes;
3389     int type;
3390
3391     CV_ASSERT( filename != NULL );
3392
3393     ptr = NULL;
3394     file = fopen( filename, "r" );
3395     if( !file )
3396     {
3397         CV_ERROR( CV_StsError, "Unable to open file" );
3398     }
3399     
3400     fscanf( file, "%d %d %d %d", &type, &num_classes, &num_features, &num_classifiers );
3401
3402     CV_ASSERT( type >= (int) CV_DABCLASS && type <= (int) CV_MREG );
3403     CV_ASSERT( num_features > 0 );
3404     CV_ASSERT( num_classifiers > 0 );
3405
3406     if( (CvBoostType) type != CV_LKCLASS )
3407     {
3408         num_classes = 1;
3409     }
3410     ptr = icvAllocBtClassifier( (CvBoostType) type, 0, num_classes, num_classifiers );
3411     ptr->numfeatures = num_features;
3412     
3413     for( i = 0; i < num_classes * num_classifiers; i++ )
3414     {
3415         int count;
3416         CvCARTClassifier* tree;
3417
3418         fscanf( file, "%d", &count );
3419
3420         data_size = sizeof( *tree )
3421             + count * ( sizeof( *(tree->compidx) ) + sizeof( *(tree->threshold) ) +
3422                         sizeof( *(tree->right) ) + sizeof( *(tree->left) ) )
3423             + (count + 1) * ( sizeof( *(tree->val) ) );
3424         CV_CALL( tree = (CvCARTClassifier*) cvAlloc( data_size ) );
3425         memset( tree, 0, data_size );
3426         tree->eval = cvEvalCARTClassifier;
3427         tree->tune = NULL;
3428         tree->save = NULL;
3429         tree->release = cvReleaseCARTClassifier;
3430         tree->compidx = (int*) ( tree + 1 );
3431         tree->threshold = (float*) ( tree->compidx + count );
3432         tree->left = (int*) ( tree->threshold + count );
3433         tree->right = (int*) ( tree->left + count );
3434         tree->val = (float*) ( tree->right + count );
3435
3436         tree->count = count;
3437         for( j = 0; j < tree->count; j++ )
3438         {
3439             fscanf( file, "%d %g %d %d", &(tree->compidx[j]),
3440                                          &(tree->threshold[j]),
3441                                          &(tree->left[j]),
3442                                          &(tree->right[j]) );
3443         }
3444         for( j = 0; j <= tree->count; j++ )
3445         {
3446             fscanf( file, "%g", &(tree->val[j]) );
3447         }
3448         ptr->trees[i] = tree;
3449     }
3450
3451     fclose( file );
3452
3453     __END__;
3454
3455     return (CvClassifier*) ptr;
3456 }
3457
3458 /****************************************************************************************\
3459 *                                    Utility functions                                   *
3460 \****************************************************************************************/
3461
3462 CV_BOOST_IMPL
3463 CvMat* cvTrimWeights( CvMat* weights, CvMat* idx, float factor )
3464 {
3465     CvMat* ptr = 0;
3466
3467     CV_FUNCNAME( "cvTrimWeights" );
3468     __BEGIN__;
3469     int i, index, num;
3470     float sum_weights;
3471     uchar* wdata;
3472     size_t wstep;
3473     int wnum;
3474     float threshold;
3475     int count;
3476     float* sorted_weights;
3477
3478     CV_ASSERT( CV_MAT_TYPE( weights->type ) == CV_32FC1 );
3479
3480     ptr = idx;
3481     sorted_weights = NULL;
3482
3483     if( factor > 0.0F && factor < 1.0F )
3484     {
3485         size_t data_size;
3486
3487         CV_MAT2VEC( *weights, wdata, wstep, wnum );
3488         num = ( idx == NULL ) ? wnum : MAX( idx->rows, idx->cols );
3489
3490         data_size = num * sizeof( *sorted_weights );
3491         sorted_weights = (float*) cvAlloc( data_size );
3492         memset( sorted_weights, 0, data_size );
3493
3494         sum_weights = 0.0F;
3495         for( i = 0; i < num; i++ )
3496         {
3497             index = icvGetIdxAt( idx, i );
3498             sorted_weights[i] = *((float*) (wdata + index * wstep));
3499             sum_weights += sorted_weights[i];
3500         }
3501
3502         icvSort_32f( sorted_weights, num, 0 );
3503
3504         sum_weights *= (1.0F - factor);
3505
3506         i = -1;
3507         do { sum_weights -= sorted_weights[++i]; }
3508         while( sum_weights > 0.0F && i < (num - 1) );
3509
3510         threshold = sorted_weights[i];
3511
3512         while( i > 0 && sorted_weights[i-1] == threshold ) i--;
3513
3514         if( i > 0 || ( idx != NULL && CV_MAT_TYPE( idx->type ) != CV_32FC1 ) )
3515         {
3516             CV_CALL( ptr = cvCreateMat( 1, num - i, CV_32FC1 ) );
3517             count = 0;
3518             for( i = 0; i < num; i++ )
3519             {
3520                 index = icvGetIdxAt( idx, i );
3521                 if( *((float*) (wdata + index * wstep)) >= threshold )
3522                 {
3523                     CV_MAT_ELEM( *ptr, float, 0, count ) = (float) index;
3524                     count++;
3525                 }
3526             }
3527         
3528             assert( count == ptr->cols );
3529         }
3530         cvFree( &sorted_weights );
3531     }
3532
3533     __END__;
3534
3535     return ptr;
3536 }
3537
3538
3539 CV_BOOST_IMPL
3540 void cvReadTrainData( const char* filename, int flags,
3541                       CvMat** trainData,
3542                       CvMat** trainClasses )
3543 {
3544
3545     CV_FUNCNAME( "cvReadTrainData" );
3546
3547     __BEGIN__;
3548
3549     FILE* file;
3550     int m, n;
3551     int i, j;
3552     float val;
3553
3554     if( filename == NULL )
3555     {
3556         CV_ERROR( CV_StsNullPtr, "filename must be specified" );
3557     }
3558     if( trainData == NULL )
3559     {
3560         CV_ERROR( CV_StsNullPtr, "trainData must be not NULL" );
3561     }
3562     if( trainClasses == NULL )
3563     {
3564         CV_ERROR( CV_StsNullPtr, "trainClasses must be not NULL" );
3565     }
3566     
3567     *trainData = NULL;
3568     *trainClasses = NULL;
3569     file = fopen( filename, "r" );
3570     if( !file )
3571     {
3572         CV_ERROR( CV_StsError, "Unable to open file" );
3573     }
3574
3575     fscanf( file, "%d %d", &m, &n );
3576
3577     if( CV_IS_ROW_SAMPLE( flags ) )
3578     {
3579         CV_CALL( *trainData = cvCreateMat( m, n, CV_32FC1 ) );
3580     }
3581     else
3582     {
3583         CV_CALL( *trainData = cvCreateMat( n, m, CV_32FC1 ) );
3584     }
3585     
3586     CV_CALL( *trainClasses = cvCreateMat( 1, m, CV_32FC1 ) );
3587
3588     for( i = 0; i < m; i++ )
3589     {
3590         for( j = 0; j < n; j++ )
3591         {
3592             fscanf( file, "%f", &val );
3593             if( CV_IS_ROW_SAMPLE( flags ) )
3594             {
3595                 CV_MAT_ELEM( **trainData, float, i, j ) = val;
3596             }
3597             else
3598             {
3599                 CV_MAT_ELEM( **trainData, float, j, i ) = val;
3600             }
3601         }
3602         fscanf( file, "%f", &val );
3603         CV_MAT_ELEM( **trainClasses, float, 0, i ) = val;
3604     }
3605
3606     fclose( file );
3607
3608     __END__;
3609     
3610 }
3611
3612 CV_BOOST_IMPL
3613 void cvWriteTrainData( const char* filename, int flags,
3614                        CvMat* trainData, CvMat* trainClasses, CvMat* sampleIdx )
3615 {
3616     CV_FUNCNAME( "cvWriteTrainData" );
3617
3618     __BEGIN__;
3619
3620     FILE* file;
3621     int m, n;
3622     int i, j;
3623     int clsrow;
3624     int count;
3625     int idx;
3626     CvScalar sc;
3627
3628     if( filename == NULL )
3629     {
3630         CV_ERROR( CV_StsNullPtr, "filename must be specified" );
3631     }
3632     if( trainData == NULL || CV_MAT_TYPE( trainData->type ) != CV_32FC1 )
3633     {
3634         CV_ERROR( CV_StsUnsupportedFormat, "Invalid trainData" );
3635     }
3636     if( CV_IS_ROW_SAMPLE( flags ) )
3637     {
3638         m = trainData->rows;
3639         n = trainData->cols;
3640     }
3641     else
3642     {
3643         n = trainData->rows;
3644         m = trainData->cols;
3645     }
3646     if( trainClasses == NULL || CV_MAT_TYPE( trainClasses->type ) != CV_32FC1 ||
3647         MIN( trainClasses->rows, trainClasses->cols ) != 1 )
3648     {
3649         CV_ERROR( CV_StsUnsupportedFormat, "Invalid trainClasses" );
3650     }
3651     clsrow = (trainClasses->rows == 1);
3652     if( m != ( (clsrow) ? trainClasses->cols : trainClasses->rows ) )
3653     {
3654         CV_ERROR( CV_StsUnmatchedSizes, "Incorrect trainData and trainClasses sizes" );
3655     }
3656     
3657     if( sampleIdx != NULL )
3658     {
3659         count = (sampleIdx->rows == 1) ? sampleIdx->cols : sampleIdx->rows;
3660     }
3661     else
3662     {
3663         count = m;
3664     }
3665     
3666
3667     file = fopen( filename, "w" );
3668     if( !file )
3669     {
3670         CV_ERROR( CV_StsError, "Unable to create file" );
3671     }
3672
3673     fprintf( file, "%d %d\n", count, n );
3674
3675     for( i = 0; i < count; i++ )
3676     {
3677         if( sampleIdx )
3678         {
3679             if( sampleIdx->rows == 1 )
3680             {
3681                 sc = cvGet2D( sampleIdx, 0, i );
3682             }
3683             else
3684             {
3685                 sc = cvGet2D( sampleIdx, i, 0 );
3686             }
3687             idx = (int) sc.val[0];
3688         }
3689         else
3690         {
3691             idx = i;
3692         }
3693         for( j = 0; j < n; j++ )
3694         {
3695             fprintf( file, "%g ", ( (CV_IS_ROW_SAMPLE( flags ))
3696                                     ? CV_MAT_ELEM( *trainData, float, idx, j ) 
3697                                     : CV_MAT_ELEM( *trainData, float, j, idx ) ) );
3698         }
3699         fprintf( file, "%g\n", ( (clsrow)
3700                                 ? CV_MAT_ELEM( *trainClasses, float, 0, idx )
3701                                 : CV_MAT_ELEM( *trainClasses, float, idx, 0 ) ) );
3702     }
3703
3704     fclose( file );
3705     
3706     __END__;
3707 }
3708
3709
3710 #define ICV_RAND_SHUFFLE( suffix, type )                                                 \
3711 void icvRandShuffle_##suffix( uchar* data, size_t step, int num )                        \
3712 {                                                                                        \
3713     CvRandState state;                                                                   \
3714     time_t seed;                                                                         \
3715     type tmp;                                                                            \
3716     int i;                                                                               \
3717     float rn;                                                                            \
3718                                                                                          \
3719     time( &seed );                                                                       \
3720                                                                                          \
3721     cvRandInit( &state, (double) 0, (double) 0, (int)seed );                             \
3722     for( i = 0; i < (num-1); i++ )                                                       \
3723     {                                                                                    \
3724         rn = ((float) cvRandNext( &state )) / (1.0F + UINT_MAX);                         \
3725         CV_SWAP( *((type*)(data + i * step)),                                            \
3726                  *((type*)(data + ( i + (int)( rn * (num - i ) ) )* step)),              \
3727                  tmp );                                                                  \
3728     }                                                                                    \
3729 }
3730
3731 ICV_RAND_SHUFFLE( 8U, uchar )
3732
3733 ICV_RAND_SHUFFLE( 16S, short )
3734
3735 ICV_RAND_SHUFFLE( 32S, int )
3736
3737 ICV_RAND_SHUFFLE( 32F, float )
3738
3739 CV_BOOST_IMPL
3740 void cvRandShuffleVec( CvMat* mat )
3741 {
3742     CV_FUNCNAME( "cvRandShuffle" );
3743
3744     __BEGIN__;
3745
3746     uchar* data;
3747     size_t step;
3748     int num;
3749
3750     if( (mat == NULL) || !CV_IS_MAT( mat ) || MIN( mat->rows, mat->cols ) != 1 )
3751     {
3752         CV_ERROR( CV_StsUnsupportedFormat, "" );
3753     }
3754
3755     CV_MAT2VEC( *mat, data, step, num );
3756     switch( CV_MAT_TYPE( mat->type ) )
3757     {
3758         case CV_8UC1:
3759             icvRandShuffle_8U( data, step, num);
3760             break;
3761         case CV_16SC1:
3762             icvRandShuffle_16S( data, step, num);
3763             break;
3764         case CV_32SC1:
3765             icvRandShuffle_32S( data, step, num);
3766             break;
3767         case CV_32FC1:
3768             icvRandShuffle_32F( data, step, num);
3769             break;
3770         default:
3771             CV_ERROR( CV_StsUnsupportedFormat, "" );
3772     }
3773
3774     __END__;
3775 }
3776
3777 /* End of file. */