Move the sources to trunk
[opencv] / ml / src / mlnbayes.cpp
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //            Intel License Agreement
11 //
12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
13 // Third party copyrights are property of their respective owners.
14 //
15 // Redistribution and use in source and binary forms, with or without modification,
16 // are permitted provided that the following conditions are met:
17 //
18 //   * Redistribution's of source code must retain the above copyright notice,
19 //     this list of conditions and the following disclaimer.
20 //
21 //   * Redistribution's in binary form must reproduce the above copyright notice,
22 //     this list of conditions and the following disclaimer in the documentation
23 //     and/or other materials provided with the distribution.
24 //
25 //   * The name of Intel Corporation may not be used to endorse or promote products
26 //     derived from this software without specific prior written permission.
27 //
28 // This software is provided by the copyright holders and contributors "as is" and
29 // any express or implied warranties, including, but not limited to, the implied
30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
31 // In no event shall the Intel Corporation or contributors be liable for any direct,
32 // indirect, incidental, special, exemplary, or consequential damages
33 // (including, but not limited to, procurement of substitute goods or services;
34 // loss of use, data, or profits; or business interruption) however caused
35 // and on any theory of liability, whether in contract, strict liability,
36 // or tort (including negligence or otherwise) arising in any way out of
37 // the use of this software, even if advised of the possibility of such damage.
38 //
39 //M*/
40
41 #include "_ml.h"
42
43 CvNormalBayesClassifier::CvNormalBayesClassifier()
44 {
45     var_count = var_all = 0;
46     var_idx = 0;
47     cls_labels = 0;
48     count = 0;
49     sum = 0;
50     productsum = 0;
51     avg = 0;
52     inv_eigen_values = 0;
53     cov_rotate_mats = 0;
54     c = 0;
55     default_model_name = "my_nb";
56 }
57
58
59 void CvNormalBayesClassifier::clear()
60 {
61     if( cls_labels )
62     {
63         for( int cls = 0; cls < cls_labels->cols; cls++ )
64         {
65             cvReleaseMat( &count[cls] );
66             cvReleaseMat( &sum[cls] );
67             cvReleaseMat( &productsum[cls] );
68             cvReleaseMat( &avg[cls] );
69             cvReleaseMat( &inv_eigen_values[cls] );
70             cvReleaseMat( &cov_rotate_mats[cls] );
71         }
72     }
73     
74     cvReleaseMat( &cls_labels );
75     cvReleaseMat( &var_idx );
76     cvReleaseMat( &c );
77     cvFree( &count );
78 }
79
80
81 CvNormalBayesClassifier::~CvNormalBayesClassifier()
82 {
83     clear();
84 }
85
86
87 CvNormalBayesClassifier::CvNormalBayesClassifier(
88     const CvMat* _train_data, const CvMat* _responses,
89     const CvMat* _var_idx, const CvMat* _sample_idx )
90 {
91     var_count = var_all = 0;
92     var_idx = 0;
93     cls_labels = 0;
94     count = 0;
95     sum = 0;
96     productsum = 0;
97     avg = 0;
98     inv_eigen_values = 0;
99     cov_rotate_mats = 0;
100     c = 0;
101     default_model_name = "my_nb";
102
103     train( _train_data, _responses, _var_idx, _sample_idx );
104 }
105
106
107 bool CvNormalBayesClassifier::train( const CvMat* _train_data, const CvMat* _responses,
108                             const CvMat* _var_idx, const CvMat* _sample_idx, bool update )
109 {
110     const float min_variation = FLT_EPSILON;
111     bool result = false;
112     CvMat* responses   = 0;
113     const float** train_data = 0;
114     CvMat* __cls_labels = 0;
115     CvMat* __var_idx = 0;
116     CvMat* cov = 0;
117     
118     CV_FUNCNAME( "CvNormalBayesClassifier::train" );
119
120     __BEGIN__;
121
122     int cls, nsamples = 0, _var_count = 0, _var_all = 0, nclasses = 0;
123     int s, c1, c2;
124     const int* responses_data;
125     
126     CV_CALL( cvPrepareTrainData( 0,
127         _train_data, CV_ROW_SAMPLE, _responses, CV_VAR_CATEGORICAL,
128         _var_idx, _sample_idx, false, &train_data,
129         &nsamples, &_var_count, &_var_all, &responses,
130         &__cls_labels, &__var_idx ));
131
132     if( !update )
133     {
134         const size_t mat_size = sizeof(CvMat*);
135         size_t data_size;
136
137         clear();
138
139         var_idx = __var_idx;
140         cls_labels = __cls_labels;
141         __var_idx = __cls_labels = 0;
142         var_count = _var_count;
143         var_all = _var_all;
144
145         nclasses = cls_labels->cols;
146         data_size = nclasses*6*mat_size;
147
148         CV_CALL( count = (CvMat**)cvAlloc( data_size ));
149         memset( count, 0, data_size );
150
151         sum             = count      + nclasses;
152         productsum      = sum        + nclasses;
153         avg             = productsum + nclasses;
154         inv_eigen_values= avg        + nclasses;
155         cov_rotate_mats = inv_eigen_values         + nclasses;
156         
157         CV_CALL( c = cvCreateMat( 1, nclasses, CV_64FC1 ));
158
159         for( cls = 0; cls < nclasses; cls++ )
160         {
161             CV_CALL(count[cls]            = cvCreateMat( 1, var_count, CV_32SC1 ));
162             CV_CALL(sum[cls]              = cvCreateMat( 1, var_count, CV_64FC1 ));
163             CV_CALL(productsum[cls]       = cvCreateMat( var_count, var_count, CV_64FC1 ));
164             CV_CALL(avg[cls]              = cvCreateMat( 1, var_count, CV_64FC1 ));
165             CV_CALL(inv_eigen_values[cls] = cvCreateMat( 1, var_count, CV_64FC1 ));
166             CV_CALL(cov_rotate_mats[cls]  = cvCreateMat( var_count, var_count, CV_64FC1 ));
167             CV_CALL(cvZero( count[cls] ));
168             CV_CALL(cvZero( sum[cls] ));
169             CV_CALL(cvZero( productsum[cls] ));
170             CV_CALL(cvZero( avg[cls] ));
171             CV_CALL(cvZero( inv_eigen_values[cls] ));
172             CV_CALL(cvZero( cov_rotate_mats[cls] ));
173         }
174     }
175     else
176     {
177         // check that the new training data has the same dimensionality etc.
178         if( _var_count != var_count || _var_all != var_all || !(!_var_idx && !var_idx ||
179             _var_idx && var_idx && cvNorm(_var_idx,var_idx,CV_C) < DBL_EPSILON) )
180             CV_ERROR( CV_StsBadArg,
181             "The new training data is inconsistent with the original training data" );
182
183         if( cls_labels->cols != __cls_labels->cols ||
184             cvNorm(cls_labels, __cls_labels, CV_C) > DBL_EPSILON )
185             CV_ERROR( CV_StsNotImplemented,
186             "In the current implementation the new training data must have absolutely "
187             "the same set of class labels as used in the original training data" );
188
189         nclasses = cls_labels->cols;
190     }
191
192     responses_data = responses->data.i;
193     CV_CALL( cov = cvCreateMat( _var_count, _var_count, CV_64FC1 ));
194
195     /* process train data (count, sum , productsum) */
196     for( s = 0; s < nsamples; s++ )
197     {
198         cls = responses_data[s];
199         int* count_data = count[cls]->data.i;
200         double* sum_data = sum[cls]->data.db;
201         double* prod_data = productsum[cls]->data.db;
202         const float* train_vec = train_data[s];
203         
204         for( c1 = 0; c1 < _var_count; c1++, prod_data += _var_count )
205         {
206             double val1 = train_vec[c1];
207             sum_data[c1] += val1;
208             count_data[c1]++;
209             for( c2 = c1; c2 < _var_count; c2++ )
210                 prod_data[c2] += train_vec[c2]*val1;
211         }
212     }
213
214     /* calculate avg, covariance matrix, c */
215     for( cls = 0; cls < nclasses; cls++ )
216     {
217         double det = 1;
218         int i, j;
219         CvMat* w = inv_eigen_values[cls];
220         int* count_data = count[cls]->data.i;
221         double* avg_data = avg[cls]->data.db;
222         double* sum1 = sum[cls]->data.db;
223
224         cvCompleteSymm( productsum[cls], 0 );
225
226         for( j = 0; j < _var_count; j++ )
227         {
228             int n = count_data[j];
229             avg_data[j] = n ? sum1[j] / n : 0.;
230         }
231
232         count_data = count[cls]->data.i;
233         avg_data = avg[cls]->data.db;
234         sum1 = sum[cls]->data.db;
235
236         for( i = 0; i < _var_count; i++ )
237         {
238             double* avg2_data = avg[cls]->data.db;
239             double* sum2 = sum[cls]->data.db;
240             double* prod_data = productsum[cls]->data.db + i*_var_count;
241             double* cov_data = cov->data.db + i*_var_count;
242             double s1val = sum1[j];
243             double avg1 = avg_data[i];
244             int count = count_data[i];
245
246             for( j = 0; j <= i; j++ )
247             {
248                 double avg2 = avg2_data[j];
249                 double cov_val = prod_data[j] - avg1 * sum2[j] - avg2 * s1val + avg1 * avg2 * count;
250                 cov_val = (count > 1) ? cov_val / (count - 1) : cov_val;
251                 cov_data[j] = cov_val;
252             }
253         }
254
255         CV_CALL( cvCompleteSymm( cov, 1 ));
256         CV_CALL( cvSVD( cov, w, cov_rotate_mats[cls], 0, CV_SVD_U_T ));
257         CV_CALL( cvMaxS( w, min_variation, w ));
258         for( j = 0; j < _var_count; j++ )
259             det *= w->data.db[j];
260
261         CV_CALL( cvDiv( NULL, w, w ));
262         c->data.db[cls] = log( det );
263     }
264
265     result = true;
266
267     __END__;
268
269     if( !result || cvGetErrStatus() < 0 )
270         clear();
271
272     cvReleaseMat( &cov );
273     cvReleaseMat( &__cls_labels );
274     cvReleaseMat( &__var_idx );
275     cvFree( &train_data );
276
277     return result;
278 }
279
280
281 float CvNormalBayesClassifier::predict( const CvMat* samples, CvMat* results ) const
282 {
283     float value = 0;
284     void* buffer = 0;
285     int allocated_buffer = 0;
286
287     CV_FUNCNAME( "CvNormalBayesClassifier::predict" );
288     
289     __BEGIN__;
290
291     int i, j, k, cls = -1, _var_count, nclasses;
292     double opt = FLT_MAX;
293     CvMat diff;
294     int rtype = 0, rstep = 0, size;
295     const int* vidx = 0;
296
297     nclasses = cls_labels->cols;
298     _var_count = avg[0]->cols;
299
300     if( !CV_IS_MAT(samples) || CV_MAT_TYPE(samples->type) != CV_32FC1 || samples->cols != var_all )
301         CV_ERROR( CV_StsBadArg,
302         "The input samples must be 32f matrix with the number of columns = var_all" );
303
304     if( samples->rows > 1 && !results )
305         CV_ERROR( CV_StsNullPtr,
306         "When the number of input samples is >1, the output vector of results must be passed" );
307
308     if( results )
309     {
310         if( !CV_IS_MAT(results) || CV_MAT_TYPE(results->type) != CV_32FC1 &&
311         CV_MAT_TYPE(results->type) != CV_32SC1 ||
312         results->cols != 1 && results->rows != 1 ||
313         results->cols + results->rows - 1 != samples->rows )
314         CV_ERROR( CV_StsBadArg, "The output array must be integer or floating-point vector "
315         "with the number of elements = number of rows in the input matrix" );
316
317         rtype = CV_MAT_TYPE(results->type);
318         rstep = CV_IS_MAT_CONT(results->type) ? 1 : results->step/CV_ELEM_SIZE(rtype);
319     }
320     
321     if( var_idx )
322         vidx = var_idx->data.i;
323
324 // allocate memory and initializing headers for calculating
325     size = sizeof(double) * (nclasses + var_count);
326     if( size <= CV_MAX_LOCAL_SIZE )
327         buffer = cvStackAlloc( size );
328     else
329     {
330         CV_CALL( buffer = cvAlloc( size ));
331         allocated_buffer = 1;
332     }
333     
334     diff = cvMat( 1, var_count, CV_64FC1, buffer );
335
336     for( k = 0; k < samples->rows; k++ )
337     {
338         int ival;
339         
340         for( i = 0; i < nclasses; i++ )
341         {
342             double cur = c->data.db[i];
343             CvMat* u = cov_rotate_mats[i];
344             CvMat* w = inv_eigen_values[i];
345             const double* avg_data = avg[i]->data.db;
346             const float* x = (const float*)(samples->data.ptr + samples->step*k);
347
348             // cov = u w u'  -->  cov^(-1) = u w^(-1) u'
349             for( j = 0; j < _var_count; j++ )
350                 diff.data.db[j] = avg_data[j] - x[vidx ? vidx[j] : j];
351
352             CV_CALL(cvGEMM( &diff, u, 1, 0, 0, &diff, CV_GEMM_B_T ));
353             for( j = 0; j < _var_count; j++ )
354             {
355                 double d = diff.data.db[j];
356                 cur += d*d*w->data.db[j];
357             }
358
359             if( cur < opt )
360             {
361                 cls = i;
362                 opt = cur;
363             }
364             /* probability = exp( -0.5 * cur ) */
365         }
366
367         ival = cls_labels->data.i[cls];
368         if( results )
369         {
370             if( rtype == CV_32SC1 )
371                 results->data.i[k*rstep] = ival;
372             else
373                 results->data.fl[k*rstep] = (float)ival;
374         }
375         if( k == 0 )
376             value = (float)ival;
377
378         /*if( _probs )
379         {
380             CV_CALL( cvConvertScale( &expo, &expo, -0.5 ));
381             CV_CALL( cvExp( &expo, &expo ));
382             if( _probs->cols == 1 )
383                 CV_CALL( cvReshape( &expo, &expo, 1, nclasses ));
384             CV_CALL( cvConvertScale( &expo, _probs, 1./cvSum( &expo ).val[0] ));
385         }*/
386     }
387
388     __END__;
389
390     if( allocated_buffer )
391         cvFree( &buffer );
392
393     return value;
394 }
395
396
397 void CvNormalBayesClassifier::write( CvFileStorage* fs, const char* name )
398 {
399     CV_FUNCNAME( "CvNormalBayesClassifier::write" );
400
401     __BEGIN__;
402
403     int nclasses, i;
404
405     nclasses = cls_labels->cols;
406
407     cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_NBAYES );
408
409     CV_CALL( cvWriteInt( fs, "var_count", var_count ));
410     CV_CALL( cvWriteInt( fs, "var_all", var_all ));
411
412     if( var_idx )
413         CV_CALL( cvWrite( fs, "var_idx", var_idx ));
414     CV_CALL( cvWrite( fs, "cls_labels", cls_labels ));
415
416     CV_CALL( cvStartWriteStruct( fs, "count", CV_NODE_SEQ ));
417     for( i = 0; i < nclasses; i++ )
418         CV_CALL( cvWrite( fs, NULL, count[i] ));
419     CV_CALL( cvEndWriteStruct( fs ));
420
421     CV_CALL( cvStartWriteStruct( fs, "sum", CV_NODE_SEQ ));
422     for( i = 0; i < nclasses; i++ )
423         CV_CALL( cvWrite( fs, NULL, sum[i] ));
424     CV_CALL( cvEndWriteStruct( fs ));
425  
426     CV_CALL( cvStartWriteStruct( fs, "productsum", CV_NODE_SEQ ));
427     for( i = 0; i < nclasses; i++ )
428         CV_CALL( cvWrite( fs, NULL, productsum[i] ));
429     CV_CALL( cvEndWriteStruct( fs ));
430
431     CV_CALL( cvStartWriteStruct( fs, "avg", CV_NODE_SEQ ));
432     for( i = 0; i < nclasses; i++ )
433         CV_CALL( cvWrite( fs, NULL, avg[i] ));
434     CV_CALL( cvEndWriteStruct( fs ));
435
436     CV_CALL( cvStartWriteStruct( fs, "inv_eigen_values", CV_NODE_SEQ ));
437     for( i = 0; i < nclasses; i++ )
438         CV_CALL( cvWrite( fs, NULL, inv_eigen_values[i] ));
439     CV_CALL( cvEndWriteStruct( fs ));
440
441     CV_CALL( cvStartWriteStruct( fs, "cov_rotate_mats", CV_NODE_SEQ ));
442     for( i = 0; i < nclasses; i++ )
443         CV_CALL( cvWrite( fs, NULL, cov_rotate_mats[i] ));
444     CV_CALL( cvEndWriteStruct( fs ));
445
446     CV_CALL( cvWrite( fs, "c", c ));
447
448     cvEndWriteStruct( fs );
449
450     __END__;
451 }
452
453
454 void CvNormalBayesClassifier::read( CvFileStorage* fs, CvFileNode* root_node )
455 {
456     bool ok = false;
457     CV_FUNCNAME( "CvNormalBayesClassifier::read" );
458
459     __BEGIN__;
460
461     int nclasses, i;
462     size_t data_size;
463     CvFileNode* node;
464     CvSeq* seq;
465     CvSeqReader reader;
466     
467     clear();
468
469     CV_CALL( var_count = cvReadIntByName( fs, root_node, "var_count", -1 ));
470     CV_CALL( var_all = cvReadIntByName( fs, root_node, "var_all", -1 ));
471     CV_CALL( var_idx = (CvMat*)cvReadByName( fs, root_node, "var_idx" ));
472     CV_CALL( cls_labels = (CvMat*)cvReadByName( fs, root_node, "cls_labels" ));
473     if( !cls_labels )
474         CV_ERROR( CV_StsParseError, "No \"cls_labels\" in NBayes classifier" );
475     if( cls_labels->cols < 1 )
476         CV_ERROR( CV_StsBadArg, "Number of classes is less 1" );
477     if( var_count <= 0 )
478         CV_ERROR( CV_StsParseError,
479         "The field \"var_count\" of NBayes classifier is missing" );
480     nclasses = cls_labels->cols;
481
482     data_size = nclasses*6*sizeof(CvMat*);
483     CV_CALL( count = (CvMat**)cvAlloc( data_size ));
484     memset( count, 0, data_size );
485
486     sum = count + nclasses;
487     productsum  = sum  + nclasses;
488     avg = productsum + nclasses;
489     inv_eigen_values = avg + nclasses;
490     cov_rotate_mats = inv_eigen_values + nclasses;
491
492     CV_CALL( node = cvGetFileNodeByName( fs, root_node, "count" ));
493     seq = node->data.seq;
494     if( !CV_NODE_IS_SEQ(node->tag) || seq->total != nclasses)
495         CV_ERROR( CV_StsBadArg, "" );
496     CV_CALL( cvStartReadSeq( seq, &reader, 0 ));
497     for( i = 0; i < nclasses; i++ )
498     {
499         CV_CALL( count[i] = (CvMat*)cvRead( fs, (CvFileNode*)reader.ptr ));
500         CV_NEXT_SEQ_ELEM( seq->elem_size, reader );
501     }
502
503     CV_CALL( node = cvGetFileNodeByName( fs, root_node, "sum" ));
504     seq = node->data.seq;
505     if( !CV_NODE_IS_SEQ(node->tag) || seq->total != nclasses)
506         CV_ERROR( CV_StsBadArg, "" );
507     CV_CALL( cvStartReadSeq( seq, &reader, 0 ));
508     for( i = 0; i < nclasses; i++ )
509     {
510         CV_CALL( sum[i] = (CvMat*)cvRead( fs, (CvFileNode*)reader.ptr ));
511         CV_NEXT_SEQ_ELEM( seq->elem_size, reader );
512     }
513
514     CV_CALL( node = cvGetFileNodeByName( fs, root_node, "productsum" ));
515     seq = node->data.seq;
516     if( !CV_NODE_IS_SEQ(node->tag) || seq->total != nclasses)
517         CV_ERROR( CV_StsBadArg, "" );
518     CV_CALL( cvStartReadSeq( seq, &reader, 0 ));
519     for( i = 0; i < nclasses; i++ )
520     {
521         CV_CALL( productsum[i] = (CvMat*)cvRead( fs, (CvFileNode*)reader.ptr ));
522         CV_NEXT_SEQ_ELEM( seq->elem_size, reader );
523     }
524
525     CV_CALL( node = cvGetFileNodeByName( fs, root_node, "avg" ));
526     seq = node->data.seq;
527     if( !CV_NODE_IS_SEQ(node->tag) || seq->total != nclasses)
528         CV_ERROR( CV_StsBadArg, "" );
529     CV_CALL( cvStartReadSeq( seq, &reader, 0 ));
530     for( i = 0; i < nclasses; i++ )
531     {
532         CV_CALL( avg[i] = (CvMat*)cvRead( fs, (CvFileNode*)reader.ptr ));
533         CV_NEXT_SEQ_ELEM( seq->elem_size, reader );
534     }
535
536     CV_CALL( node = cvGetFileNodeByName( fs, root_node, "inv_eigen_values" ));
537     seq = node->data.seq;
538     if( !CV_NODE_IS_SEQ(node->tag) || seq->total != nclasses)
539         CV_ERROR( CV_StsBadArg, "" );
540     CV_CALL( cvStartReadSeq( seq, &reader, 0 ));
541     for( i = 0; i < nclasses; i++ )
542     {
543         CV_CALL( inv_eigen_values[i] = (CvMat*)cvRead( fs, (CvFileNode*)reader.ptr ));
544         CV_NEXT_SEQ_ELEM( seq->elem_size, reader );
545     }
546
547     CV_CALL( node = cvGetFileNodeByName( fs, root_node, "cov_rotate_mats" ));
548     seq = node->data.seq;
549     if( !CV_NODE_IS_SEQ(node->tag) || seq->total != nclasses)
550         CV_ERROR( CV_StsBadArg, "" );
551     CV_CALL( cvStartReadSeq( seq, &reader, 0 ));
552     for( i = 0; i < nclasses; i++ )
553     {
554         CV_CALL( cov_rotate_mats[i] = (CvMat*)cvRead( fs, (CvFileNode*)reader.ptr ));
555         CV_NEXT_SEQ_ELEM( seq->elem_size, reader );
556     }
557
558     CV_CALL( c = (CvMat*)cvReadByName( fs, root_node, "c" ));
559
560     ok = true;
561
562     __END__;
563
564     if( !ok )
565         clear();
566 }
567
568 /* End of file. */
569