1 /*M///////////////////////////////////////////////////////////////////////////////////////
3 // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
5 // By downloading, copying, installing or using the software you agree to this license.
6 // If you do not agree to this license, do not download, install,
7 // copy or use the software.
10 // Intel License Agreement
12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
13 // Third party copyrights are property of their respective owners.
15 // Redistribution and use in source and binary forms, with or without modification,
16 // are permitted provided that the following conditions are met:
18 // * Redistribution's of source code must retain the above copyright notice,
19 // this list of conditions and the following disclaimer.
21 // * Redistribution's in binary form must reproduce the above copyright notice,
22 // this list of conditions and the following disclaimer in the documentation
23 // and/or other materials provided with the distribution.
25 // * The name of Intel Corporation may not be used to endorse or promote products
26 // derived from this software without specific prior written permission.
28 // This software is provided by the copyright holders and contributors "as is" and
29 // any express or implied warranties, including, but not limited to, the implied
30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
31 // In no event shall the Intel Corporation or contributors be liable for any direct,
32 // indirect, incidental, special, exemplary, or consequential damages
33 // (including, but not limited to, procurement of substitute goods or services;
34 // loss of use, data, or profits; or business interruption) however caused
35 // and on any theory of liability, whether in contract, strict liability,
36 // or tort (including negligence or otherwise) arising in any way out of
37 // the use of this software, even if advised of the possibility of such damage.
46 icvCmpIntegers (const void* a, const void* b) {return *(const int*)a - *(const int*)b;}
48 /****************************************************************************************\
49 * Cross-validation algorithms realizations *
50 \****************************************************************************************/
52 // Return pointer to trainIdx. Function DOES NOT FILL this matrix!
54 const CvMat* cvCrossValGetTrainIdxMatrix (const CvStatModel* estimateModel)
58 CV_FUNCNAME ("cvCrossValGetTrainIdxMatrix");
61 if (!CV_IS_CROSSVAL(estimateModel))
63 CV_ERROR (CV_StsBadArg, "Pointer point to not CvCrossValidationModel");
66 result = ((CvCrossValidationModel*)estimateModel)->sampleIdxTrain;
71 } // End of cvCrossValGetTrainIdxMatrix
73 /****************************************************************************************/
74 // Return pointer to checkIdx. Function DOES NOT FILL this matrix!
76 const CvMat* cvCrossValGetCheckIdxMatrix (const CvStatModel* estimateModel)
80 CV_FUNCNAME ("cvCrossValGetCheckIdxMatrix");
83 if (!CV_IS_CROSSVAL (estimateModel))
85 CV_ERROR (CV_StsBadArg, "Pointer point to not CvCrossValidationModel");
88 result = ((CvCrossValidationModel*)estimateModel)->sampleIdxEval;
93 } // End of cvCrossValGetCheckIdxMatrix
95 /****************************************************************************************/
96 // Create new Idx-matrix for next classifiers training and return code of result.
97 // Result is 0 if function can't make next step (error input or folds are finished),
98 // it is 1 if all was correct, and it is 2 if current fold wasn't' checked.
100 int cvCrossValNextStep (CvStatModel* estimateModel)
104 CV_FUNCNAME ("cvCrossValGetNextTrainIdx");
107 CvCrossValidationModel* crVal = (CvCrossValidationModel*) estimateModel;
110 if (!CV_IS_CROSSVAL (estimateModel))
112 CV_ERROR (CV_StsBadArg, "Pointer point to not CvCrossValidationModel");
115 fold = ++crVal->current_fold;
117 if (fold >= crVal->folds_all)
119 if (fold == crVal->folds_all)
123 CV_ERROR (CV_StsInternal, "All iterations has end long ago");
127 k = crVal->folds[fold + 1] - crVal->folds[fold];
128 crVal->sampleIdxTrain->data.i = crVal->sampleIdxAll + crVal->folds[fold + 1];
129 crVal->sampleIdxTrain->cols = crVal->samples_all - k;
130 crVal->sampleIdxEval->data.i = crVal->sampleIdxAll + crVal->folds[fold];
131 crVal->sampleIdxEval->cols = k;
133 if (crVal->is_checked)
135 crVal->is_checked = 0;
148 /****************************************************************************************/
149 // Do checking part of loop of cross-validations metod.
151 void cvCrossValCheckClassifier (CvStatModel* estimateModel,
152 const CvStatModel* model,
153 const CvMat* trainData,
155 const CvMat* trainClasses)
157 CV_FUNCNAME ("cvCrossValCheckClassifier ");
160 CvCrossValidationModel* crVal = (CvCrossValidationModel*) estimateModel;
165 float* responses_result;
168 double sum_c, sum_p, sum_pp, sum_cp, sum_cc, sq_err;
170 // Check input data to correct values.
171 if (!CV_IS_CROSSVAL (estimateModel))
173 CV_ERROR (CV_StsBadArg,"First parameter point to not CvCrossValidationModel");
175 if (!CV_IS_STAT_MODEL (model))
177 CV_ERROR (CV_StsBadArg, "Second parameter point to not CvStatModel");
179 if (!CV_IS_MAT (trainData))
181 CV_ERROR (CV_StsBadArg, "Third parameter point to not CvMat");
183 if (!CV_IS_MAT (trainClasses))
185 CV_ERROR (CV_StsBadArg, "Fifth parameter point to not CvMat");
187 if (crVal->is_checked)
189 CV_ERROR (CV_StsInternal, "This iterations already was checked");
193 k = crVal->sampleIdxEval->cols;
194 data = crVal->sampleIdxEval->data.i;
196 // Eval tested feature vectors.
197 CV_CALL (cvStatModelMultiPredict (model, trainData, sample_t_flag,
198 crVal->predict_results, NULL, crVal->sampleIdxEval));
199 // Count number if correct results.
200 responses_result = crVal->predict_results->data.fl;
201 if (crVal->is_regression)
203 sum_c = sum_p = sum_pp = sum_cp = sum_cc = sq_err = 0;
204 if (CV_MAT_TYPE (trainClasses->type) == CV_32FC1)
206 responses_fl = trainClasses->data.fl;
207 step = trainClasses->rows == 1 ? 1 : trainClasses->step / sizeof(float);
208 for (i = 0; i < k; i++)
210 te = responses_result[*data];
211 te1 = responses_fl[*data * step];
225 responses_i = trainClasses->data.i;
226 step = trainClasses->rows == 1 ? 1 : trainClasses->step / sizeof(int);
227 for (i = 0; i < k; i++)
229 te = responses_result[*data];
230 te1 = responses_i[*data * step];
242 // Fixing new internal values of accuracy.
243 crVal->sum_correct += sum_c;
244 crVal->sum_predict += sum_p;
245 crVal->sum_cc += sum_cc;
246 crVal->sum_pp += sum_pp;
247 crVal->sum_cp += sum_cp;
248 crVal->sq_error += sq_err;
252 if (CV_MAT_TYPE (trainClasses->type) == CV_32FC1)
254 responses_fl = trainClasses->data.fl;
255 step = trainClasses->rows == 1 ? 1 : trainClasses->step / sizeof(float);
256 for (i = 0, j = 0; i < k; i++)
258 if (cvRound (responses_result[*data]) == cvRound (responses_fl[*data * step]))
265 responses_i = trainClasses->data.i;
266 step = trainClasses->rows == 1 ? 1 : trainClasses->step / sizeof(int);
267 for (i = 0, j = 0; i < k; i++)
269 if (cvRound (responses_result[*data]) == responses_i[*data * step])
274 // Fixing new internal values of accuracy.
275 crVal->correct_results += j;
277 // Fixing that this fold already checked.
278 crVal->all_results += k;
279 crVal->is_checked = 1;
282 } // End of cvCrossValCheckClassifier
284 /****************************************************************************************/
285 // Return current accuracy.
287 float cvCrossValGetResult (const CvStatModel* estimateModel,
292 CV_FUNCNAME ("cvCrossValGetResult");
296 CvCrossValidationModel* crVal = (CvCrossValidationModel*)estimateModel;
298 if (!CV_IS_CROSSVAL (estimateModel))
300 CV_ERROR (CV_StsBadArg, "Pointer point to not CvCrossValidationModel");
303 if (crVal->all_results)
305 if (crVal->is_regression)
307 result = ((float)crVal->sq_error) / crVal->all_results;
310 te = crVal->all_results * crVal->sum_cp -
311 crVal->sum_correct * crVal->sum_predict;
313 te1 = (crVal->all_results * crVal->sum_cc -
314 crVal->sum_correct * crVal->sum_correct) *
315 (crVal->all_results * crVal->sum_pp -
316 crVal->sum_predict * crVal->sum_predict);
317 *correlation = (float)(te / te1);
323 result = ((float)crVal->correct_results) / crVal->all_results;
332 /****************************************************************************************/
333 // Reset cross-validation EstimateModel to state the same as it was immidiatly after
336 void cvCrossValReset (CvStatModel* estimateModel)
338 CV_FUNCNAME ("cvCrossValReset");
341 CvCrossValidationModel* crVal = (CvCrossValidationModel*)estimateModel;
343 if (!CV_IS_CROSSVAL (estimateModel))
345 CV_ERROR (CV_StsBadArg, "Pointer point to not CvCrossValidationModel");
348 crVal->current_fold = -1;
349 crVal->is_checked = 1;
350 crVal->all_results = 0;
351 crVal->correct_results = 0;
353 crVal->sum_correct = 0;
354 crVal->sum_predict = 0;
362 /****************************************************************************************/
363 // This function is standart CvStatModel field to release cross-validation EstimateModel.
365 void cvReleaseCrossValidationModel (CvStatModel** model)
367 CvCrossValidationModel* pModel;
369 CV_FUNCNAME ("cvReleaseCrossValidationModel");
374 CV_ERROR (CV_StsNullPtr, "");
377 pModel = (CvCrossValidationModel*)*model;
382 if (!CV_IS_CROSSVAL (pModel))
384 CV_ERROR (CV_StsBadArg, "");
387 cvFree (&pModel->sampleIdxAll);
388 cvFree (&pModel->folds);
389 cvReleaseMat (&pModel->sampleIdxEval);
390 cvReleaseMat (&pModel->sampleIdxTrain);
391 cvReleaseMat (&pModel->predict_results);
396 } // End of cvReleaseCrossValidationModel.
398 /****************************************************************************************/
399 // This function create cross-validation EstimateModel.
401 cvCreateCrossValidationEstimateModel(
403 const CvStatModelParams* estimateParams,
404 const CvMat* sampleIdx)
406 CvStatModel* model = NULL;
407 CvCrossValidationModel* crVal = NULL;
409 CV_FUNCNAME ("cvCreateCrossValidationEstimateModel");
415 int samples_selected;
422 rng = cvRNG(cvGetTickCount());
423 cvRandInt (&rng); cvRandInt (&rng); cvRandInt (&rng); cvRandInt (&rng);
424 // Check input parameters.
426 k_fold = ((CvCrossValidationParams*)estimateParams)->k_fold;
429 CV_ERROR (CV_StsBadArg, "Error in parameters of cross-validation (k_fold == 0)!");
431 if (samples_all <= 0)
433 CV_ERROR (CV_StsBadArg, "<samples_all> should be positive!");
436 // Alloc memory and fill standart StatModel's fields.
437 CV_CALL (crVal = (CvCrossValidationModel*)cvCreateStatModel (
438 CV_STAT_MODEL_MAGIC_VAL | CV_CROSSVAL_MAGIC_VAL,
439 sizeof(CvCrossValidationModel),
440 cvReleaseCrossValidationModel,
442 crVal->current_fold = -1;
443 crVal->folds_all = k_fold;
444 if (estimateParams && ((CvCrossValidationParams*)estimateParams)->is_regression)
445 crVal->is_regression = 1;
447 crVal->is_regression = 0;
448 if (estimateParams && ((CvCrossValidationParams*)estimateParams)->rng)
449 prng = ((CvCrossValidationParams*)estimateParams)->rng;
453 // Check and preprocess sample indices.
459 if (!CV_IS_MAT (sampleIdx))
460 CV_ERROR (CV_StsBadArg, "Invalid sampleIdx array");
462 if (sampleIdx->rows != 1 && sampleIdx->cols != 1)
463 CV_ERROR (CV_StsBadSize, "sampleIdx array must be 1-dimensional");
465 s_len = sampleIdx->rows + sampleIdx->cols - 1;
466 s_step = sampleIdx->rows == 1 ?
467 1 : sampleIdx->step / CV_ELEM_SIZE(sampleIdx->type);
469 s_type = CV_MAT_TYPE (sampleIdx->type);
476 uchar* s_data = sampleIdx->data.ptr;
478 // sampleIdx is array of 1's and 0's -
479 // i.e. it is a mask of the selected samples
480 if( s_len != samples_all )
481 CV_ERROR (CV_StsUnmatchedSizes,
482 "Sample mask should contain as many elements as the total number of samples");
484 samples_selected = 0;
485 for (i = 0; i < s_len; i++)
486 samples_selected += s_data[i * s_step] != 0;
488 if (samples_selected == 0)
489 CV_ERROR (CV_StsOutOfRange, "No samples is selected!");
491 s_len = samples_selected;
494 if (s_len > samples_all)
495 CV_ERROR (CV_StsOutOfRange,
496 "sampleIdx array may not contain more elements than the total number of samples");
497 samples_selected = s_len;
500 CV_ERROR (CV_StsUnsupportedFormat, "Unsupported sampleIdx array data type "
501 "(it should be 8uC1, 8sC1 or 32sC1)");
504 // Alloc additional memory for internal Idx and fill it.
505 /*!!*/ CV_CALL (res_s_data = crVal->sampleIdxAll =
506 (int*)cvAlloc (2 * s_len * sizeof(int)));
508 if (s_type < CV_32SC1)
510 uchar* s_data = sampleIdx->data.ptr;
511 for (i = 0; i < s_len; i++)
512 if (s_data[i * s_step])
516 res_s_data = crVal->sampleIdxAll;
520 int* s_data = sampleIdx->data.i;
521 int out_of_order = 0;
523 for (i = 0; i < s_len; i++)
525 res_s_data[i] = s_data[i * s_step];
526 if (i > 0 && res_s_data[i] < res_s_data[i - 1])
531 qsort (res_s_data, s_len, sizeof(res_s_data[0]), icvCmpIntegers);
533 if (res_s_data[0] < 0 ||
534 res_s_data[s_len - 1] >= samples_all)
535 CV_ERROR (CV_StsBadArg, "There are out-of-range sample indices");
536 for (i = 1; i < s_len; i++)
537 if (res_s_data[i] <= res_s_data[i - 1])
538 CV_ERROR (CV_StsBadArg, "There are duplicated");
541 else // if (sampleIdx)
543 // Alloc additional memory for internal Idx and fill it.
545 CV_CALL (res_s_data = crVal->sampleIdxAll = (int*)cvAlloc (2 * s_len * sizeof(int)));
546 for (i = 0; i < s_len; i++)
550 res_s_data = crVal->sampleIdxAll;
551 } // if (sampleIdx) ... else
553 // Resort internal Idx.
554 te_s_data = res_s_data + s_len;
555 for (i = s_len; i > 1; i--)
557 j = cvRandInt (prng) % i;
559 *te_s_data = res_s_data[j];
563 // Duplicate resorted internal Idx.
564 // It will be used to simplify operation of getting trainIdx.
565 te_s_data = res_s_data + s_len;
566 for (i = 0; i < s_len; i++)
568 *te_s_data++ = *res_s_data++;
571 // Cut sampleIdxAll to parts.
576 CV_ERROR (CV_StsBadArg,
577 "Error in parameters of cross-validation ('k_fold' > #samples)!");
579 folds = crVal->folds = (int*) cvAlloc ((k_fold + 1) * sizeof (int));
581 for (i = 1; i < k_fold; i++)
583 *folds++ = cvRound (i * s_len * 1. / k_fold);
586 folds = crVal->folds;
588 crVal->max_fold_size = (s_len - 1) / k_fold + 1;
593 crVal->max_fold_size = k;
596 CV_ERROR (CV_StsBadArg,
597 "Error in parameters of cross-validation (-'k_fold' > #samples)!");
599 crVal->folds_all = k = (s_len - 1) / k + 1;
601 folds = crVal->folds = (int*) cvAlloc ((k + 1) * sizeof (int));
602 for (i = 0; i < k; i++)
604 *folds++ = -i * k_fold;
607 folds = crVal->folds;
610 // Prepare other internal fields to working.
611 CV_CALL (crVal->predict_results = cvCreateMat (1, samples_all, CV_32FC1));
612 CV_CALL (crVal->sampleIdxEval = cvCreateMatHeader (1, 1, CV_32SC1));
613 CV_CALL (crVal->sampleIdxTrain = cvCreateMatHeader (1, 1, CV_32SC1));
614 crVal->sampleIdxEval->cols = 0;
615 crVal->sampleIdxTrain->cols = 0;
616 crVal->samples_all = s_len;
617 crVal->is_checked = 1;
619 crVal->getTrainIdxMat = cvCrossValGetTrainIdxMatrix;
620 crVal->getCheckIdxMat = cvCrossValGetCheckIdxMatrix;
621 crVal->nextStep = cvCrossValNextStep;
622 crVal->check = cvCrossValCheckClassifier;
623 crVal->getResult = cvCrossValGetResult;
624 crVal->reset = cvCrossValReset;
626 model = (CvStatModel*)crVal;
632 cvReleaseCrossValidationModel ((CvStatModel**)&crVal);
636 } // End of cvCreateCrossValidationEstimateModel
639 /****************************************************************************************\
640 * Extended interface with backcalls for models *
641 \****************************************************************************************/
643 cvCrossValidation (const CvMat* trueData,
645 const CvMat* trueClasses,
646 CvStatModel* (*createClassifier) (const CvMat*,
649 const CvClassifierTrainParams*,
654 const CvClassifierTrainParams* estimateParams,
655 const CvClassifierTrainParams* trainParams,
656 const CvMat* compIdx,
657 const CvMat* sampleIdx,
658 CvStatModel** pCrValModel,
659 const CvMat* typeMask,
660 const CvMat* missedMeasurementMask)
662 CvCrossValidationModel* crVal = NULL;
664 CvStatModel* pClassifier = NULL;
666 CV_FUNCNAME ("cvCrossValidation");
669 const CvMat* trainDataIdx;
672 // checking input data
673 if ((createClassifier) == NULL)
675 CV_ERROR (CV_StsNullPtr, "Null pointer to functiion which create classifier");
677 if (pCrValModel && *pCrValModel && !CV_IS_CROSSVAL(*pCrValModel))
679 CV_ERROR (CV_StsBadArg,
680 "<pCrValModel> point to not cross-validation model");
684 if (pCrValModel && *pCrValModel)
686 crVal = (CvCrossValidationModel*)*pCrValModel;
687 crVal->reset ((CvStatModel*)crVal);
691 samples_all = ((tflag) ? trueData->rows : trueData->cols);
692 CV_CALL (crVal = (CvCrossValidationModel*)
693 cvCreateCrossValidationEstimateModel (samples_all, estimateParams, sampleIdx));
696 CV_CALL (trainDataIdx = crVal->getTrainIdxMat ((CvStatModel*)crVal));
699 for (; crVal->nextStep((CvStatModel*)crVal) != 0; )
701 CV_CALL (pClassifier = createClassifier (trueData, tflag, trueClasses,
702 trainParams, compIdx, trainDataIdx, typeMask, missedMeasurementMask));
703 CV_CALL (crVal->check ((CvStatModel*)crVal, pClassifier,
704 trueData, tflag, trueClasses));
706 pClassifier->release (&pClassifier);
709 // Get result and fill output field.
710 CV_CALL (result = crVal->getResult ((CvStatModel*)crVal, 0));
712 if (pCrValModel && !*pCrValModel)
713 *pCrValModel = (CvStatModel*)crVal;
717 // Free all memory that should be freed.
719 pClassifier->release (&pClassifier);
720 if (crVal && (!pCrValModel || !*pCrValModel))
721 crVal->release ((CvStatModel**)&crVal);
724 } // End of cvCrossValidation