1 /*M///////////////////////////////////////////////////////////////////////////////////////
3 // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
5 // By downloading, copying, installing or using the software you agree to this license.
6 // If you do not agree to this license, do not download, install,
7 // copy or use the software.
10 // Intel License Agreement
12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
13 // Third party copyrights are property of their respective owners.
15 // Redistribution and use in source and binary forms, with or without modification,
16 // are permitted provided that the following conditions are met:
18 // * Redistribution's of source code must retain the above copyright notice,
19 // this list of conditions and the following disclaimer.
21 // * Redistribution's in binary form must reproduce the above copyright notice,
22 // this list of conditions and the following disclaimer in the documentation
23 // and/or other materials provided with the distribution.
25 // * The name of Intel Corporation may not be used to endorse or promote products
26 // derived from this software without specific prior written permission.
28 // This software is provided by the copyright holders and contributors "as is" and
29 // any express or implied warranties, including, but not limited to, the implied
30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
31 // In no event shall the Intel Corporation or contributors be liable for any direct,
32 // indirect, incidental, special, exemplary, or consequential damages
33 // (including, but not limited to, procurement of substitute goods or services;
34 // loss of use, data, or profits; or business interruption) however caused
35 // and on any theory of liability, whether in contract, strict liability,
36 // or tort (including negligence or otherwise) arising in any way out of
37 // the use of this software, even if advised of the possibility of such damage.
43 CvNormalBayesClassifier::CvNormalBayesClassifier()
45 var_count = var_all = 0;
55 default_model_name = "my_nb";
59 void CvNormalBayesClassifier::clear()
63 for( int cls = 0; cls < cls_labels->cols; cls++ )
65 cvReleaseMat( &count[cls] );
66 cvReleaseMat( &sum[cls] );
67 cvReleaseMat( &productsum[cls] );
68 cvReleaseMat( &avg[cls] );
69 cvReleaseMat( &inv_eigen_values[cls] );
70 cvReleaseMat( &cov_rotate_mats[cls] );
74 cvReleaseMat( &cls_labels );
75 cvReleaseMat( &var_idx );
81 CvNormalBayesClassifier::~CvNormalBayesClassifier()
87 CvNormalBayesClassifier::CvNormalBayesClassifier(
88 const CvMat* _train_data, const CvMat* _responses,
89 const CvMat* _var_idx, const CvMat* _sample_idx )
91 var_count = var_all = 0;
101 default_model_name = "my_nb";
103 train( _train_data, _responses, _var_idx, _sample_idx );
107 bool CvNormalBayesClassifier::train( const CvMat* _train_data, const CvMat* _responses,
108 const CvMat* _var_idx, const CvMat* _sample_idx, bool update )
110 const float min_variation = FLT_EPSILON;
112 CvMat* responses = 0;
113 const float** train_data = 0;
114 CvMat* __cls_labels = 0;
115 CvMat* __var_idx = 0;
118 CV_FUNCNAME( "CvNormalBayesClassifier::train" );
122 int cls, nsamples = 0, _var_count = 0, _var_all = 0, nclasses = 0;
124 const int* responses_data;
126 CV_CALL( cvPrepareTrainData( 0,
127 _train_data, CV_ROW_SAMPLE, _responses, CV_VAR_CATEGORICAL,
128 _var_idx, _sample_idx, false, &train_data,
129 &nsamples, &_var_count, &_var_all, &responses,
130 &__cls_labels, &__var_idx ));
134 const size_t mat_size = sizeof(CvMat*);
140 cls_labels = __cls_labels;
141 __var_idx = __cls_labels = 0;
142 var_count = _var_count;
145 nclasses = cls_labels->cols;
146 data_size = nclasses*6*mat_size;
148 CV_CALL( count = (CvMat**)cvAlloc( data_size ));
149 memset( count, 0, data_size );
151 sum = count + nclasses;
152 productsum = sum + nclasses;
153 avg = productsum + nclasses;
154 inv_eigen_values= avg + nclasses;
155 cov_rotate_mats = inv_eigen_values + nclasses;
157 CV_CALL( c = cvCreateMat( 1, nclasses, CV_64FC1 ));
159 for( cls = 0; cls < nclasses; cls++ )
161 CV_CALL(count[cls] = cvCreateMat( 1, var_count, CV_32SC1 ));
162 CV_CALL(sum[cls] = cvCreateMat( 1, var_count, CV_64FC1 ));
163 CV_CALL(productsum[cls] = cvCreateMat( var_count, var_count, CV_64FC1 ));
164 CV_CALL(avg[cls] = cvCreateMat( 1, var_count, CV_64FC1 ));
165 CV_CALL(inv_eigen_values[cls] = cvCreateMat( 1, var_count, CV_64FC1 ));
166 CV_CALL(cov_rotate_mats[cls] = cvCreateMat( var_count, var_count, CV_64FC1 ));
167 CV_CALL(cvZero( count[cls] ));
168 CV_CALL(cvZero( sum[cls] ));
169 CV_CALL(cvZero( productsum[cls] ));
170 CV_CALL(cvZero( avg[cls] ));
171 CV_CALL(cvZero( inv_eigen_values[cls] ));
172 CV_CALL(cvZero( cov_rotate_mats[cls] ));
177 // check that the new training data has the same dimensionality etc.
178 if( _var_count != var_count || _var_all != var_all || !(!_var_idx && !var_idx ||
179 _var_idx && var_idx && cvNorm(_var_idx,var_idx,CV_C) < DBL_EPSILON) )
180 CV_ERROR( CV_StsBadArg,
181 "The new training data is inconsistent with the original training data" );
183 if( cls_labels->cols != __cls_labels->cols ||
184 cvNorm(cls_labels, __cls_labels, CV_C) > DBL_EPSILON )
185 CV_ERROR( CV_StsNotImplemented,
186 "In the current implementation the new training data must have absolutely "
187 "the same set of class labels as used in the original training data" );
189 nclasses = cls_labels->cols;
192 responses_data = responses->data.i;
193 CV_CALL( cov = cvCreateMat( _var_count, _var_count, CV_64FC1 ));
195 /* process train data (count, sum , productsum) */
196 for( s = 0; s < nsamples; s++ )
198 cls = responses_data[s];
199 int* count_data = count[cls]->data.i;
200 double* sum_data = sum[cls]->data.db;
201 double* prod_data = productsum[cls]->data.db;
202 const float* train_vec = train_data[s];
204 for( c1 = 0; c1 < _var_count; c1++, prod_data += _var_count )
206 double val1 = train_vec[c1];
207 sum_data[c1] += val1;
209 for( c2 = c1; c2 < _var_count; c2++ )
210 prod_data[c2] += train_vec[c2]*val1;
214 /* calculate avg, covariance matrix, c */
215 for( cls = 0; cls < nclasses; cls++ )
219 CvMat* w = inv_eigen_values[cls];
220 int* count_data = count[cls]->data.i;
221 double* avg_data = avg[cls]->data.db;
222 double* sum1 = sum[cls]->data.db;
224 cvCompleteSymm( productsum[cls], 0 );
226 for( j = 0; j < _var_count; j++ )
228 int n = count_data[j];
229 avg_data[j] = n ? sum1[j] / n : 0.;
232 count_data = count[cls]->data.i;
233 avg_data = avg[cls]->data.db;
234 sum1 = sum[cls]->data.db;
236 for( i = 0; i < _var_count; i++ )
238 double* avg2_data = avg[cls]->data.db;
239 double* sum2 = sum[cls]->data.db;
240 double* prod_data = productsum[cls]->data.db + i*_var_count;
241 double* cov_data = cov->data.db + i*_var_count;
242 double s1val = sum1[j];
243 double avg1 = avg_data[i];
244 int count = count_data[i];
246 for( j = 0; j <= i; j++ )
248 double avg2 = avg2_data[j];
249 double cov_val = prod_data[j] - avg1 * sum2[j] - avg2 * s1val + avg1 * avg2 * count;
250 cov_val = (count > 1) ? cov_val / (count - 1) : cov_val;
251 cov_data[j] = cov_val;
255 CV_CALL( cvCompleteSymm( cov, 1 ));
256 CV_CALL( cvSVD( cov, w, cov_rotate_mats[cls], 0, CV_SVD_U_T ));
257 CV_CALL( cvMaxS( w, min_variation, w ));
258 for( j = 0; j < _var_count; j++ )
259 det *= w->data.db[j];
261 CV_CALL( cvDiv( NULL, w, w ));
262 c->data.db[cls] = log( det );
269 if( !result || cvGetErrStatus() < 0 )
272 cvReleaseMat( &cov );
273 cvReleaseMat( &__cls_labels );
274 cvReleaseMat( &__var_idx );
275 cvFree( &train_data );
281 float CvNormalBayesClassifier::predict( const CvMat* samples, CvMat* results ) const
285 int allocated_buffer = 0;
287 CV_FUNCNAME( "CvNormalBayesClassifier::predict" );
291 int i, j, k, cls = -1, _var_count, nclasses;
292 double opt = FLT_MAX;
294 int rtype = 0, rstep = 0, size;
297 nclasses = cls_labels->cols;
298 _var_count = avg[0]->cols;
300 if( !CV_IS_MAT(samples) || CV_MAT_TYPE(samples->type) != CV_32FC1 || samples->cols != var_all )
301 CV_ERROR( CV_StsBadArg,
302 "The input samples must be 32f matrix with the number of columns = var_all" );
304 if( samples->rows > 1 && !results )
305 CV_ERROR( CV_StsNullPtr,
306 "When the number of input samples is >1, the output vector of results must be passed" );
310 if( !CV_IS_MAT(results) || CV_MAT_TYPE(results->type) != CV_32FC1 &&
311 CV_MAT_TYPE(results->type) != CV_32SC1 ||
312 results->cols != 1 && results->rows != 1 ||
313 results->cols + results->rows - 1 != samples->rows )
314 CV_ERROR( CV_StsBadArg, "The output array must be integer or floating-point vector "
315 "with the number of elements = number of rows in the input matrix" );
317 rtype = CV_MAT_TYPE(results->type);
318 rstep = CV_IS_MAT_CONT(results->type) ? 1 : results->step/CV_ELEM_SIZE(rtype);
322 vidx = var_idx->data.i;
324 // allocate memory and initializing headers for calculating
325 size = sizeof(double) * (nclasses + var_count);
326 if( size <= CV_MAX_LOCAL_SIZE )
327 buffer = cvStackAlloc( size );
330 CV_CALL( buffer = cvAlloc( size ));
331 allocated_buffer = 1;
334 diff = cvMat( 1, var_count, CV_64FC1, buffer );
336 for( k = 0; k < samples->rows; k++ )
340 for( i = 0; i < nclasses; i++ )
342 double cur = c->data.db[i];
343 CvMat* u = cov_rotate_mats[i];
344 CvMat* w = inv_eigen_values[i];
345 const double* avg_data = avg[i]->data.db;
346 const float* x = (const float*)(samples->data.ptr + samples->step*k);
348 // cov = u w u' --> cov^(-1) = u w^(-1) u'
349 for( j = 0; j < _var_count; j++ )
350 diff.data.db[j] = avg_data[j] - x[vidx ? vidx[j] : j];
352 CV_CALL(cvGEMM( &diff, u, 1, 0, 0, &diff, CV_GEMM_B_T ));
353 for( j = 0; j < _var_count; j++ )
355 double d = diff.data.db[j];
356 cur += d*d*w->data.db[j];
364 /* probability = exp( -0.5 * cur ) */
367 ival = cls_labels->data.i[cls];
370 if( rtype == CV_32SC1 )
371 results->data.i[k*rstep] = ival;
373 results->data.fl[k*rstep] = (float)ival;
380 CV_CALL( cvConvertScale( &expo, &expo, -0.5 ));
381 CV_CALL( cvExp( &expo, &expo ));
382 if( _probs->cols == 1 )
383 CV_CALL( cvReshape( &expo, &expo, 1, nclasses ));
384 CV_CALL( cvConvertScale( &expo, _probs, 1./cvSum( &expo ).val[0] ));
390 if( allocated_buffer )
397 void CvNormalBayesClassifier::write( CvFileStorage* fs, const char* name )
399 CV_FUNCNAME( "CvNormalBayesClassifier::write" );
405 nclasses = cls_labels->cols;
407 cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_NBAYES );
409 CV_CALL( cvWriteInt( fs, "var_count", var_count ));
410 CV_CALL( cvWriteInt( fs, "var_all", var_all ));
413 CV_CALL( cvWrite( fs, "var_idx", var_idx ));
414 CV_CALL( cvWrite( fs, "cls_labels", cls_labels ));
416 CV_CALL( cvStartWriteStruct( fs, "count", CV_NODE_SEQ ));
417 for( i = 0; i < nclasses; i++ )
418 CV_CALL( cvWrite( fs, NULL, count[i] ));
419 CV_CALL( cvEndWriteStruct( fs ));
421 CV_CALL( cvStartWriteStruct( fs, "sum", CV_NODE_SEQ ));
422 for( i = 0; i < nclasses; i++ )
423 CV_CALL( cvWrite( fs, NULL, sum[i] ));
424 CV_CALL( cvEndWriteStruct( fs ));
426 CV_CALL( cvStartWriteStruct( fs, "productsum", CV_NODE_SEQ ));
427 for( i = 0; i < nclasses; i++ )
428 CV_CALL( cvWrite( fs, NULL, productsum[i] ));
429 CV_CALL( cvEndWriteStruct( fs ));
431 CV_CALL( cvStartWriteStruct( fs, "avg", CV_NODE_SEQ ));
432 for( i = 0; i < nclasses; i++ )
433 CV_CALL( cvWrite( fs, NULL, avg[i] ));
434 CV_CALL( cvEndWriteStruct( fs ));
436 CV_CALL( cvStartWriteStruct( fs, "inv_eigen_values", CV_NODE_SEQ ));
437 for( i = 0; i < nclasses; i++ )
438 CV_CALL( cvWrite( fs, NULL, inv_eigen_values[i] ));
439 CV_CALL( cvEndWriteStruct( fs ));
441 CV_CALL( cvStartWriteStruct( fs, "cov_rotate_mats", CV_NODE_SEQ ));
442 for( i = 0; i < nclasses; i++ )
443 CV_CALL( cvWrite( fs, NULL, cov_rotate_mats[i] ));
444 CV_CALL( cvEndWriteStruct( fs ));
446 CV_CALL( cvWrite( fs, "c", c ));
448 cvEndWriteStruct( fs );
454 void CvNormalBayesClassifier::read( CvFileStorage* fs, CvFileNode* root_node )
457 CV_FUNCNAME( "CvNormalBayesClassifier::read" );
469 CV_CALL( var_count = cvReadIntByName( fs, root_node, "var_count", -1 ));
470 CV_CALL( var_all = cvReadIntByName( fs, root_node, "var_all", -1 ));
471 CV_CALL( var_idx = (CvMat*)cvReadByName( fs, root_node, "var_idx" ));
472 CV_CALL( cls_labels = (CvMat*)cvReadByName( fs, root_node, "cls_labels" ));
474 CV_ERROR( CV_StsParseError, "No \"cls_labels\" in NBayes classifier" );
475 if( cls_labels->cols < 1 )
476 CV_ERROR( CV_StsBadArg, "Number of classes is less 1" );
478 CV_ERROR( CV_StsParseError,
479 "The field \"var_count\" of NBayes classifier is missing" );
480 nclasses = cls_labels->cols;
482 data_size = nclasses*6*sizeof(CvMat*);
483 CV_CALL( count = (CvMat**)cvAlloc( data_size ));
484 memset( count, 0, data_size );
486 sum = count + nclasses;
487 productsum = sum + nclasses;
488 avg = productsum + nclasses;
489 inv_eigen_values = avg + nclasses;
490 cov_rotate_mats = inv_eigen_values + nclasses;
492 CV_CALL( node = cvGetFileNodeByName( fs, root_node, "count" ));
493 seq = node->data.seq;
494 if( !CV_NODE_IS_SEQ(node->tag) || seq->total != nclasses)
495 CV_ERROR( CV_StsBadArg, "" );
496 CV_CALL( cvStartReadSeq( seq, &reader, 0 ));
497 for( i = 0; i < nclasses; i++ )
499 CV_CALL( count[i] = (CvMat*)cvRead( fs, (CvFileNode*)reader.ptr ));
500 CV_NEXT_SEQ_ELEM( seq->elem_size, reader );
503 CV_CALL( node = cvGetFileNodeByName( fs, root_node, "sum" ));
504 seq = node->data.seq;
505 if( !CV_NODE_IS_SEQ(node->tag) || seq->total != nclasses)
506 CV_ERROR( CV_StsBadArg, "" );
507 CV_CALL( cvStartReadSeq( seq, &reader, 0 ));
508 for( i = 0; i < nclasses; i++ )
510 CV_CALL( sum[i] = (CvMat*)cvRead( fs, (CvFileNode*)reader.ptr ));
511 CV_NEXT_SEQ_ELEM( seq->elem_size, reader );
514 CV_CALL( node = cvGetFileNodeByName( fs, root_node, "productsum" ));
515 seq = node->data.seq;
516 if( !CV_NODE_IS_SEQ(node->tag) || seq->total != nclasses)
517 CV_ERROR( CV_StsBadArg, "" );
518 CV_CALL( cvStartReadSeq( seq, &reader, 0 ));
519 for( i = 0; i < nclasses; i++ )
521 CV_CALL( productsum[i] = (CvMat*)cvRead( fs, (CvFileNode*)reader.ptr ));
522 CV_NEXT_SEQ_ELEM( seq->elem_size, reader );
525 CV_CALL( node = cvGetFileNodeByName( fs, root_node, "avg" ));
526 seq = node->data.seq;
527 if( !CV_NODE_IS_SEQ(node->tag) || seq->total != nclasses)
528 CV_ERROR( CV_StsBadArg, "" );
529 CV_CALL( cvStartReadSeq( seq, &reader, 0 ));
530 for( i = 0; i < nclasses; i++ )
532 CV_CALL( avg[i] = (CvMat*)cvRead( fs, (CvFileNode*)reader.ptr ));
533 CV_NEXT_SEQ_ELEM( seq->elem_size, reader );
536 CV_CALL( node = cvGetFileNodeByName( fs, root_node, "inv_eigen_values" ));
537 seq = node->data.seq;
538 if( !CV_NODE_IS_SEQ(node->tag) || seq->total != nclasses)
539 CV_ERROR( CV_StsBadArg, "" );
540 CV_CALL( cvStartReadSeq( seq, &reader, 0 ));
541 for( i = 0; i < nclasses; i++ )
543 CV_CALL( inv_eigen_values[i] = (CvMat*)cvRead( fs, (CvFileNode*)reader.ptr ));
544 CV_NEXT_SEQ_ELEM( seq->elem_size, reader );
547 CV_CALL( node = cvGetFileNodeByName( fs, root_node, "cov_rotate_mats" ));
548 seq = node->data.seq;
549 if( !CV_NODE_IS_SEQ(node->tag) || seq->total != nclasses)
550 CV_ERROR( CV_StsBadArg, "" );
551 CV_CALL( cvStartReadSeq( seq, &reader, 0 ));
552 for( i = 0; i < nclasses; i++ )
554 CV_CALL( cov_rotate_mats[i] = (CvMat*)cvRead( fs, (CvFileNode*)reader.ptr ));
555 CV_NEXT_SEQ_ELEM( seq->elem_size, reader );
558 CV_CALL( c = (CvMat*)cvReadByName( fs, root_node, "c" ));