Move the sources to trunk
[opencv] / ml / src / mlem.cpp
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                        Intel License Agreement
11 //                For Open Source Computer Vision Library
12 //
13 // Copyright( C) 2000, Intel Corporation, all rights reserved.
14 // Third party copyrights are property of their respective owners.
15 //
16 // Redistribution and use in source and binary forms, with or without modification,
17 // are permitted provided that the following conditions are met:
18 //
19 //   * Redistribution's of source code must retain the above copyright notice,
20 //     this list of conditions and the following disclaimer.
21 //
22 //   * Redistribution's in binary form must reproduce the above copyright notice,
23 //     this list of conditions and the following disclaimer in the documentation
24 //     and/or other materials provided with the distribution.
25 //
26 //   * The name of Intel Corporation may not be used to endorse or promote products
27 //     derived from this software without specific prior written permission.
28 //
29 // This software is provided by the copyright holders and contributors "as is" and
30 // any express or implied warranties, including, but not limited to, the implied
31 // warranties of merchantability and fitness for a particular purpose are disclaimed.
32 // In no event shall the Intel Corporation or contributors be liable for any direct,
33 // indirect, incidental, special, exemplary, or consequential damages
34 //(including, but not limited to, procurement of substitute goods or services;
35 // loss of use, data, or profits; or business interruption) however caused
36 // and on any theory of liability, whether in contract, strict liability,
37 // or tort(including negligence or otherwise) arising in any way out of
38 // the use of this software, even ifadvised of the possibility of such damage.
39 //
40 //M*/
41
42 #include "_ml.h"
43
44
45 /*
46    CvEM:
47  * params.nclusters    - number of clusters to cluster samples to.
48  * means               - calculated by the EM algorithm set of gaussians' means.
49  * log_weight_div_det - auxilary vector that k-th component is equal to
50                         (-2)*ln(weights_k/det(Sigma_k)^0.5),
51                         where <weights_k> is the weight,
52                         <Sigma_k> is the covariation matrice of k-th cluster.
53  * inv_eigen_values   - set of 1*dims matrices, <inv_eigen_values>[k] contains
54                         inversed eigen values of covariation matrice of the k-th cluster.
55                         In the case of <cov_mat_type> == COV_MAT_DIAGONAL,
56                         inv_eigen_values[k] = Sigma_k^(-1).
57  * covs_rotate_mats   - used only if cov_mat_type == COV_MAT_GENERIC, in all the
58                         other cases it is NULL. <covs_rotate_mats>[k] is the orthogonal
59                         matrice, obtained by the SVD-decomposition of Sigma_k.
60    Both <inv_eigen_values> and <covs_rotate_mats> fields are used for representation of
61    covariation matrices and simplifying EM calculations.
62    For fixed k denote
63    u = covs_rotate_mats[k],
64    v = inv_eigen_values[k],
65    w = v^(-1);
66    if <cov_mat_type> == COV_MAT_GENERIC, then Sigma_k = u w u',
67    else                                       Sigma_k = w.
68    Symbol ' means transposition.
69  */
70
71
72 CvEM::CvEM()
73 {
74     means = weights = probs = inv_eigen_values = log_weight_div_det = 0;
75     covs = cov_rotate_mats = 0;
76 }
77
78
79 CvEM::~CvEM()
80 {
81     clear();
82 }
83
84
85 void CvEM::clear()
86 {
87     int i;
88     
89     cvReleaseMat( &means );
90     cvReleaseMat( &weights );
91     cvReleaseMat( &probs );
92     cvReleaseMat( &inv_eigen_values );
93     cvReleaseMat( &log_weight_div_det );
94
95     if( covs || cov_rotate_mats )
96     {
97         for( i = 0; i < params.nclusters; i++ )
98         {
99             if( covs )
100                 cvReleaseMat( &covs[i] );
101             if( cov_rotate_mats )
102                 cvReleaseMat( &cov_rotate_mats[i] );
103         }
104         cvFree( &covs );
105         cvFree( &cov_rotate_mats );
106     }
107 }
108
109
110 void CvEM::set_params( const CvEMParams& _params, const CvVectors& train_data )
111 {
112     CV_FUNCNAME( "CvEM::set_params" );
113
114     __BEGIN__;
115     
116     int k;
117
118     params = _params;
119     params.term_crit = cvCheckTermCriteria( params.term_crit, 1e-6, 10000 );
120
121     if( params.cov_mat_type != COV_MAT_SPHERICAL &&
122         params.cov_mat_type != COV_MAT_DIAGONAL &&
123         params.cov_mat_type != COV_MAT_GENERIC )
124         CV_ERROR( CV_StsBadArg, "Unknown covariation matrix type" );
125
126     switch( params.start_step )
127     {
128     case START_M_STEP:
129         if( !params.probs )
130             CV_ERROR( CV_StsNullPtr, "Probabilities must be specified when EM algorithm starts with M-step" );
131         break;
132     case START_E_STEP:
133         if( !params.means )
134             CV_ERROR( CV_StsNullPtr, "Mean's must be specified when EM algorithm starts with E-step" );
135         break;
136     case START_AUTO_STEP:
137         break;
138     default:
139         CV_ERROR( CV_StsBadArg, "Unknown start_step" );
140     }
141
142     if( params.nclusters < 1 )
143         CV_ERROR( CV_StsOutOfRange, "The number of clusters (mixtures) should be > 0" );
144
145     if( params.probs )
146     {
147         const CvMat* p = params.weights;
148         if( !CV_IS_MAT(p) ||
149             CV_MAT_TYPE(p->type) != CV_32FC1  &&
150             CV_MAT_TYPE(p->type) != CV_64FC1 ||
151             p->rows != train_data.count ||
152             p->cols != params.nclusters )
153             CV_ERROR( CV_StsBadArg, "The array of probabilities must be a valid "
154             "floating-point matrix (CvMat) of 'nsamples' x 'nclusters' size" );
155     }
156
157     if( params.means )
158     {
159         const CvMat* m = params.means;
160         if( !CV_IS_MAT(m) ||
161             CV_MAT_TYPE(m->type) != CV_32FC1  &&
162             CV_MAT_TYPE(m->type) != CV_64FC1 ||
163             m->rows != params.nclusters ||
164             m->cols != train_data.dims )
165             CV_ERROR( CV_StsBadArg, "The array of mean's must be a valid "
166             "floating-point matrix (CvMat) of 'nsamples' x 'dims' size" );
167     }
168
169     if( params.weights )
170     {
171         const CvMat* w = params.weights;
172         if( !CV_IS_MAT(w) ||
173             CV_MAT_TYPE(w->type) != CV_32FC1  &&
174             CV_MAT_TYPE(w->type) != CV_64FC1 ||
175             w->rows != 1 && w->cols != 1 ||
176             w->rows + w->cols - 1 != params.nclusters )
177             CV_ERROR( CV_StsBadArg, "The array of weights must be a valid "
178             "1d floating-point vector (CvMat) of 'nclusters' elements" );
179     }
180
181     if( params.covs )
182         for( k = 0; k < params.nclusters; k++ )
183         {
184             const CvMat* cov = params.covs[k];
185             if( !CV_IS_MAT(cov) ||
186                 CV_MAT_TYPE(cov->type) != CV_32FC1  &&
187                 CV_MAT_TYPE(cov->type) != CV_64FC1 ||
188                 cov->rows != cov->cols || cov->cols != train_data.dims )
189                 CV_ERROR( CV_StsBadArg,
190                 "Each of covariation matrices must be a valid square "
191                 "floating-point matrix (CvMat) of 'dims' x 'dims'" );
192         }
193
194     __END__;
195 }
196
197
198 /****************************************************************************************/
199 float
200 CvEM::predict( const CvMat* _sample, CvMat* _probs ) const
201 {
202     float* sample_data   = 0;
203     void* buffer = 0;
204     int allocated_buffer = 0;
205     int cls = 0;
206
207     CV_FUNCNAME( "CvEM::predict" );
208     __BEGIN__;
209
210     int i, k, dims;
211     int nclusters;
212     int cov_mat_type = params.cov_mat_type;
213     double opt = FLT_MAX;
214     size_t size;
215     CvMat diff, expo;
216
217     dims = means->cols;
218     nclusters = params.nclusters;
219
220     CV_CALL( cvPreparePredictData( _sample, dims, 0, params.nclusters, _probs, &sample_data ));
221
222 // allocate memory and initializing headers for calculating
223     size = sizeof(double) * (nclusters + dims);
224     if( size <= CV_MAX_LOCAL_SIZE )
225         buffer = cvStackAlloc( size );
226     else
227     {
228         CV_CALL( buffer = cvAlloc( size ));
229         allocated_buffer = 1;
230     }
231     expo = cvMat( 1, nclusters, CV_64FC1, buffer );
232     diff = cvMat( 1, dims, CV_64FC1, (double*)buffer + nclusters );
233
234 // calculate the probabilities
235     for( k = 0; k < nclusters; k++ )
236     {
237         const double* mean_k = (const double*)(means->data.ptr + means->step*k);
238         const double* w = (const double*)(inv_eigen_values->data.ptr + inv_eigen_values->step*k);
239         double cur = log_weight_div_det->data.db[k];
240         CvMat* u = cov_rotate_mats[k];
241         // cov = u w u'  -->  cov^(-1) = u w^(-1) u'
242         if( cov_mat_type == COV_MAT_SPHERICAL )
243         {
244             double w0 = w[0];
245             for( i = 0; i < dims; i++ )
246             {
247                 double val = sample_data[i] - mean_k[i];
248                 cur += val*val*w0;
249             }
250         }
251         else
252         {
253             for( i = 0; i < dims; i++ )
254                 diff.data.db[i] = sample_data[i] - mean_k[i];
255             if( cov_mat_type == COV_MAT_GENERIC )
256                 cvGEMM( &diff, u, 1, 0, 0, &diff, CV_GEMM_B_T );
257             for( i = 0; i < dims; i++ )
258             {
259                 double val = diff.data.db[i];
260                 cur += val*val*w[i];
261             }
262         }
263
264         expo.data.db[k] = cur;
265         if( cur < opt )
266         {
267             cls = k;
268             opt = cur;
269         }
270         /* probability = (2*pi)^(-dims/2)*exp( -0.5 * cur ) */
271     }
272
273     if( _probs )
274     {
275         CV_CALL( cvConvertScale( &expo, &expo, -0.5 ));
276         CV_CALL( cvExp( &expo, &expo ));
277         if( _probs->cols == 1 )
278             CV_CALL( cvReshape( &expo, &expo, 0, nclusters ));
279         CV_CALL( cvConvertScale( &expo, _probs, 1./cvSum( &expo ).val[0] ));
280     }
281
282     __END__;
283
284     if( sample_data != _sample->data.fl )
285         cvFree( &sample_data );
286     if( allocated_buffer )
287         cvFree( &buffer );
288
289     return (float)cls;
290 }
291
292
293
294 bool CvEM::train( const CvMat* _samples, const CvMat* _sample_idx,
295                   CvEMParams _params, CvMat* labels )
296 {
297     bool result = false;
298     CvVectors train_data;
299     CvMat* sample_idx = 0;
300
301     train_data.data.fl = 0;
302     train_data.count = 0;
303
304     CV_FUNCNAME("cvEM");
305
306     __BEGIN__;
307
308     int i, nsamples, nclusters, dims;
309
310     clear();
311
312     CV_CALL( cvPrepareTrainData( "cvEM",
313         _samples, CV_ROW_SAMPLE, 0, CV_VAR_CATEGORICAL,
314         0, _sample_idx, false, (const float***)&train_data.data.fl,
315         &train_data.count, &train_data.dims, &train_data.dims,
316         0, 0, 0, &sample_idx ));
317
318     CV_CALL( set_params( _params, train_data ));
319     nsamples = train_data.count;
320     nclusters = params.nclusters;
321     dims = train_data.dims;
322
323     if( labels && (!CV_IS_MAT(labels) || CV_MAT_TYPE(labels->type) != CV_32SC1 ||
324         labels->cols != 1 && labels->rows != 1 || labels->cols + labels->rows - 1 != nsamples ))
325         CV_ERROR( CV_StsBadArg,
326         "labels array (when passed) must be a valid 1d integer vector of <sample_count> elements" );
327
328     if( nsamples <= nclusters )
329         CV_ERROR( CV_StsOutOfRange,
330         "The number of samples should be greater than the number of clusters" );
331
332     CV_CALL( log_weight_div_det = cvCreateMat( 1, nclusters, CV_64FC1 ));
333     CV_CALL( probs  = cvCreateMat( nsamples, nclusters, CV_64FC1 ));
334     CV_CALL( means = cvCreateMat( nclusters, dims, CV_64FC1 ));
335     CV_CALL( weights = cvCreateMat( 1, nclusters, CV_64FC1 ));
336     CV_CALL( inv_eigen_values = cvCreateMat( nclusters,
337         params.cov_mat_type == COV_MAT_SPHERICAL ? 1 : dims, CV_64FC1 ));
338     CV_CALL( covs = (CvMat**)cvAlloc( nclusters * sizeof(*covs) ));
339     CV_CALL( cov_rotate_mats = (CvMat**)cvAlloc( nclusters * sizeof(cov_rotate_mats[0]) ));
340
341     for( i = 0; i < nclusters; i++ )
342     {
343         CV_CALL( covs[i] = cvCreateMat( dims, dims, CV_64FC1 ));
344         CV_CALL( cov_rotate_mats[i]  = cvCreateMat( dims, dims, CV_64FC1 ));
345         cvZero( cov_rotate_mats[i] );
346     }
347
348     init_em( train_data );
349     log_likelihood = run_em( train_data );
350     if( log_likelihood <= -DBL_MAX/10000. )
351         EXIT;
352
353     if( labels )
354     {
355         if( nclusters == 1 )
356             cvZero( labels );
357         else
358         {
359             CvMat sample = cvMat( 1, dims, CV_32F );
360             CvMat prob = cvMat( 1, nclusters, CV_64F );
361             int lstep = CV_IS_MAT_CONT(labels->type) ? 1 : labels->step/sizeof(int);
362             
363             for( i = 0; i < nsamples; i++ )
364             {
365                 int idx = sample_idx ? sample_idx->data.i[i] : i;
366                 sample.data.ptr = _samples->data.ptr + _samples->step*idx;
367                 prob.data.ptr = probs->data.ptr + probs->step*i;
368
369                 labels->data.i[i*lstep] = cvRound(predict(&sample, &prob));
370             }
371         }
372     }
373
374     result = true;
375
376     __END__;
377
378     if( sample_idx != _sample_idx )
379         cvReleaseMat( &sample_idx );
380
381     cvFree( &train_data.data.ptr );
382
383     return result;
384 }
385
386
387 void CvEM::init_em( const CvVectors& train_data )
388 {
389     CvMat *w = 0, *u = 0, *tcov = 0;
390     
391     CV_FUNCNAME( "CvEM::init_em" );
392
393     __BEGIN__;
394
395     double maxval = 0;
396     int i, force_symm_plus = 0;
397     int nclusters = params.nclusters, nsamples = train_data.count, dims = train_data.dims;
398
399     if( params.start_step == START_AUTO_STEP || nclusters == 1 || nclusters == nsamples )
400         init_auto( train_data );
401     else if( params.start_step == START_M_STEP )
402     {
403         for( i = 0; i < nsamples; i++ )
404         {
405             CvMat prob;
406             cvGetRow( params.probs, &prob, i );
407             cvMaxS( &prob, 0., &prob );
408             cvMinMaxLoc( &prob, 0, &maxval );
409             if( maxval < FLT_EPSILON )
410                 cvSet( &prob, cvScalar(1./nclusters) );
411             else
412                 cvNormalize( &prob, &prob, 1., 0, CV_L1 );
413         }
414         EXIT; // do not preprocess covariation matrices,
415               // as in this case they are initialized at the first iteration of EM
416     }
417     else
418     {
419         CV_ASSERT( params.start_step == START_E_STEP && params.means );
420         if( params.weights && params.covs )
421         {
422             cvConvert( params.means, means );
423             cvReshape( weights, weights, 1, params.weights->rows );
424             cvConvert( params.weights, weights );
425             cvReshape( weights, weights, 1, 1 );
426             cvMaxS( weights, 0., weights );
427             cvMinMaxLoc( weights, 0, &maxval );
428             if( maxval < FLT_EPSILON )
429                 cvSet( &weights, cvScalar(1./nclusters) );
430             cvNormalize( weights, weights, 1., 0, CV_L1 );
431             for( i = 0; i < nclusters; i++ )
432                 CV_CALL( cvConvert( params.covs[i], covs[i] ));
433             force_symm_plus = 1;
434         }
435         else
436             init_auto( train_data );
437     }
438
439     CV_CALL( tcov = cvCreateMat( dims, dims, CV_64FC1 ));
440     CV_CALL( w = cvCreateMat( dims, dims, CV_64FC1 ));
441     if( params.cov_mat_type == COV_MAT_GENERIC )
442         CV_CALL( u = cvCreateMat( dims, dims, CV_64FC1 ));
443
444     for( i = 0; i < nclusters; i++ )
445     {
446         if( force_symm_plus )
447         {
448             cvTranspose( covs[i], tcov );
449             cvAddWeighted( covs[i], 0.5, tcov, 0.5, 0, tcov );
450         }
451         else
452             cvCopy( covs[i], tcov );
453         cvSVD( tcov, w, u, 0, CV_SVD_MODIFY_A + CV_SVD_U_T + CV_SVD_V_T );
454         if( params.cov_mat_type == COV_MAT_SPHERICAL )
455             cvSetIdentity( covs[i], cvScalar(cvTrace(w).val[0]/dims) );
456         else if( params.cov_mat_type == COV_MAT_DIAGONAL )
457             cvCopy( w, covs[i] );
458         else
459         {
460             // generic case: covs[i] = (u')'*max(w,0)*u'
461             cvGEMM( u, w, 1, 0, 0, tcov, CV_GEMM_A_T );
462             cvGEMM( tcov, u, 1, 0, 0, covs[i], 0 );
463         }
464     }
465
466     __END__;
467
468     cvReleaseMat( &w );
469     cvReleaseMat( &u );
470     cvReleaseMat( &tcov );
471 }
472
473
474 void CvEM::init_auto( const CvVectors& train_data )
475 {
476     CvMat* hdr = 0;
477     const void** vec = 0;
478     CvMat* class_ranges = 0;
479     CvMat* labels = 0;
480
481     CV_FUNCNAME( "CvEM::init_auto" );
482
483     __BEGIN__;
484
485     int nclusters = params.nclusters, nsamples = train_data.count, dims = train_data.dims;
486     int i, j;
487
488     if( nclusters == nsamples )
489     {
490         CvMat src = cvMat( 1, dims, CV_32F );
491         CvMat dst = cvMat( 1, dims, CV_64F );
492         for( i = 0; i < nsamples; i++ )
493         {
494             src.data.ptr = train_data.data.ptr[i];
495             dst.data.ptr = means->data.ptr + means->step*i;
496             cvConvert( &src, &dst );
497             cvZero( covs[i] );
498             cvSetIdentity( cov_rotate_mats[i] );
499         }
500         cvSetIdentity( probs );
501         cvSet( weights, cvScalar(1./nclusters) );
502     }
503     else
504     {
505         int max_count = 0;
506         
507         CV_CALL( class_ranges = cvCreateMat( 1, nclusters+1, CV_32SC1 ));
508         if( nclusters > 1 )
509         {
510             CV_CALL( labels = cvCreateMat( 1, nsamples, CV_32SC1 ));
511             kmeans( train_data, nclusters, labels, cvTermCriteria( CV_TERMCRIT_ITER,
512                     params.means ? 1 : 10, 0.5 ), params.means );
513             CV_CALL( cvSortSamplesByClasses( (const float**)train_data.data.fl,
514                                             labels, class_ranges->data.i ));
515         }
516         else
517         {
518             class_ranges->data.i[0] = 0;
519             class_ranges->data.i[1] = nsamples;
520         }
521
522         for( i = 0; i < nclusters; i++ )
523         {
524             int left = class_ranges->data.i[i], right = class_ranges->data.i[i+1];
525             max_count = MAX( max_count, right - left );
526         }
527         CV_CALL( hdr = (CvMat*)cvAlloc( max_count*sizeof(hdr[0]) ));
528         CV_CALL( vec = (const void**)cvAlloc( max_count*sizeof(vec[0]) ));
529         hdr[0] = cvMat( 1, dims, CV_32F );
530         for( i = 0; i < max_count; i++ )
531         {
532             vec[i] = hdr + i;
533             hdr[i] = hdr[0];
534         }
535                 
536         for( i = 0; i < nclusters; i++ )
537         {
538             int left = class_ranges->data.i[i], right = class_ranges->data.i[i+1];
539             int cluster_size = right - left;
540             CvMat avg;
541             
542             if( cluster_size <= 0 )
543                 continue;
544
545             for( j = left; j < right; j++ )
546                 hdr[j - left].data.fl = train_data.data.fl[j];
547             
548             CV_CALL( cvGetRow( means, &avg, i ));
549             CV_CALL( cvCalcCovarMatrix( vec, cluster_size, covs[i],
550                 &avg, CV_COVAR_NORMAL | CV_COVAR_SCALE ));
551             weights->data.db[i] = (double)cluster_size/(double)nsamples;
552         }
553     }
554
555     __END__;
556
557     cvReleaseMat( &class_ranges );
558     cvReleaseMat( &labels );
559     cvFree( &hdr );
560     cvFree( &vec );
561 }
562
563
564 void CvEM::kmeans( const CvVectors& train_data, int nclusters, CvMat* labels,
565                    CvTermCriteria termcrit, const CvMat* centers0 )
566 {
567     CvMat* centers = 0;
568     CvMat* old_centers = 0;
569     CvMat* counters = 0;
570     
571     CV_FUNCNAME( "CvEM::kmeans" );
572
573     __BEGIN__;
574
575     CvRNG rng = cvRNG(-1);
576     int i, j, k, nsamples, dims;
577     int iter = 0;
578     double max_dist = DBL_MAX;
579
580     termcrit = cvCheckTermCriteria( termcrit, 1e-6, 100 );
581     termcrit.epsilon *= termcrit.epsilon;
582     nsamples = train_data.count;
583     dims = train_data.dims;
584     nclusters = MIN( nclusters, nsamples );
585
586     CV_CALL( centers = cvCreateMat( nclusters, dims, CV_64FC1 ));
587     CV_CALL( old_centers = cvCreateMat( nclusters, dims, CV_64FC1 ));
588     CV_CALL( counters = cvCreateMat( 1, nclusters, CV_32SC1 ));
589     cvZero( old_centers );
590
591     if( centers0 )
592     {
593         CV_CALL( cvConvert( centers0, centers ));
594     }
595     else
596     {
597         for( i = 0; i < nsamples; i++ )
598             labels->data.i[i] = i*nclusters/nsamples;
599         cvRandShuffle( labels, &rng );
600     }
601
602     for( ;; )
603     {
604         CvMat* temp;
605
606         if( iter > 0 || centers0 )
607         {
608             for( i = 0; i < nsamples; i++ )
609             {
610                 const float* s = train_data.data.fl[i];
611                 int k_best = 0;
612                 double min_dist = DBL_MAX;
613
614                 for( k = 0; k < nclusters; k++ )
615                 {
616                     const double* c = (double*)(centers->data.ptr + k*centers->step);
617                     double dist = 0;
618                 
619                     for( j = 0; j <= dims - 4; j += 4 )
620                     {
621                         double t0 = c[j] - s[j];
622                         double t1 = c[j+1] - s[j+1];
623                         dist += t0*t0 + t1*t1;
624                         t0 = c[j+2] - s[j+2];
625                         t1 = c[j+3] - s[j+3];
626                         dist += t0*t0 + t1*t1;
627                     }
628
629                     for( ; j < dims; j++ )
630                     {
631                         double t = c[j] - s[j];
632                         dist += t*t;
633                     }
634                 
635                     if( min_dist > dist )
636                     {
637                         min_dist = dist;
638                         k_best = k;
639                     }
640                 }
641
642                 labels->data.i[i] = k_best;
643             }
644         }
645
646         if( ++iter > termcrit.max_iter )
647             break;
648
649         CV_SWAP( centers, old_centers, temp );
650         cvZero( centers );
651         cvZero( counters );
652
653         // update centers
654         for( i = 0; i < nsamples; i++ )
655         {
656             const float* s = train_data.data.fl[i];
657             k = labels->data.i[i];
658             double* c = (double*)(centers->data.ptr + k*centers->step);
659
660             for( j = 0; j <= dims - 4; j += 4 )
661             {
662                 double t0 = c[j] + s[j];
663                 double t1 = c[j+1] + s[j+1];
664
665                 c[j] = t0;
666                 c[j+1] = t1;
667
668                 t0 = c[j+2] + s[j+2];
669                 t1 = c[j+3] + s[j+3];
670
671                 c[j+2] = t0;
672                 c[j+3] = t1;
673             }
674             for( ; j < dims; j++ )
675                 c[j] += s[j];
676             counters->data.i[k]++;
677         }
678
679         if( iter > 1 )
680             max_dist = 0;
681
682         for( k = 0; k < nclusters; k++ )
683         {
684             double* c = (double*)(centers->data.ptr + k*centers->step);
685             if( counters->data.i[k] != 0 )
686             {
687                 double scale = 1./counters->data.i[k];
688                 for( j = 0; j < dims; j++ )
689                     c[j] *= scale;
690             }
691             else
692             {
693                 const float* s;
694                 for( j = 0; j < 10; j++ )
695                 {
696                     i = cvRandInt( &rng ) % nsamples;
697                     if( counters->data.i[labels->data.i[i]] > 1 )
698                         break;
699                 }
700                 s = train_data.data.fl[i];
701                 for( j = 0; j < dims; j++ )
702                     c[j] = s[j];
703             }
704             
705             if( iter > 1 )
706             {
707                 double dist = 0;
708                 const double* c_o = (double*)(old_centers->data.ptr + k*old_centers->step);
709                 for( j = 0; j < dims; j++ )
710                 {
711                     double t = c[j] - c_o[j];
712                     dist += t*t;
713                 }
714                 if( max_dist < dist )
715                     max_dist = dist;
716             }
717         }
718
719         if( max_dist < termcrit.epsilon )
720             break;
721     }
722
723     cvZero( counters );
724     for( i = 0; i < nsamples; i++ )
725         counters->data.i[labels->data.i[i]]++;
726
727     // ensure that we do not have empty clusters
728     for( k = 0; k < nclusters; k++ )
729         if( counters->data.i[k] == 0 )
730             for(;;)
731             {
732                 i = cvRandInt(&rng) % nsamples;
733                 j = labels->data.i[i];
734                 if( counters->data.i[j] > 1 )
735                 {
736                     labels->data.i[i] = k;
737                     counters->data.i[j]--;
738                     counters->data.i[k]++;
739                     break;
740                 }
741             }
742
743     __END__;
744
745     cvReleaseMat( &centers );
746     cvReleaseMat( &old_centers );
747     cvReleaseMat( &counters );
748 }
749
750
751 /****************************************************************************************/
752 /* log_weight_div_det[k] = -2*log(weights_k) + log(det(Sigma_k)))
753
754    covs[k] = cov_rotate_mats[k] * cov_eigen_values[k] * (cov_rotate_mats[k])'
755    cov_rotate_mats[k] are orthogonal matrices of eigenvectors and
756    cov_eigen_values[k] are diagonal matrices (represented by 1D vectors) of eigen values.
757
758    The <alpha_ik> is the probability of the vector x_i to belong to the k-th cluster:
759    <alpha_ik> ~ weights_k * exp{ -0.5[ln(det(Sigma_k)) + (x_i - mu_k)' Sigma_k^(-1) (x_i - mu_k)] }
760    We calculate these probabilities here by the equivalent formulae:
761    Denote
762    S_ik = -0.5(log(det(Sigma_k)) + (x_i - mu_k)' Sigma_k^(-1) (x_i - mu_k)) + log(weights_k),
763    M_i = max_k S_ik = S_qi, so that the q-th class is the one where maximum reaches. Then
764    alpha_ik = exp{ S_ik - M_i } / ( 1 + sum_j!=q exp{ S_ji - M_i })
765 */
766 double CvEM::run_em( const CvVectors& train_data )
767 {
768     CvMat* centered_sample = 0;
769     CvMat* covs_item = 0;
770     CvMat* log_det = 0;
771     CvMat* log_weights = 0;
772     CvMat* cov_eigen_values = 0;
773     CvMat* samples = 0;
774     CvMat* sum_probs = 0;
775     log_likelihood = -DBL_MAX;
776
777     CV_FUNCNAME( "CvEM::run_em" );
778     __BEGIN__;
779
780     int nsamples = train_data.count, dims = train_data.dims, nclusters = params.nclusters;
781     double min_variation = FLT_EPSILON;
782     double min_det_value = MAX( DBL_MIN, pow( min_variation, dims ));
783     double likelihood_bias = -CV_LOG2PI * (double)nsamples * (double)dims / 2., _log_likelihood = -DBL_MAX;
784     int start_step = params.start_step;
785     
786     int i, j, k, n;
787     int is_general = 0, is_diagonal = 0, is_spherical = 0;
788     double prev_log_likelihood = -DBL_MAX / 1000., det, d;
789     CvMat whdr, iwhdr, diag, *w, *iw;
790     double* w_data;
791     double* sp_data;
792
793     if( nclusters == 1 )
794     {
795         double log_weight;
796         CV_CALL( cvSet( probs, cvScalar(1.)) );
797
798         if( params.cov_mat_type == COV_MAT_SPHERICAL )
799         {
800             d = cvTrace(*covs).val[0]/dims;
801             d = MAX( d, FLT_EPSILON );
802             inv_eigen_values->data.db[0] = 1./d;
803             log_weight = pow( d, dims*0.5 );
804         }
805         else
806         {
807             w_data = inv_eigen_values->data.db;
808             
809             if( params.cov_mat_type == COV_MAT_GENERIC )
810                 cvSVD( *covs, inv_eigen_values, *cov_rotate_mats, 0, CV_SVD_U_T );
811             else
812                 cvTranspose( cvGetDiag(*covs, &diag), inv_eigen_values );
813        
814             cvMaxS( inv_eigen_values, FLT_EPSILON, inv_eigen_values );
815             for( j = 0, det = 1.; j < dims; j++ )
816                 det *= w_data[j];
817             log_weight = sqrt(det);
818             cvDiv( 0, inv_eigen_values, inv_eigen_values );
819         }
820
821         log_weight_div_det->data.db[0] = -2*log(weights->data.db[0]/log_weight);
822         log_likelihood = DBL_MAX/1000.;
823         EXIT;
824     }
825
826     if( params.cov_mat_type == COV_MAT_GENERIC )
827         is_general  = 1;
828     else if( params.cov_mat_type == COV_MAT_DIAGONAL )
829         is_diagonal = 1;
830     else if( params.cov_mat_type == COV_MAT_SPHERICAL )
831         is_spherical  = 1;
832     /* In the case of <cov_mat_type> == COV_MAT_DIAGONAL, the k-th row of cov_eigen_values
833     contains the diagonal elements (variations). In the case of
834     <cov_mat_type> == COV_MAT_SPHERICAL - the 0-ths elements of the vectors cov_eigen_values[k]
835     are to be equal to the mean of the variations over all the dimensions. */
836
837     CV_CALL( log_det = cvCreateMat( 1, nclusters, CV_64FC1 ));
838     CV_CALL( log_weights = cvCreateMat( 1, nclusters, CV_64FC1 ));
839     CV_CALL( covs_item = cvCreateMat( dims, dims, CV_64FC1 ));
840     CV_CALL( centered_sample = cvCreateMat( 1, dims, CV_64FC1 ));
841     CV_CALL( cov_eigen_values = cvCreateMat( inv_eigen_values->rows, inv_eigen_values->cols, CV_64FC1 ));
842     CV_CALL( samples = cvCreateMat( nsamples, dims, CV_64FC1 ));
843     CV_CALL( sum_probs = cvCreateMat( 1, nclusters, CV_64FC1 ));
844     sp_data = sum_probs->data.db;
845
846     // copy the training data into double-precision matrix
847     for( i = 0; i < nsamples; i++ )
848     {
849         const float* src = train_data.data.fl[i];
850         double* dst = (double*)(samples->data.ptr + samples->step*i);
851
852         for( j = 0; j < dims; j++ )
853             dst[j] = src[j];
854     }
855
856     if( start_step != START_M_STEP )
857     {
858         for( k = 0; k < nclusters; k++ )
859         {
860             if( is_general || is_diagonal )
861             {
862                 w = cvGetRow( cov_eigen_values, &whdr, k );
863                 if( is_general )
864                     cvSVD( covs[k], w, cov_rotate_mats[k], 0, CV_SVD_U_T );
865                 else
866                     cvTranspose( cvGetDiag( covs[k], &diag ), w );
867                 w_data = w->data.db;
868                 for( j = 0, det = 1.; j < dims; j++ )
869                     det *= w_data[j];
870                 if( det < min_det_value )
871                 {
872                     if( start_step == START_AUTO_STEP )
873                         det = min_det_value;
874                     else
875                         EXIT;
876                 }
877                 log_det->data.db[k] = det;
878             }
879             else
880             {
881                 d = cvTrace(covs[k]).val[0]/(double)dims;
882                 if( d < min_variation )
883                 {
884                     if( start_step == START_AUTO_STEP )
885                         d = min_variation;
886                     else
887                         EXIT;
888                 }
889                 cov_eigen_values->data.db[k] = d;
890                 log_det->data.db[k] = d;
891             }
892         }
893     
894         cvLog( log_det, log_det );
895         if( is_spherical )
896             cvScale( log_det, log_det, dims );
897     }
898
899     for( n = 0; n < params.term_crit.max_iter; n++ )
900     {
901         if( n > 0 || start_step != START_M_STEP )
902         {
903             // e-step: compute probs_ik from means_k, covs_k and weights_k.
904             CV_CALL(cvLog( weights, log_weights ));
905
906             // S_ik = -0.5[log(det(Sigma_k)) + (x_i - mu_k)' Sigma_k^(-1) (x_i - mu_k)] + log(weights_k)
907             for( k = 0; k < nclusters; k++ )
908             {
909                 CvMat* u = cov_rotate_mats[k];
910                 const double* mean = (double*)(means->data.ptr + means->step*k);
911                 w = cvGetRow( cov_eigen_values, &whdr, k );
912                 iw = cvGetRow( inv_eigen_values, &iwhdr, k );
913                 cvDiv( 0, w, iw );
914
915                 w_data = (double*)(inv_eigen_values->data.ptr + inv_eigen_values->step*k);
916
917                 for( i = 0; i < nsamples; i++ )
918                 {
919                     double *csample = centered_sample->data.db, p = log_det->data.db[k];
920                     const double* sample = (double*)(samples->data.ptr + samples->step*i);
921                     double* pp = (double*)(probs->data.ptr + probs->step*i);
922                     for( j = 0; j < dims; j++ )
923                         csample[j] = sample[j] - mean[j];
924                     if( is_general )
925                         cvGEMM( centered_sample, u, 1, 0, 0, centered_sample, CV_GEMM_B_T );
926                     for( j = 0; j < dims; j++ )
927                         p += csample[j]*csample[j]*w_data[is_spherical ? 0 : j];
928                     pp[k] = -0.5*p + log_weights->data.db[k];
929
930                     // S_ik <- S_ik - max_j S_ij
931                     if( k == nclusters - 1 )
932                     {
933                         double max_val = 0;
934                         for( j = 0; j < nclusters; j++ )
935                             max_val = MAX( max_val, pp[j] );
936                         for( j = 0; j < nclusters; j++ )
937                             pp[j] -= max_val;
938                     }
939                 }
940             }
941
942             CV_CALL(cvExp( probs, probs )); // exp( S_ik )
943             cvZero( sum_probs );
944
945             // alpha_ik = exp( S_ik ) / sum_j exp( S_ij ),
946             // log_likelihood = sum_i log (sum_j exp(S_ij))
947             for( i = 0, _log_likelihood = likelihood_bias; i < nsamples; i++ )
948             {
949                 double* pp = (double*)(probs->data.ptr + probs->step*i), sum = 0;
950                 for( j = 0; j < nclusters; j++ )
951                     sum += pp[j];
952                 sum = 1./MAX( sum, DBL_EPSILON );
953                 for( j = 0; j < nclusters; j++ )
954                 {
955                     double p = pp[j] *= sum;
956                     sp_data[j] += p;
957                 }
958                 _log_likelihood -= log( sum );
959             }
960
961             // check termination criteria
962             if( fabs( (_log_likelihood - prev_log_likelihood) / prev_log_likelihood ) < params.term_crit.epsilon )
963                 break;
964             prev_log_likelihood = _log_likelihood;
965         }
966
967         // m-step: update means_k, covs_k and weights_k from probs_ik
968         cvGEMM( probs, samples, 1, 0, 0, means, CV_GEMM_A_T );
969
970         for( k = 0; k < nclusters; k++ )
971         {
972             double sum = sp_data[k], inv_sum = 1./sum;
973             CvMat* cov = covs[k], _mean, _sample;
974             
975             w = cvGetRow( cov_eigen_values, &whdr, k );
976             w_data = w->data.db;
977             cvGetRow( means, &_mean, k );
978             cvGetRow( samples, &_sample, k );
979
980             // update weights_k
981             weights->data.db[k] = sum;
982             
983             // update means_k
984             cvScale( &_mean, &_mean, inv_sum );
985
986             // compute covs_k
987             cvZero( cov );
988             cvZero( w );
989
990             for( i = 0; i < nsamples; i++ )
991             {
992                 double p = probs->data.db[i*nclusters + k]*inv_sum;
993                 _sample.data.db = (double*)(samples->data.ptr + samples->step*i);
994
995                 if( is_general )
996                 {
997                     cvMulTransposed( &_sample, covs_item, 1, &_mean );
998                     cvScaleAdd( covs_item, cvRealScalar(p), cov, cov );
999                 }
1000                 else
1001                     for( j = 0; j < dims; j++ )
1002                     {
1003                         double val = _sample.data.db[j] - _mean.data.db[j];
1004                         w_data[is_spherical ? 0 : j] += p*val*val;
1005                     }
1006             }
1007             
1008             if( is_spherical )
1009             {
1010                 d = w_data[0]/(double)dims;
1011                 d = MAX( d, min_variation );
1012                 w->data.db[0] = d;
1013                 log_det->data.db[k] = d;
1014             }
1015             else
1016             {
1017                 if( is_general )
1018                     cvSVD( cov, w, cov_rotate_mats[k], 0, CV_SVD_U_T );
1019                 cvMaxS( w, min_variation, w );
1020                 for( j = 0, det = 1.; j < dims; j++ )
1021                     det *= w_data[j];
1022                 log_det->data.db[k] = det;
1023             }
1024         }
1025
1026         cvConvertScale( weights, weights, 1./(double)nsamples, 0 );
1027         cvMaxS( weights, DBL_MIN, weights );
1028
1029         cvLog( log_det, log_det );
1030         if( is_spherical )
1031             cvScale( log_det, log_det, dims );
1032     } // end of iteration process
1033
1034     //log_weight_div_det[k] = -2*log(weights_k/det(Sigma_k))^0.5) = -2*log(weights_k) + log(det(Sigma_k)))
1035     if( log_weight_div_det )
1036     {
1037         cvScale( log_weights, log_weight_div_det, -2 );
1038         cvAdd( log_weight_div_det, log_det, log_weight_div_det );
1039     }
1040     
1041     /* Now finalize all the covariation matrices:
1042     1) if <cov_mat_type> == COV_MAT_DIAGONAL we used array of <w> as diagonals.
1043        Now w[k] should be copied back to the diagonals of covs[k];
1044     2) if <cov_mat_type> == COV_MAT_SPHERICAL we used the 0-th element of w[k]
1045        as an average variation in each cluster. The value of the 0-th element of w[k]
1046        should be copied to the all of the diagonal elements of covs[k]. */
1047     if( is_spherical )
1048     {
1049         for( k = 0; k < nclusters; k++ )
1050             cvSetIdentity( covs[k], cvScalar(cov_eigen_values->data.db[k]));
1051     }
1052     else if( is_diagonal )
1053     {
1054         for( k = 0; k < nclusters; k++ )
1055             cvTranspose( cvGetRow( cov_eigen_values, &whdr, k ),
1056                          cvGetDiag( covs[k], &diag ));
1057     }
1058     cvDiv( 0, cov_eigen_values, inv_eigen_values );
1059
1060     log_likelihood = _log_likelihood;
1061
1062     __END__;
1063
1064     cvReleaseMat( &log_det );
1065     cvReleaseMat( &log_weights );
1066     cvReleaseMat( &covs_item );
1067     cvReleaseMat( &centered_sample );
1068     cvReleaseMat( &cov_eigen_values );
1069     cvReleaseMat( &samples );
1070     cvReleaseMat( &sum_probs );
1071
1072     return log_likelihood;
1073 }
1074
1075
1076 int CvEM::get_nclusters() const
1077 {
1078     return params.nclusters;
1079 }
1080
1081 const CvMat* CvEM::get_means() const
1082 {
1083     return means;
1084 }
1085
1086 const CvMat** CvEM::get_covs() const
1087 {
1088     return (const CvMat**)covs;
1089 }
1090
1091 const CvMat* CvEM::get_weights() const
1092 {
1093     return weights;
1094 }
1095
1096 const CvMat* CvEM::get_probs() const
1097 {
1098     return probs;
1099 }
1100
1101 /* End of file. */