d15643c5de374f1a7527060dd6cfa2ece2735a14
[opencv] / src / ml / mlem.cpp
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                        Intel License Agreement
11 //                For Open Source Computer Vision Library
12 //
13 // Copyright( C) 2000, Intel Corporation, all rights reserved.
14 // Third party copyrights are property of their respective owners.
15 //
16 // Redistribution and use in source and binary forms, with or without modification,
17 // are permitted provided that the following conditions are met:
18 //
19 //   * Redistribution's of source code must retain the above copyright notice,
20 //     this list of conditions and the following disclaimer.
21 //
22 //   * Redistribution's in binary form must reproduce the above copyright notice,
23 //     this list of conditions and the following disclaimer in the documentation
24 //     and/or other materials provided with the distribution.
25 //
26 //   * The name of Intel Corporation may not be used to endorse or promote products
27 //     derived from this software without specific prior written permission.
28 //
29 // This software is provided by the copyright holders and contributors "as is" and
30 // any express or implied warranties, including, but not limited to, the implied
31 // warranties of merchantability and fitness for a particular purpose are disclaimed.
32 // In no event shall the Intel Corporation or contributors be liable for any direct,
33 // indirect, incidental, special, exemplary, or consequential damages
34 //(including, but not limited to, procurement of substitute goods or services;
35 // loss of use, data, or profits; or business interruption) however caused
36 // and on any theory of liability, whether in contract, strict liability,
37 // or tort(including negligence or otherwise) arising in any way out of
38 // the use of this software, even ifadvised of the possibility of such damage.
39 //
40 //M*/
41
42 #include "_ml.h"
43
44
45 /*
46    CvEM:
47  * params.nclusters    - number of clusters to cluster samples to.
48  * means               - calculated by the EM algorithm set of gaussians' means.
49  * log_weight_div_det - auxilary vector that k-th component is equal to
50                         (-2)*ln(weights_k/det(Sigma_k)^0.5),
51                         where <weights_k> is the weight,
52                         <Sigma_k> is the covariation matrice of k-th cluster.
53  * inv_eigen_values   - set of 1*dims matrices, <inv_eigen_values>[k] contains
54                         inversed eigen values of covariation matrice of the k-th cluster.
55                         In the case of <cov_mat_type> == COV_MAT_DIAGONAL,
56                         inv_eigen_values[k] = Sigma_k^(-1).
57  * covs_rotate_mats   - used only if cov_mat_type == COV_MAT_GENERIC, in all the
58                         other cases it is NULL. <covs_rotate_mats>[k] is the orthogonal
59                         matrice, obtained by the SVD-decomposition of Sigma_k.
60    Both <inv_eigen_values> and <covs_rotate_mats> fields are used for representation of
61    covariation matrices and simplifying EM calculations.
62    For fixed k denote
63    u = covs_rotate_mats[k],
64    v = inv_eigen_values[k],
65    w = v^(-1);
66    if <cov_mat_type> == COV_MAT_GENERIC, then Sigma_k = u w u',
67    else                                       Sigma_k = w.
68    Symbol ' means transposition.
69  */
70
71
72 CvEM::CvEM()
73 {
74     means = weights = probs = inv_eigen_values = log_weight_div_det = 0;
75     covs = cov_rotate_mats = 0;
76 }
77
78 CvEM::CvEM( const CvMat* samples, const CvMat* sample_idx,
79             CvEMParams params, CvMat* labels )
80 {
81     means = weights = probs = inv_eigen_values = log_weight_div_det = 0;
82     covs = cov_rotate_mats = 0;
83
84     // just invoke the train() method
85     train(samples, sample_idx, params, labels);
86 }
87
88 CvEM::~CvEM()
89 {
90     clear();
91 }
92
93
94 void CvEM::clear()
95 {
96     int i;
97
98     cvReleaseMat( &means );
99     cvReleaseMat( &weights );
100     cvReleaseMat( &probs );
101     cvReleaseMat( &inv_eigen_values );
102     cvReleaseMat( &log_weight_div_det );
103
104     if( covs || cov_rotate_mats )
105     {
106         for( i = 0; i < params.nclusters; i++ )
107         {
108             if( covs )
109                 cvReleaseMat( &covs[i] );
110             if( cov_rotate_mats )
111                 cvReleaseMat( &cov_rotate_mats[i] );
112         }
113         cvFree( &covs );
114         cvFree( &cov_rotate_mats );
115     }
116 }
117
118
119 void CvEM::set_params( const CvEMParams& _params, const CvVectors& train_data )
120 {
121     CV_FUNCNAME( "CvEM::set_params" );
122
123     __BEGIN__;
124
125     int k;
126
127     params = _params;
128     params.term_crit = cvCheckTermCriteria( params.term_crit, 1e-6, 10000 );
129
130     if( params.cov_mat_type != COV_MAT_SPHERICAL &&
131         params.cov_mat_type != COV_MAT_DIAGONAL &&
132         params.cov_mat_type != COV_MAT_GENERIC )
133         CV_ERROR( CV_StsBadArg, "Unknown covariation matrix type" );
134
135     switch( params.start_step )
136     {
137     case START_M_STEP:
138         if( !params.probs )
139             CV_ERROR( CV_StsNullPtr, "Probabilities must be specified when EM algorithm starts with M-step" );
140         break;
141     case START_E_STEP:
142         if( !params.means )
143             CV_ERROR( CV_StsNullPtr, "Mean's must be specified when EM algorithm starts with E-step" );
144         break;
145     case START_AUTO_STEP:
146         break;
147     default:
148         CV_ERROR( CV_StsBadArg, "Unknown start_step" );
149     }
150
151     if( params.nclusters < 1 )
152         CV_ERROR( CV_StsOutOfRange, "The number of clusters (mixtures) should be > 0" );
153
154     if( params.probs )
155     {
156         const CvMat* p = params.probs;
157         if( !CV_IS_MAT(p) ||
158             (CV_MAT_TYPE(p->type) != CV_32FC1  &&
159             CV_MAT_TYPE(p->type) != CV_64FC1) ||
160             p->rows != train_data.count ||
161             p->cols != params.nclusters )
162             CV_ERROR( CV_StsBadArg, "The array of probabilities must be a valid "
163             "floating-point matrix (CvMat) of 'nsamples' x 'nclusters' size" );
164     }
165
166     if( params.means )
167     {
168         const CvMat* m = params.means;
169         if( !CV_IS_MAT(m) ||
170             (CV_MAT_TYPE(m->type) != CV_32FC1  &&
171             CV_MAT_TYPE(m->type) != CV_64FC1) ||
172             m->rows != params.nclusters ||
173             m->cols != train_data.dims )
174             CV_ERROR( CV_StsBadArg, "The array of mean's must be a valid "
175             "floating-point matrix (CvMat) of 'nsamples' x 'dims' size" );
176     }
177
178     if( params.weights )
179     {
180         const CvMat* w = params.weights;
181         if( !CV_IS_MAT(w) ||
182             (CV_MAT_TYPE(w->type) != CV_32FC1  &&
183             CV_MAT_TYPE(w->type) != CV_64FC1) ||
184             (w->rows != 1 && w->cols != 1) ||
185             w->rows + w->cols - 1 != params.nclusters )
186             CV_ERROR( CV_StsBadArg, "The array of weights must be a valid "
187             "1d floating-point vector (CvMat) of 'nclusters' elements" );
188     }
189
190     if( params.covs )
191         for( k = 0; k < params.nclusters; k++ )
192         {
193             const CvMat* cov = params.covs[k];
194             if( !CV_IS_MAT(cov) ||
195                 (CV_MAT_TYPE(cov->type) != CV_32FC1  &&
196                 CV_MAT_TYPE(cov->type) != CV_64FC1) ||
197                 cov->rows != cov->cols || cov->cols != train_data.dims )
198                 CV_ERROR( CV_StsBadArg,
199                 "Each of covariation matrices must be a valid square "
200                 "floating-point matrix (CvMat) of 'dims' x 'dims'" );
201         }
202
203     __END__;
204 }
205
206
207 /****************************************************************************************/
208 float
209 CvEM::predict( const CvMat* _sample, CvMat* _probs ) const
210 {
211     float* sample_data   = 0;
212     void* buffer = 0;
213     int allocated_buffer = 0;
214     int cls = 0;
215
216     CV_FUNCNAME( "CvEM::predict" );
217     __BEGIN__;
218
219     int i, k, dims;
220     int nclusters;
221     int cov_mat_type = params.cov_mat_type;
222     double opt = FLT_MAX;
223     size_t size;
224     CvMat diff, expo;
225
226     dims = means->cols;
227     nclusters = params.nclusters;
228
229     CV_CALL( cvPreparePredictData( _sample, dims, 0, params.nclusters, _probs, &sample_data ));
230
231 // allocate memory and initializing headers for calculating
232     size = sizeof(double) * (nclusters + dims);
233     if( size <= CV_MAX_LOCAL_SIZE )
234         buffer = cvStackAlloc( size );
235     else
236     {
237         CV_CALL( buffer = cvAlloc( size ));
238         allocated_buffer = 1;
239     }
240     expo = cvMat( 1, nclusters, CV_64FC1, buffer );
241     diff = cvMat( 1, dims, CV_64FC1, (double*)buffer + nclusters );
242
243 // calculate the probabilities
244     for( k = 0; k < nclusters; k++ )
245     {
246         const double* mean_k = (const double*)(means->data.ptr + means->step*k);
247         const double* w = (const double*)(inv_eigen_values->data.ptr + inv_eigen_values->step*k);
248         double cur = log_weight_div_det->data.db[k];
249         CvMat* u = cov_rotate_mats[k];
250         // cov = u w u'  -->  cov^(-1) = u w^(-1) u'
251         if( cov_mat_type == COV_MAT_SPHERICAL )
252         {
253             double w0 = w[0];
254             for( i = 0; i < dims; i++ )
255             {
256                 double val = sample_data[i] - mean_k[i];
257                 cur += val*val*w0;
258             }
259         }
260         else
261         {
262             for( i = 0; i < dims; i++ )
263                 diff.data.db[i] = sample_data[i] - mean_k[i];
264             if( cov_mat_type == COV_MAT_GENERIC )
265                 cvGEMM( &diff, u, 1, 0, 0, &diff, CV_GEMM_B_T );
266             for( i = 0; i < dims; i++ )
267             {
268                 double val = diff.data.db[i];
269                 cur += val*val*w[i];
270             }
271         }
272
273         expo.data.db[k] = cur;
274         if( cur < opt )
275         {
276             cls = k;
277             opt = cur;
278         }
279         /* probability = (2*pi)^(-dims/2)*exp( -0.5 * cur ) */
280     }
281
282     if( _probs )
283     {
284         CV_CALL( cvConvertScale( &expo, &expo, -0.5 ));
285         CV_CALL( cvExp( &expo, &expo ));
286         if( _probs->cols == 1 )
287             CV_CALL( cvReshape( &expo, &expo, 0, nclusters ));
288         CV_CALL( cvConvertScale( &expo, _probs, 1./cvSum( &expo ).val[0] ));
289     }
290
291     __END__;
292
293     if( sample_data != _sample->data.fl )
294         cvFree( &sample_data );
295     if( allocated_buffer )
296         cvFree( &buffer );
297
298     return (float)cls;
299 }
300
301
302
303 bool CvEM::train( const CvMat* _samples, const CvMat* _sample_idx,
304                   CvEMParams _params, CvMat* labels )
305 {
306     bool result = false;
307     CvVectors train_data;
308     CvMat* sample_idx = 0;
309
310     train_data.data.fl = 0;
311     train_data.count = 0;
312
313     CV_FUNCNAME("cvEM");
314
315     __BEGIN__;
316
317     int i, nsamples, nclusters, dims;
318
319     clear();
320
321     CV_CALL( cvPrepareTrainData( "cvEM",
322         _samples, CV_ROW_SAMPLE, 0, CV_VAR_CATEGORICAL,
323         0, _sample_idx, false, (const float***)&train_data.data.fl,
324         &train_data.count, &train_data.dims, &train_data.dims,
325         0, 0, 0, &sample_idx ));
326
327     CV_CALL( set_params( _params, train_data ));
328     nsamples = train_data.count;
329     nclusters = params.nclusters;
330     dims = train_data.dims;
331
332     if( labels && (!CV_IS_MAT(labels) || CV_MAT_TYPE(labels->type) != CV_32SC1 ||
333         (labels->cols != 1 && labels->rows != 1) || labels->cols + labels->rows - 1 != nsamples ))
334         CV_ERROR( CV_StsBadArg,
335         "labels array (when passed) must be a valid 1d integer vector of <sample_count> elements" );
336
337     if( nsamples <= nclusters )
338         CV_ERROR( CV_StsOutOfRange,
339         "The number of samples should be greater than the number of clusters" );
340
341     CV_CALL( log_weight_div_det = cvCreateMat( 1, nclusters, CV_64FC1 ));
342     CV_CALL( probs  = cvCreateMat( nsamples, nclusters, CV_64FC1 ));
343     CV_CALL( means = cvCreateMat( nclusters, dims, CV_64FC1 ));
344     CV_CALL( weights = cvCreateMat( 1, nclusters, CV_64FC1 ));
345     CV_CALL( inv_eigen_values = cvCreateMat( nclusters,
346         params.cov_mat_type == COV_MAT_SPHERICAL ? 1 : dims, CV_64FC1 ));
347     CV_CALL( covs = (CvMat**)cvAlloc( nclusters * sizeof(*covs) ));
348     CV_CALL( cov_rotate_mats = (CvMat**)cvAlloc( nclusters * sizeof(cov_rotate_mats[0]) ));
349
350     for( i = 0; i < nclusters; i++ )
351     {
352         CV_CALL( covs[i] = cvCreateMat( dims, dims, CV_64FC1 ));
353         CV_CALL( cov_rotate_mats[i]  = cvCreateMat( dims, dims, CV_64FC1 ));
354         cvZero( cov_rotate_mats[i] );
355     }
356
357     init_em( train_data );
358     log_likelihood = run_em( train_data );
359     if( log_likelihood <= -DBL_MAX/10000. )
360         EXIT;
361
362     if( labels )
363     {
364         if( nclusters == 1 )
365             cvZero( labels );
366         else
367         {
368             CvMat sample = cvMat( 1, dims, CV_32F );
369             CvMat prob = cvMat( 1, nclusters, CV_64F );
370             int lstep = CV_IS_MAT_CONT(labels->type) ? 1 : labels->step/sizeof(int);
371
372             for( i = 0; i < nsamples; i++ )
373             {
374                 int idx = sample_idx ? sample_idx->data.i[i] : i;
375                 sample.data.ptr = _samples->data.ptr + _samples->step*idx;
376                 prob.data.ptr = probs->data.ptr + probs->step*i;
377
378                 labels->data.i[i*lstep] = cvRound(predict(&sample, &prob));
379             }
380         }
381     }
382
383     result = true;
384
385     __END__;
386
387     if( sample_idx != _sample_idx )
388         cvReleaseMat( &sample_idx );
389
390     cvFree( &train_data.data.ptr );
391
392     return result;
393 }
394
395
396 void CvEM::init_em( const CvVectors& train_data )
397 {
398     CvMat *w = 0, *u = 0, *tcov = 0;
399
400     CV_FUNCNAME( "CvEM::init_em" );
401
402     __BEGIN__;
403
404     double maxval = 0;
405     int i, force_symm_plus = 0;
406     int nclusters = params.nclusters, nsamples = train_data.count, dims = train_data.dims;
407
408     if( params.start_step == START_AUTO_STEP || nclusters == 1 || nclusters == nsamples )
409         init_auto( train_data );
410     else if( params.start_step == START_M_STEP )
411     {
412         for( i = 0; i < nsamples; i++ )
413         {
414             CvMat prob;
415             cvGetRow( params.probs, &prob, i );
416             cvMaxS( &prob, 0., &prob );
417             cvMinMaxLoc( &prob, 0, &maxval );
418             if( maxval < FLT_EPSILON )
419                 cvSet( &prob, cvScalar(1./nclusters) );
420             else
421                 cvNormalize( &prob, &prob, 1., 0, CV_L1 );
422         }
423         EXIT; // do not preprocess covariation matrices,
424               // as in this case they are initialized at the first iteration of EM
425     }
426     else
427     {
428         CV_ASSERT( params.start_step == START_E_STEP && params.means );
429         if( params.weights && params.covs )
430         {
431             cvConvert( params.means, means );
432             cvReshape( weights, weights, 1, params.weights->rows );
433             cvConvert( params.weights, weights );
434             cvReshape( weights, weights, 1, 1 );
435             cvMaxS( weights, 0., weights );
436             cvMinMaxLoc( weights, 0, &maxval );
437             if( maxval < FLT_EPSILON )
438                 cvSet( weights, cvScalar(1./nclusters) );
439             cvNormalize( weights, weights, 1., 0, CV_L1 );
440             for( i = 0; i < nclusters; i++ )
441                 CV_CALL( cvConvert( params.covs[i], covs[i] ));
442             force_symm_plus = 1;
443         }
444         else
445             init_auto( train_data );
446     }
447
448     CV_CALL( tcov = cvCreateMat( dims, dims, CV_64FC1 ));
449     CV_CALL( w = cvCreateMat( dims, dims, CV_64FC1 ));
450     if( params.cov_mat_type != COV_MAT_SPHERICAL )
451         CV_CALL( u = cvCreateMat( dims, dims, CV_64FC1 ));
452
453     for( i = 0; i < nclusters; i++ )
454     {
455         if( force_symm_plus )
456         {
457             cvTranspose( covs[i], tcov );
458             cvAddWeighted( covs[i], 0.5, tcov, 0.5, 0, tcov );
459         }
460         else
461             cvCopy( covs[i], tcov );
462         cvSVD( tcov, w, u, 0, CV_SVD_MODIFY_A + CV_SVD_U_T + CV_SVD_V_T );
463         if( params.cov_mat_type == COV_MAT_SPHERICAL )
464             cvSetIdentity( covs[i], cvScalar(cvTrace(w).val[0]/dims) );
465         /*else if( params.cov_mat_type == COV_MAT_DIAGONAL )
466             cvCopy( w, covs[i] );*/
467         else
468         {
469             // generic case: covs[i] = (u')'*max(w,0)*u'
470             cvGEMM( u, w, 1, 0, 0, tcov, CV_GEMM_A_T );
471             cvGEMM( tcov, u, 1, 0, 0, covs[i], 0 );
472         }
473     }
474
475     __END__;
476
477     cvReleaseMat( &w );
478     cvReleaseMat( &u );
479     cvReleaseMat( &tcov );
480 }
481
482
483 void CvEM::init_auto( const CvVectors& train_data )
484 {
485     CvMat* hdr = 0;
486     const void** vec = 0;
487     CvMat* class_ranges = 0;
488     CvMat* labels = 0;
489
490     CV_FUNCNAME( "CvEM::init_auto" );
491
492     __BEGIN__;
493
494     int nclusters = params.nclusters, nsamples = train_data.count, dims = train_data.dims;
495     int i, j;
496
497     if( nclusters == nsamples )
498     {
499         CvMat src = cvMat( 1, dims, CV_32F );
500         CvMat dst = cvMat( 1, dims, CV_64F );
501         for( i = 0; i < nsamples; i++ )
502         {
503             src.data.ptr = train_data.data.ptr[i];
504             dst.data.ptr = means->data.ptr + means->step*i;
505             cvConvert( &src, &dst );
506             cvZero( covs[i] );
507             cvSetIdentity( cov_rotate_mats[i] );
508         }
509         cvSetIdentity( probs );
510         cvSet( weights, cvScalar(1./nclusters) );
511     }
512     else
513     {
514         int max_count = 0;
515
516         CV_CALL( class_ranges = cvCreateMat( 1, nclusters+1, CV_32SC1 ));
517         if( nclusters > 1 )
518         {
519             CV_CALL( labels = cvCreateMat( 1, nsamples, CV_32SC1 ));
520             kmeans( train_data, nclusters, labels, cvTermCriteria( CV_TERMCRIT_ITER,
521                     params.means ? 1 : 10, 0.5 ), params.means );
522             CV_CALL( cvSortSamplesByClasses( (const float**)train_data.data.fl,
523                                             labels, class_ranges->data.i ));
524         }
525         else
526         {
527             class_ranges->data.i[0] = 0;
528             class_ranges->data.i[1] = nsamples;
529         }
530
531         for( i = 0; i < nclusters; i++ )
532         {
533             int left = class_ranges->data.i[i], right = class_ranges->data.i[i+1];
534             max_count = MAX( max_count, right - left );
535         }
536         CV_CALL( hdr = (CvMat*)cvAlloc( max_count*sizeof(hdr[0]) ));
537         CV_CALL( vec = (const void**)cvAlloc( max_count*sizeof(vec[0]) ));
538         hdr[0] = cvMat( 1, dims, CV_32F );
539         for( i = 0; i < max_count; i++ )
540         {
541             vec[i] = hdr + i;
542             hdr[i] = hdr[0];
543         }
544
545         for( i = 0; i < nclusters; i++ )
546         {
547             int left = class_ranges->data.i[i], right = class_ranges->data.i[i+1];
548             int cluster_size = right - left;
549             CvMat avg;
550
551             if( cluster_size <= 0 )
552                 continue;
553
554             for( j = left; j < right; j++ )
555                 hdr[j - left].data.fl = train_data.data.fl[j];
556
557             CV_CALL( cvGetRow( means, &avg, i ));
558             CV_CALL( cvCalcCovarMatrix( vec, cluster_size, covs[i],
559                 &avg, CV_COVAR_NORMAL | CV_COVAR_SCALE ));
560             weights->data.db[i] = (double)cluster_size/(double)nsamples;
561         }
562     }
563
564     __END__;
565
566     cvReleaseMat( &class_ranges );
567     cvReleaseMat( &labels );
568     cvFree( &hdr );
569     cvFree( &vec );
570 }
571
572
573 void CvEM::kmeans( const CvVectors& train_data, int nclusters, CvMat* labels,
574                    CvTermCriteria termcrit, const CvMat* centers0 )
575 {
576     CvMat* centers = 0;
577     CvMat* old_centers = 0;
578     CvMat* counters = 0;
579
580     CV_FUNCNAME( "CvEM::kmeans" );
581
582     __BEGIN__;
583
584     CvRNG rng = cvRNG(-1);
585     int i, j, k, nsamples, dims;
586     int iter = 0;
587     double max_dist = DBL_MAX;
588
589     termcrit = cvCheckTermCriteria( termcrit, 1e-6, 100 );
590     termcrit.epsilon *= termcrit.epsilon;
591     nsamples = train_data.count;
592     dims = train_data.dims;
593     nclusters = MIN( nclusters, nsamples );
594
595     CV_CALL( centers = cvCreateMat( nclusters, dims, CV_64FC1 ));
596     CV_CALL( old_centers = cvCreateMat( nclusters, dims, CV_64FC1 ));
597     CV_CALL( counters = cvCreateMat( 1, nclusters, CV_32SC1 ));
598     cvZero( old_centers );
599
600     if( centers0 )
601     {
602         CV_CALL( cvConvert( centers0, centers ));
603     }
604     else
605     {
606         for( i = 0; i < nsamples; i++ )
607             labels->data.i[i] = i*nclusters/nsamples;
608         cvRandShuffle( labels, &rng );
609     }
610
611     for( ;; )
612     {
613         CvMat* temp;
614
615         if( iter > 0 || centers0 )
616         {
617             for( i = 0; i < nsamples; i++ )
618             {
619                 const float* s = train_data.data.fl[i];
620                 int k_best = 0;
621                 double min_dist = DBL_MAX;
622
623                 for( k = 0; k < nclusters; k++ )
624                 {
625                     const double* c = (double*)(centers->data.ptr + k*centers->step);
626                     double dist = 0;
627
628                     for( j = 0; j <= dims - 4; j += 4 )
629                     {
630                         double t0 = c[j] - s[j];
631                         double t1 = c[j+1] - s[j+1];
632                         dist += t0*t0 + t1*t1;
633                         t0 = c[j+2] - s[j+2];
634                         t1 = c[j+3] - s[j+3];
635                         dist += t0*t0 + t1*t1;
636                     }
637
638                     for( ; j < dims; j++ )
639                     {
640                         double t = c[j] - s[j];
641                         dist += t*t;
642                     }
643
644                     if( min_dist > dist )
645                     {
646                         min_dist = dist;
647                         k_best = k;
648                     }
649                 }
650
651                 labels->data.i[i] = k_best;
652             }
653         }
654
655         if( ++iter > termcrit.max_iter )
656             break;
657
658         CV_SWAP( centers, old_centers, temp );
659         cvZero( centers );
660         cvZero( counters );
661
662         // update centers
663         for( i = 0; i < nsamples; i++ )
664         {
665             const float* s = train_data.data.fl[i];
666             k = labels->data.i[i];
667             double* c = (double*)(centers->data.ptr + k*centers->step);
668
669             for( j = 0; j <= dims - 4; j += 4 )
670             {
671                 double t0 = c[j] + s[j];
672                 double t1 = c[j+1] + s[j+1];
673
674                 c[j] = t0;
675                 c[j+1] = t1;
676
677                 t0 = c[j+2] + s[j+2];
678                 t1 = c[j+3] + s[j+3];
679
680                 c[j+2] = t0;
681                 c[j+3] = t1;
682             }
683             for( ; j < dims; j++ )
684                 c[j] += s[j];
685             counters->data.i[k]++;
686         }
687
688         if( iter > 1 )
689             max_dist = 0;
690
691         for( k = 0; k < nclusters; k++ )
692         {
693             double* c = (double*)(centers->data.ptr + k*centers->step);
694             if( counters->data.i[k] != 0 )
695             {
696                 double scale = 1./counters->data.i[k];
697                 for( j = 0; j < dims; j++ )
698                     c[j] *= scale;
699             }
700             else
701             {
702                 const float* s;
703                 for( j = 0; j < 10; j++ )
704                 {
705                     i = cvRandInt( &rng ) % nsamples;
706                     if( counters->data.i[labels->data.i[i]] > 1 )
707                         break;
708                 }
709                 s = train_data.data.fl[i];
710                 for( j = 0; j < dims; j++ )
711                     c[j] = s[j];
712             }
713
714             if( iter > 1 )
715             {
716                 double dist = 0;
717                 const double* c_o = (double*)(old_centers->data.ptr + k*old_centers->step);
718                 for( j = 0; j < dims; j++ )
719                 {
720                     double t = c[j] - c_o[j];
721                     dist += t*t;
722                 }
723                 if( max_dist < dist )
724                     max_dist = dist;
725             }
726         }
727
728         if( max_dist < termcrit.epsilon )
729             break;
730     }
731
732     cvZero( counters );
733     for( i = 0; i < nsamples; i++ )
734         counters->data.i[labels->data.i[i]]++;
735
736     // ensure that we do not have empty clusters
737     for( k = 0; k < nclusters; k++ )
738         if( counters->data.i[k] == 0 )
739             for(;;)
740             {
741                 i = cvRandInt(&rng) % nsamples;
742                 j = labels->data.i[i];
743                 if( counters->data.i[j] > 1 )
744                 {
745                     labels->data.i[i] = k;
746                     counters->data.i[j]--;
747                     counters->data.i[k]++;
748                     break;
749                 }
750             }
751
752     __END__;
753
754     cvReleaseMat( &centers );
755     cvReleaseMat( &old_centers );
756     cvReleaseMat( &counters );
757 }
758
759
760 /****************************************************************************************/
761 /* log_weight_div_det[k] = -2*log(weights_k) + log(det(Sigma_k)))
762
763    covs[k] = cov_rotate_mats[k] * cov_eigen_values[k] * (cov_rotate_mats[k])'
764    cov_rotate_mats[k] are orthogonal matrices of eigenvectors and
765    cov_eigen_values[k] are diagonal matrices (represented by 1D vectors) of eigen values.
766
767    The <alpha_ik> is the probability of the vector x_i to belong to the k-th cluster:
768    <alpha_ik> ~ weights_k * exp{ -0.5[ln(det(Sigma_k)) + (x_i - mu_k)' Sigma_k^(-1) (x_i - mu_k)] }
769    We calculate these probabilities here by the equivalent formulae:
770    Denote
771    S_ik = -0.5(log(det(Sigma_k)) + (x_i - mu_k)' Sigma_k^(-1) (x_i - mu_k)) + log(weights_k),
772    M_i = max_k S_ik = S_qi, so that the q-th class is the one where maximum reaches. Then
773    alpha_ik = exp{ S_ik - M_i } / ( 1 + sum_j!=q exp{ S_ji - M_i })
774 */
775 double CvEM::run_em( const CvVectors& train_data )
776 {
777     CvMat* centered_sample = 0;
778     CvMat* covs_item = 0;
779     CvMat* log_det = 0;
780     CvMat* log_weights = 0;
781     CvMat* cov_eigen_values = 0;
782     CvMat* samples = 0;
783     CvMat* sum_probs = 0;
784     log_likelihood = -DBL_MAX;
785
786     CV_FUNCNAME( "CvEM::run_em" );
787     __BEGIN__;
788
789     int nsamples = train_data.count, dims = train_data.dims, nclusters = params.nclusters;
790     double min_variation = FLT_EPSILON;
791     double min_det_value = MAX( DBL_MIN, pow( min_variation, dims ));
792     double likelihood_bias = -CV_LOG2PI * (double)nsamples * (double)dims / 2., _log_likelihood = -DBL_MAX;
793     int start_step = params.start_step;
794
795     int i, j, k, n;
796     int is_general = 0, is_diagonal = 0, is_spherical = 0;
797     double prev_log_likelihood = -DBL_MAX / 1000., det, d;
798     CvMat whdr, iwhdr, diag, *w, *iw;
799     double* w_data;
800     double* sp_data;
801
802     if( nclusters == 1 )
803     {
804         double log_weight;
805         CV_CALL( cvSet( probs, cvScalar(1.)) );
806
807         if( params.cov_mat_type == COV_MAT_SPHERICAL )
808         {
809             d = cvTrace(*covs).val[0]/dims;
810             d = MAX( d, FLT_EPSILON );
811             inv_eigen_values->data.db[0] = 1./d;
812             log_weight = pow( d, dims*0.5 );
813         }
814         else
815         {
816             w_data = inv_eigen_values->data.db;
817
818             if( params.cov_mat_type == COV_MAT_GENERIC )
819                 cvSVD( *covs, inv_eigen_values, *cov_rotate_mats, 0, CV_SVD_U_T );
820             else
821                 cvTranspose( cvGetDiag(*covs, &diag), inv_eigen_values );
822
823             cvMaxS( inv_eigen_values, FLT_EPSILON, inv_eigen_values );
824             for( j = 0, det = 1.; j < dims; j++ )
825                 det *= w_data[j];
826             log_weight = sqrt(det);
827             cvDiv( 0, inv_eigen_values, inv_eigen_values );
828         }
829
830         log_weight_div_det->data.db[0] = -2*log(weights->data.db[0]/log_weight);
831         log_likelihood = DBL_MAX/1000.;
832         EXIT;
833     }
834
835     if( params.cov_mat_type == COV_MAT_GENERIC )
836         is_general  = 1;
837     else if( params.cov_mat_type == COV_MAT_DIAGONAL )
838         is_diagonal = 1;
839     else if( params.cov_mat_type == COV_MAT_SPHERICAL )
840         is_spherical  = 1;
841     /* In the case of <cov_mat_type> == COV_MAT_DIAGONAL, the k-th row of cov_eigen_values
842     contains the diagonal elements (variations). In the case of
843     <cov_mat_type> == COV_MAT_SPHERICAL - the 0-ths elements of the vectors cov_eigen_values[k]
844     are to be equal to the mean of the variations over all the dimensions. */
845
846     CV_CALL( log_det = cvCreateMat( 1, nclusters, CV_64FC1 ));
847     CV_CALL( log_weights = cvCreateMat( 1, nclusters, CV_64FC1 ));
848     CV_CALL( covs_item = cvCreateMat( dims, dims, CV_64FC1 ));
849     CV_CALL( centered_sample = cvCreateMat( 1, dims, CV_64FC1 ));
850     CV_CALL( cov_eigen_values = cvCreateMat( inv_eigen_values->rows, inv_eigen_values->cols, CV_64FC1 ));
851     CV_CALL( samples = cvCreateMat( nsamples, dims, CV_64FC1 ));
852     CV_CALL( sum_probs = cvCreateMat( 1, nclusters, CV_64FC1 ));
853     sp_data = sum_probs->data.db;
854
855     // copy the training data into double-precision matrix
856     for( i = 0; i < nsamples; i++ )
857     {
858         const float* src = train_data.data.fl[i];
859         double* dst = (double*)(samples->data.ptr + samples->step*i);
860
861         for( j = 0; j < dims; j++ )
862             dst[j] = src[j];
863     }
864
865     if( start_step != START_M_STEP )
866     {
867         for( k = 0; k < nclusters; k++ )
868         {
869             if( is_general || is_diagonal )
870             {
871                 w = cvGetRow( cov_eigen_values, &whdr, k );
872                 if( is_general )
873                     cvSVD( covs[k], w, cov_rotate_mats[k], 0, CV_SVD_U_T );
874                 else
875                     cvTranspose( cvGetDiag( covs[k], &diag ), w );
876                 w_data = w->data.db;
877                 for( j = 0, det = 1.; j < dims; j++ )
878                     det *= w_data[j];
879                 if( det < min_det_value )
880                 {
881                     if( start_step == START_AUTO_STEP )
882                         det = min_det_value;
883                     else
884                         EXIT;
885                 }
886                 log_det->data.db[k] = det;
887             }
888             else
889             {
890                 d = cvTrace(covs[k]).val[0]/(double)dims;
891                 if( d < min_variation )
892                 {
893                     if( start_step == START_AUTO_STEP )
894                         d = min_variation;
895                     else
896                         EXIT;
897                 }
898                 cov_eigen_values->data.db[k] = d;
899                 log_det->data.db[k] = d;
900             }
901         }
902
903         cvLog( log_det, log_det );
904         if( is_spherical )
905             cvScale( log_det, log_det, dims );
906     }
907
908     for( n = 0; n < params.term_crit.max_iter; n++ )
909     {
910         if( n > 0 || start_step != START_M_STEP )
911         {
912             // e-step: compute probs_ik from means_k, covs_k and weights_k.
913             CV_CALL(cvLog( weights, log_weights ));
914
915             // S_ik = -0.5[log(det(Sigma_k)) + (x_i - mu_k)' Sigma_k^(-1) (x_i - mu_k)] + log(weights_k)
916             for( k = 0; k < nclusters; k++ )
917             {
918                 CvMat* u = cov_rotate_mats[k];
919                 const double* mean = (double*)(means->data.ptr + means->step*k);
920                 w = cvGetRow( cov_eigen_values, &whdr, k );
921                 iw = cvGetRow( inv_eigen_values, &iwhdr, k );
922                 cvDiv( 0, w, iw );
923
924                 w_data = (double*)(inv_eigen_values->data.ptr + inv_eigen_values->step*k);
925
926                 for( i = 0; i < nsamples; i++ )
927                 {
928                     double *csample = centered_sample->data.db, p = log_det->data.db[k];
929                     const double* sample = (double*)(samples->data.ptr + samples->step*i);
930                     double* pp = (double*)(probs->data.ptr + probs->step*i);
931                     for( j = 0; j < dims; j++ )
932                         csample[j] = sample[j] - mean[j];
933                     if( is_general )
934                         cvGEMM( centered_sample, u, 1, 0, 0, centered_sample, CV_GEMM_B_T );
935                     for( j = 0; j < dims; j++ )
936                         p += csample[j]*csample[j]*w_data[is_spherical ? 0 : j];
937                     pp[k] = -0.5*p + log_weights->data.db[k];
938
939                     // S_ik <- S_ik - max_j S_ij
940                     if( k == nclusters - 1 )
941                     {
942                         double max_val = 0;
943                         for( j = 0; j < nclusters; j++ )
944                             max_val = MAX( max_val, pp[j] );
945                         for( j = 0; j < nclusters; j++ )
946                             pp[j] -= max_val;
947                     }
948                 }
949             }
950
951             CV_CALL(cvExp( probs, probs )); // exp( S_ik )
952             cvZero( sum_probs );
953
954             // alpha_ik = exp( S_ik ) / sum_j exp( S_ij ),
955             // log_likelihood = sum_i log (sum_j exp(S_ij))
956             for( i = 0, _log_likelihood = likelihood_bias; i < nsamples; i++ )
957             {
958                 double* pp = (double*)(probs->data.ptr + probs->step*i), sum = 0;
959                 for( j = 0; j < nclusters; j++ )
960                     sum += pp[j];
961                 sum = 1./MAX( sum, DBL_EPSILON );
962                 for( j = 0; j < nclusters; j++ )
963                 {
964                     double p = pp[j] *= sum;
965                     sp_data[j] += p;
966                 }
967                 _log_likelihood -= log( sum );
968             }
969
970             // check termination criteria
971             if( fabs( (_log_likelihood - prev_log_likelihood) / prev_log_likelihood ) < params.term_crit.epsilon )
972                 break;
973             prev_log_likelihood = _log_likelihood;
974         }
975
976         // m-step: update means_k, covs_k and weights_k from probs_ik
977         cvGEMM( probs, samples, 1, 0, 0, means, CV_GEMM_A_T );
978
979         for( k = 0; k < nclusters; k++ )
980         {
981             double sum = sp_data[k], inv_sum = 1./sum;
982             CvMat* cov = covs[k], _mean, _sample;
983
984             w = cvGetRow( cov_eigen_values, &whdr, k );
985             w_data = w->data.db;
986             cvGetRow( means, &_mean, k );
987             cvGetRow( samples, &_sample, k );
988
989             // update weights_k
990             weights->data.db[k] = sum;
991
992             // update means_k
993             cvScale( &_mean, &_mean, inv_sum );
994
995             // compute covs_k
996             cvZero( cov );
997             cvZero( w );
998
999             for( i = 0; i < nsamples; i++ )
1000             {
1001                 double p = probs->data.db[i*nclusters + k]*inv_sum;
1002                 _sample.data.db = (double*)(samples->data.ptr + samples->step*i);
1003
1004                 if( is_general )
1005                 {
1006                     cvMulTransposed( &_sample, covs_item, 1, &_mean );
1007                     cvScaleAdd( covs_item, cvRealScalar(p), cov, cov );
1008                 }
1009                 else
1010                     for( j = 0; j < dims; j++ )
1011                     {
1012                         double val = _sample.data.db[j] - _mean.data.db[j];
1013                         w_data[is_spherical ? 0 : j] += p*val*val;
1014                     }
1015             }
1016
1017             if( is_spherical )
1018             {
1019                 d = w_data[0]/(double)dims;
1020                 d = MAX( d, min_variation );
1021                 w->data.db[0] = d;
1022                 log_det->data.db[k] = d;
1023             }
1024             else
1025             {
1026                 if( is_general )
1027                     cvSVD( cov, w, cov_rotate_mats[k], 0, CV_SVD_U_T );
1028                 cvMaxS( w, min_variation, w );
1029                 for( j = 0, det = 1.; j < dims; j++ )
1030                     det *= w_data[j];
1031                 log_det->data.db[k] = det;
1032             }
1033         }
1034
1035         cvConvertScale( weights, weights, 1./(double)nsamples, 0 );
1036         cvMaxS( weights, DBL_MIN, weights );
1037
1038         cvLog( log_det, log_det );
1039         if( is_spherical )
1040             cvScale( log_det, log_det, dims );
1041     } // end of iteration process
1042
1043     //log_weight_div_det[k] = -2*log(weights_k/det(Sigma_k))^0.5) = -2*log(weights_k) + log(det(Sigma_k)))
1044     if( log_weight_div_det )
1045     {
1046         cvScale( log_weights, log_weight_div_det, -2 );
1047         cvAdd( log_weight_div_det, log_det, log_weight_div_det );
1048     }
1049
1050     /* Now finalize all the covariation matrices:
1051     1) if <cov_mat_type> == COV_MAT_DIAGONAL we used array of <w> as diagonals.
1052        Now w[k] should be copied back to the diagonals of covs[k];
1053     2) if <cov_mat_type> == COV_MAT_SPHERICAL we used the 0-th element of w[k]
1054        as an average variation in each cluster. The value of the 0-th element of w[k]
1055        should be copied to the all of the diagonal elements of covs[k]. */
1056     if( is_spherical )
1057     {
1058         for( k = 0; k < nclusters; k++ )
1059             cvSetIdentity( covs[k], cvScalar(cov_eigen_values->data.db[k]));
1060     }
1061     else if( is_diagonal )
1062     {
1063         for( k = 0; k < nclusters; k++ )
1064             cvTranspose( cvGetRow( cov_eigen_values, &whdr, k ),
1065                          cvGetDiag( covs[k], &diag ));
1066     }
1067     cvDiv( 0, cov_eigen_values, inv_eigen_values );
1068
1069     log_likelihood = _log_likelihood;
1070
1071     __END__;
1072
1073     cvReleaseMat( &log_det );
1074     cvReleaseMat( &log_weights );
1075     cvReleaseMat( &covs_item );
1076     cvReleaseMat( &centered_sample );
1077     cvReleaseMat( &cov_eigen_values );
1078     cvReleaseMat( &samples );
1079     cvReleaseMat( &sum_probs );
1080
1081     return log_likelihood;
1082 }
1083
1084
1085 int CvEM::get_nclusters() const
1086 {
1087     return params.nclusters;
1088 }
1089
1090 const CvMat* CvEM::get_means() const
1091 {
1092     return means;
1093 }
1094
1095 const CvMat** CvEM::get_covs() const
1096 {
1097     return (const CvMat**)covs;
1098 }
1099
1100 const CvMat* CvEM::get_weights() const
1101 {
1102     return weights;
1103 }
1104
1105 const CvMat* CvEM::get_probs() const
1106 {
1107     return probs;
1108 }
1109
1110 using namespace cv;
1111
1112 CvEM::CvEM( const Mat& samples, const Mat& sample_idx,
1113            CvEMParams params, Mat* labels )
1114 {
1115     means = weights = probs = inv_eigen_values = log_weight_div_det = 0;
1116     covs = cov_rotate_mats = 0;
1117     
1118     // just invoke the train() method
1119     train(samples, sample_idx, params, labels);
1120 }    
1121
1122 bool CvEM::train( const Mat& _samples, const Mat& _sample_idx,
1123                  CvEMParams _params, Mat* _labels )
1124 {
1125     CvMat samples = _samples, sidx = _sample_idx, labels, *plabels = 0;
1126     
1127     if( _labels )
1128     {
1129         int nsamples = sidx.data.ptr ? sidx.rows : samples.rows;
1130         
1131         if( !(_labels->data && _labels->type() == CV_32SC1 &&
1132               (_labels->cols == 1 || _labels->rows == 1) &&
1133               _labels->cols + _labels->rows - 1 == nsamples) )
1134             _labels->create(nsamples, 1, CV_32SC1);
1135         plabels = &(labels = *_labels);
1136     }
1137     return train(&samples, sidx.data.ptr ? &sidx : 0, _params, plabels);
1138 }
1139
1140 float
1141 CvEM::predict( const Mat& _sample, Mat* _probs ) const
1142 {
1143     CvMat sample = _sample, probs, *pprobs = 0;
1144     
1145     if( _probs )
1146     {
1147         int nclusters = params.nclusters;
1148         if(!(_probs->data && (_probs->type() == CV_32F || _probs->type()==CV_64F) &&
1149              (_probs->cols == 1 || _probs->rows == 1) &&
1150              _probs->cols + _probs->rows - 1 == nclusters))
1151             _probs->create(nclusters, 1, _sample.type());
1152         pprobs = &(probs = *_probs);
1153     }
1154     return predict(&sample, pprobs);
1155 }
1156
1157
1158 /* End of file. */