1 /*M///////////////////////////////////////////////////////////////////////////////////////
3 // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
5 // By downloading, copying, installing or using the software you agree to this license.
6 // If you do not agree to this license, do not download, install,
7 // copy or use the software.
10 // Intel License Agreement
12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
13 // Third party copyrights are property of their respective owners.
15 // Redistribution and use in source and binary forms, with or without modification,
16 // are permitted provided that the following conditions are met:
18 // * Redistribution's of source code must retain the above copyright notice,
19 // this list of conditions and the following disclaimer.
21 // * Redistribution's in binary form must reproduce the above copyright notice,
22 // this list of conditions and the following disclaimer in the documentation
23 // and/or other materials provided with the distribution.
25 // * The name of Intel Corporation may not be used to endorse or promote products
26 // derived from this software without specific prior written permission.
28 // This software is provided by the copyright holders and contributors "as is" and
29 // any express or implied warranties, including, but not limited to, the implied
30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
31 // In no event shall the Intel Corporation or contributors be liable for any direct,
32 // indirect, incidental, special, exemplary, or consequential damages
33 // (including, but not limited to, procurement of substitute goods or services;
34 // loss of use, data, or profits; or business interruption) however caused
35 // and on any theory of liability, whether in contract, strict liability,
36 // or tort (including negligence or otherwise) arising in any way out of
37 // the use of this software, even if advised of the possibility of such damage.
43 /****************************************************************************************\
47 The code has been derived from libsvm library (version 2.6)
48 (http://www.csie.ntu.edu.tw/~cjlin/libsvm).
50 Here is the orignal copyright:
51 ------------------------------------------------------------------------------------------
52 Copyright (c) 2000-2003 Chih-Chung Chang and Chih-Jen Lin
55 Redistribution and use in source and binary forms, with or without
56 modification, are permitted provided that the following conditions
59 1. Redistributions of source code must retain the above copyright
60 notice, this list of conditions and the following disclaimer.
62 2. Redistributions in binary form must reproduce the above copyright
63 notice, this list of conditions and the following disclaimer in the
64 documentation and/or other materials provided with the distribution.
66 3. Neither name of copyright holders nor the names of its contributors
67 may be used to endorse or promote products derived from this software
68 without specific prior written permission.
71 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
72 ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
73 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
74 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR
75 CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
76 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
77 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
78 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
79 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
80 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
81 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
82 \****************************************************************************************/
84 #define CV_SVM_MIN_CACHE_SIZE (40 << 20) /* 40Mb */
90 #pragma warning( disable: 4514 ) /* unreferenced inline functions */
95 #define QFLOAT_TYPE CV_32F
97 typedef double Qfloat;
98 #define QFLOAT_TYPE CV_64F
102 bool CvParamGrid::check() const
106 CV_FUNCNAME( "CvParamGrid::check" );
109 if( min_val > max_val )
110 CV_ERROR( CV_StsBadArg, "Lower bound of the grid must be less then the upper one" );
111 if( min_val < DBL_EPSILON )
112 CV_ERROR( CV_StsBadArg, "Lower bound of the grid must be positive" );
113 if( step < 1. + FLT_EPSILON )
114 CV_ERROR( CV_StsBadArg, "Grid step must greater then 1" );
123 CvParamGrid CvSVM::get_default_grid( int param_id )
126 if( param_id == CvSVM::C )
130 grid.step = 5; // total iterations = 5
132 else if( param_id == CvSVM::GAMMA )
136 grid.step = 15; // total iterations = 4
138 else if( param_id == CvSVM::P )
142 grid.step = 7; // total iterations = 4
144 else if( param_id == CvSVM::NU )
148 grid.step = 3; // total iterations = 3
150 else if( param_id == CvSVM::COEF )
154 grid.step = 14; // total iterations = 3
156 else if( param_id == CvSVM::DEGREE )
160 grid.step = 7; // total iterations = 3
163 cvError( CV_StsBadArg, "CvSVM::get_default_grid", "Invalid type of parameter "
164 "(use one of CvSVM::C, CvSVM::GAMMA et al.)", __FILE__, __LINE__ );
168 // SVM training parameters
169 CvSVMParams::CvSVMParams() :
170 svm_type(CvSVM::C_SVC), kernel_type(CvSVM::RBF), degree(0),
171 gamma(1), coef0(0), C(1), nu(0), p(0), class_weights(0)
173 term_crit = cvTermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 1000, FLT_EPSILON );
177 CvSVMParams::CvSVMParams( int _svm_type, int _kernel_type,
178 double _degree, double _gamma, double _coef0,
179 double _Con, double _nu, double _p,
180 CvMat* _class_weights, CvTermCriteria _term_crit ) :
181 svm_type(_svm_type), kernel_type(_kernel_type),
182 degree(_degree), gamma(_gamma), coef0(_coef0),
183 C(_Con), nu(_nu), p(_p), class_weights(_class_weights), term_crit(_term_crit)
188 /////////////////////////////////////// SVM kernel ///////////////////////////////////////
190 CvSVMKernel::CvSVMKernel()
196 void CvSVMKernel::clear()
203 CvSVMKernel::~CvSVMKernel()
208 CvSVMKernel::CvSVMKernel( const CvSVMParams* _params, Calc _calc_func )
211 create( _params, _calc_func );
215 bool CvSVMKernel::create( const CvSVMParams* _params, Calc _calc_func )
219 calc_func = _calc_func;
222 calc_func = params->kernel_type == CvSVM::RBF ? &CvSVMKernel::calc_rbf :
223 params->kernel_type == CvSVM::POLY ? &CvSVMKernel::calc_poly :
224 params->kernel_type == CvSVM::SIGMOID ? &CvSVMKernel::calc_sigmoid :
225 &CvSVMKernel::calc_linear;
231 void CvSVMKernel::calc_non_rbf_base( int vcount, int var_count, const float** vecs,
232 const float* another, Qfloat* results,
233 double alpha, double beta )
236 for( j = 0; j < vcount; j++ )
238 const float* sample = vecs[j];
240 for( k = 0; k <= var_count - 4; k += 4 )
241 s += sample[k]*another[k] + sample[k+1]*another[k+1] +
242 sample[k+2]*another[k+2] + sample[k+3]*another[k+3];
243 for( ; k < var_count; k++ )
244 s += sample[k]*another[k];
245 results[j] = (Qfloat)(s*alpha + beta);
250 void CvSVMKernel::calc_linear( int vcount, int var_count, const float** vecs,
251 const float* another, Qfloat* results )
253 calc_non_rbf_base( vcount, var_count, vecs, another, results, 1, 0 );
257 void CvSVMKernel::calc_poly( int vcount, int var_count, const float** vecs,
258 const float* another, Qfloat* results )
260 CvMat R = cvMat( 1, vcount, QFLOAT_TYPE, results );
261 calc_non_rbf_base( vcount, var_count, vecs, another, results, params->gamma, params->coef0 );
262 cvPow( &R, &R, params->degree );
266 void CvSVMKernel::calc_sigmoid( int vcount, int var_count, const float** vecs,
267 const float* another, Qfloat* results )
270 calc_non_rbf_base( vcount, var_count, vecs, another, results,
271 -2*params->gamma, -2*params->coef0 );
272 // TODO: speedup this
273 for( j = 0; j < vcount; j++ )
275 Qfloat t = results[j];
276 double e = exp(-fabs(t));
278 results[j] = (Qfloat)((1. - e)/(1. + e));
280 results[j] = (Qfloat)((e - 1.)/(e + 1.));
285 void CvSVMKernel::calc_rbf( int vcount, int var_count, const float** vecs,
286 const float* another, Qfloat* results )
288 CvMat R = cvMat( 1, vcount, QFLOAT_TYPE, results );
289 double gamma = -params->gamma;
292 for( j = 0; j < vcount; j++ )
294 const float* sample = vecs[j];
297 for( k = 0; k <= var_count - 4; k += 4 )
299 double t0 = sample[k] - another[k];
300 double t1 = sample[k+1] - another[k+1];
304 t0 = sample[k+2] - another[k+2];
305 t1 = sample[k+3] - another[k+3];
310 for( ; k < var_count; k++ )
312 double t0 = sample[k] - another[k];
315 results[j] = (Qfloat)(s*gamma);
322 void CvSVMKernel::calc( int vcount, int var_count, const float** vecs,
323 const float* another, Qfloat* results )
325 const Qfloat max_val = (Qfloat)(FLT_MAX*1e-3);
327 (this->*calc_func)( vcount, var_count, vecs, another, results );
328 for( j = 0; j < vcount; j++ )
330 if( results[j] > max_val )
331 results[j] = max_val;
336 // Generalized SMO+SVMlight algorithm
339 // min [0.5(\alpha^T Q \alpha) + b^T \alpha]
341 // y^T \alpha = \delta
343 // 0 <= alpha_i <= Cp for y_i = 1
344 // 0 <= alpha_i <= Cn for y_i = -1
348 // Q, b, y, Cp, Cn, and an initial feasible point \alpha
349 // l is the size of vectors and matrices
350 // eps is the stopping criterion
352 // solution will be put in \alpha, objective value will be put in obj
355 void CvSVMSolver::clear()
362 cvReleaseMemStorage( &storage );
364 select_working_set_func = 0;
373 CvSVMSolver::CvSVMSolver()
380 CvSVMSolver::~CvSVMSolver()
386 CvSVMSolver::CvSVMSolver( int _sample_count, int _var_count, const float** _samples, char* _y,
387 int _alpha_count, double* _alpha, double _Cp, double _Cn,
388 CvMemStorage* _storage, CvSVMKernel* _kernel, GetRow _get_row,
389 SelectWorkingSet _select_working_set, CalcRho _calc_rho )
392 create( _sample_count, _var_count, _samples, _y, _alpha_count, _alpha, _Cp, _Cn,
393 _storage, _kernel, _get_row, _select_working_set, _calc_rho );
397 bool CvSVMSolver::create( int _sample_count, int _var_count, const float** _samples, char* _y,
398 int _alpha_count, double* _alpha, double _Cp, double _Cn,
399 CvMemStorage* _storage, CvSVMKernel* _kernel, GetRow _get_row,
400 SelectWorkingSet _select_working_set, CalcRho _calc_rho )
405 CV_FUNCNAME( "CvSVMSolver::create" );
413 sample_count = _sample_count;
414 var_count = _var_count;
417 alpha_count = _alpha_count;
423 eps = kernel->params->term_crit.epsilon;
424 max_iter = kernel->params->term_crit.max_iter;
425 storage = cvCreateChildMemStorage( _storage );
427 b = (double*)cvMemStorageAlloc( storage, alpha_count*sizeof(b[0]));
428 alpha_status = (char*)cvMemStorageAlloc( storage, alpha_count*sizeof(alpha_status[0]));
429 G = (double*)cvMemStorageAlloc( storage, alpha_count*sizeof(G[0]));
430 for( i = 0; i < 2; i++ )
431 buf[i] = (Qfloat*)cvMemStorageAlloc( storage, sample_count*2*sizeof(buf[i][0]) );
432 svm_type = kernel->params->svm_type;
434 select_working_set_func = _select_working_set;
435 if( !select_working_set_func )
436 select_working_set_func = svm_type == CvSVM::NU_SVC || svm_type == CvSVM::NU_SVR ?
437 &CvSVMSolver::select_working_set_nu_svm : &CvSVMSolver::select_working_set;
439 calc_rho_func = _calc_rho;
441 calc_rho_func = svm_type == CvSVM::NU_SVC || svm_type == CvSVM::NU_SVR ?
442 &CvSVMSolver::calc_rho_nu_svm : &CvSVMSolver::calc_rho;
444 get_row_func = _get_row;
446 get_row_func = params->svm_type == CvSVM::EPS_SVR ||
447 params->svm_type == CvSVM::NU_SVR ? &CvSVMSolver::get_row_svr :
448 params->svm_type == CvSVM::C_SVC ||
449 params->svm_type == CvSVM::NU_SVC ? &CvSVMSolver::get_row_svc :
450 &CvSVMSolver::get_row_one_class;
452 cache_line_size = sample_count*sizeof(Qfloat);
453 // cache size = max(num_of_samples^2*sizeof(Qfloat)*0.25, 64Kb)
454 // (assuming that for large training sets ~25% of Q matrix is used)
455 cache_size = MAX( cache_line_size*sample_count/4, CV_SVM_MIN_CACHE_SIZE );
457 // the size of Q matrix row headers
458 rows_hdr_size = sample_count*sizeof(rows[0]);
459 if( rows_hdr_size > storage->block_size )
460 CV_ERROR( CV_StsOutOfRange, "Too small storage block size" );
462 lru_list.prev = lru_list.next = &lru_list;
463 rows = (CvSVMKernelRow*)cvMemStorageAlloc( storage, rows_hdr_size );
464 memset( rows, 0, rows_hdr_size );
474 float* CvSVMSolver::get_row_base( int i, bool* _existed )
476 int i1 = i < sample_count ? i : i - sample_count;
477 CvSVMKernelRow* row = rows + i1;
478 bool existed = row->data != 0;
481 if( existed || cache_size <= 0 )
483 CvSVMKernelRow* del_row = existed ? row : lru_list.prev;
484 data = del_row->data;
487 // delete row from the LRU list
489 del_row->prev->next = del_row->next;
490 del_row->next->prev = del_row->prev;
494 data = (Qfloat*)cvMemStorageAlloc( storage, cache_line_size );
495 cache_size -= cache_line_size;
498 // insert row into the LRU list
500 row->prev = &lru_list;
501 row->next = lru_list.next;
502 row->prev->next = row->next->prev = row;
506 kernel->calc( sample_count, var_count, samples, samples[i1], row->data );
516 float* CvSVMSolver::get_row_svc( int i, float* row, float*, bool existed )
521 int j, len = sample_count;
522 assert( _y && i < sample_count );
526 for( j = 0; j < len; j++ )
527 row[j] = _y[j]*row[j];
531 for( j = 0; j < len; j++ )
532 row[j] = -_y[j]*row[j];
539 float* CvSVMSolver::get_row_one_class( int, float* row, float*, bool )
545 float* CvSVMSolver::get_row_svr( int i, float* row, float* dst, bool )
547 int j, len = sample_count;
548 Qfloat* dst_pos = dst;
549 Qfloat* dst_neg = dst + len;
553 CV_SWAP( dst_pos, dst_neg, temp );
556 for( j = 0; j < len; j++ )
567 float* CvSVMSolver::get_row( int i, float* dst )
569 bool existed = false;
570 float* row = get_row_base( i, &existed );
571 return (this->*get_row_func)( i, row, dst, existed );
575 #undef is_upper_bound
576 #define is_upper_bound(i) (alpha_status[i] > 0)
578 #undef is_lower_bound
579 #define is_lower_bound(i) (alpha_status[i] < 0)
582 #define is_free(i) (alpha_status[i] == 0)
585 #define get_C(i) (C[y[i]>0])
587 #undef update_alpha_status
588 #define update_alpha_status(i) \
589 alpha_status[i] = (char)(alpha[i] >= get_C(i) ? 1 : alpha[i] <= 0 ? -1 : 0)
591 #undef reconstruct_gradient
592 #define reconstruct_gradient() /* empty for now */
595 bool CvSVMSolver::solve_generic( CvSVMSolutionInfo& si )
600 // 1. initialize gradient and alpha status
601 for( i = 0; i < alpha_count; i++ )
603 update_alpha_status(i);
605 if( fabs(G[i]) > 1e200 )
609 for( i = 0; i < alpha_count; i++ )
611 if( !is_lower_bound(i) )
613 const Qfloat *Q_i = get_row( i, buf[0] );
614 double alpha_i = alpha[i];
616 for( j = 0; j < alpha_count; j++ )
617 G[j] += alpha_i*Q_i[j];
621 // 2. optimization loop
624 const Qfloat *Q_i, *Q_j;
626 double old_alpha_i, old_alpha_j, alpha_i, alpha_j;
627 double delta_alpha_i, delta_alpha_j;
630 for( i = 0; i < alpha_count; i++ )
632 if( fabs(G[i]) > 1e+300 )
635 if( fabs(alpha[i]) > 1e16 )
640 if( (this->*select_working_set_func)( i, j ) != 0 || iter++ >= max_iter )
643 Q_i = get_row( i, buf[0] );
644 Q_j = get_row( j, buf[1] );
649 alpha_i = old_alpha_i = alpha[i];
650 alpha_j = old_alpha_j = alpha[j];
654 double denom = Q_i[i]+Q_j[j]+2*Q_i[j];
655 double delta = (-G[i]-G[j])/MAX(fabs(denom),FLT_EPSILON);
656 double diff = alpha_i - alpha_j;
660 if( diff > 0 && alpha_j < 0 )
665 else if( diff <= 0 && alpha_i < 0 )
671 if( diff > C_i - C_j && alpha_i > C_i )
674 alpha_j = C_i - diff;
676 else if( diff <= C_i - C_j && alpha_j > C_j )
679 alpha_i = C_j + diff;
684 double denom = Q_i[i]+Q_j[j]-2*Q_i[j];
685 double delta = (G[i]-G[j])/MAX(fabs(denom),FLT_EPSILON);
686 double sum = alpha_i + alpha_j;
690 if( sum > C_i && alpha_i > C_i )
695 else if( sum <= C_i && alpha_j < 0)
701 if( sum > C_j && alpha_j > C_j )
706 else if( sum <= C_j && alpha_i < 0 )
716 update_alpha_status(i);
717 update_alpha_status(j);
720 delta_alpha_i = alpha_i - old_alpha_i;
721 delta_alpha_j = alpha_j - old_alpha_j;
723 for( k = 0; k < alpha_count; k++ )
724 G[k] += Q_i[k]*delta_alpha_i + Q_j[k]*delta_alpha_j;
728 (this->*calc_rho_func)( si.rho, si.r );
730 // calculate objective value
731 for( i = 0, si.obj = 0; i < alpha_count; i++ )
732 si.obj += alpha[i] * (G[i] + b[i]);
736 si.upper_bound_p = C[1];
737 si.upper_bound_n = C[0];
743 // return 1 if already optimal, return 0 otherwise
745 CvSVMSolver::select_working_set( int& out_i, int& out_j )
747 // return i,j which maximize -grad(f)^T d , under constraint
748 // if alpha_i == C, d != +1
749 // if alpha_i == 0, d != -1
750 double Gmax1 = -DBL_MAX; // max { -grad(f)_i * d | y_i*d = +1 }
753 double Gmax2 = -DBL_MAX; // max { -grad(f)_i * d | y_i*d = -1 }
758 for( i = 0; i < alpha_count; i++ )
762 if( y[i] > 0 ) // y = +1
764 if( !is_upper_bound(i) && (t = -G[i]) > Gmax1 ) // d = +1
769 if( !is_lower_bound(i) && (t = G[i]) > Gmax2 ) // d = -1
777 if( !is_upper_bound(i) && (t = -G[i]) > Gmax2 ) // d = +1
782 if( !is_lower_bound(i) && (t = G[i]) > Gmax1 ) // d = -1
793 return Gmax1 + Gmax2 < eps;
798 CvSVMSolver::calc_rho( double& rho, double& r )
801 double ub = DBL_MAX, lb = -DBL_MAX, sum_free = 0;
803 for( i = 0; i < alpha_count; i++ )
805 double yG = y[i]*G[i];
807 if( is_lower_bound(i) )
814 else if( is_upper_bound(i) )
828 rho = nr_free > 0 ? sum_free/nr_free : (ub + lb)*0.5;
834 CvSVMSolver::select_working_set_nu_svm( int& out_i, int& out_j )
836 // return i,j which maximize -grad(f)^T d , under constraint
837 // if alpha_i == C, d != +1
838 // if alpha_i == 0, d != -1
839 double Gmax1 = -DBL_MAX; // max { -grad(f)_i * d | y_i = +1, d = +1 }
842 double Gmax2 = -DBL_MAX; // max { -grad(f)_i * d | y_i = +1, d = -1 }
845 double Gmax3 = -DBL_MAX; // max { -grad(f)_i * d | y_i = -1, d = +1 }
848 double Gmax4 = -DBL_MAX; // max { -grad(f)_i * d | y_i = -1, d = -1 }
853 for( i = 0; i < alpha_count; i++ )
857 if( y[i] > 0 ) // y == +1
859 if( !is_upper_bound(i) && (t = -G[i]) > Gmax1 ) // d = +1
864 if( !is_lower_bound(i) && (t = G[i]) > Gmax2 ) // d = -1
872 if( !is_upper_bound(i) && (t = -G[i]) > Gmax3 ) // d = +1
877 if( !is_lower_bound(i) && (t = G[i]) > Gmax4 ) // d = -1
885 if( MAX(Gmax1 + Gmax2, Gmax3 + Gmax4) < eps )
888 if( Gmax1 + Gmax2 > Gmax3 + Gmax4 )
903 CvSVMSolver::calc_rho_nu_svm( double& rho, double& r )
905 int nr_free1 = 0, nr_free2 = 0;
906 double ub1 = DBL_MAX, ub2 = DBL_MAX;
907 double lb1 = -DBL_MAX, lb2 = -DBL_MAX;
908 double sum_free1 = 0, sum_free2 = 0;
913 for( i = 0; i < alpha_count; i++ )
918 if( is_lower_bound(i) )
919 ub1 = MIN( ub1, G_i );
920 else if( is_upper_bound(i) )
921 lb1 = MAX( lb1, G_i );
930 if( is_lower_bound(i) )
931 ub2 = MIN( ub2, G_i );
932 else if( is_upper_bound(i) )
933 lb2 = MAX( lb2, G_i );
942 r1 = nr_free1 > 0 ? sum_free1/nr_free1 : (ub1 + lb1)*0.5;
943 r2 = nr_free2 > 0 ? sum_free2/nr_free2 : (ub2 + lb2)*0.5;
951 ///////////////////////// construct and solve various formulations ///////////////////////
954 bool CvSVMSolver::solve_c_svc( int _sample_count, int _var_count, const float** _samples, char* _y,
955 double _Cp, double _Cn, CvMemStorage* _storage,
956 CvSVMKernel* _kernel, double* _alpha, CvSVMSolutionInfo& _si )
960 if( !create( _sample_count, _var_count, _samples, _y, _sample_count,
961 _alpha, _Cp, _Cn, _storage, _kernel, &CvSVMSolver::get_row_svc,
962 &CvSVMSolver::select_working_set, &CvSVMSolver::calc_rho ))
965 for( i = 0; i < sample_count; i++ )
971 if( !solve_generic( _si ))
974 for( i = 0; i < sample_count; i++ )
981 bool CvSVMSolver::solve_nu_svc( int _sample_count, int _var_count, const float** _samples, char* _y,
982 CvMemStorage* _storage, CvSVMKernel* _kernel,
983 double* _alpha, CvSVMSolutionInfo& _si )
986 double sum_pos, sum_neg, inv_r;
988 if( !create( _sample_count, _var_count, _samples, _y, _sample_count,
989 _alpha, 1., 1., _storage, _kernel, &CvSVMSolver::get_row_svc,
990 &CvSVMSolver::select_working_set_nu_svm, &CvSVMSolver::calc_rho_nu_svm ))
993 sum_pos = kernel->params->nu * sample_count * 0.5;
994 sum_neg = kernel->params->nu * sample_count * 0.5;
996 for( i = 0; i < sample_count; i++ )
1000 alpha[i] = MIN(1.0, sum_pos);
1001 sum_pos -= alpha[i];
1005 alpha[i] = MIN(1.0, sum_neg);
1006 sum_neg -= alpha[i];
1011 if( !solve_generic( _si ))
1016 for( i = 0; i < sample_count; i++ )
1017 alpha[i] *= y[i]*inv_r;
1020 _si.obj *= (inv_r*inv_r);
1021 _si.upper_bound_p = inv_r;
1022 _si.upper_bound_n = inv_r;
1028 bool CvSVMSolver::solve_one_class( int _sample_count, int _var_count, const float** _samples,
1029 CvMemStorage* _storage, CvSVMKernel* _kernel,
1030 double* _alpha, CvSVMSolutionInfo& _si )
1033 double nu = _kernel->params->nu;
1035 if( !create( _sample_count, _var_count, _samples, 0, _sample_count,
1036 _alpha, 1., 1., _storage, _kernel, &CvSVMSolver::get_row_one_class,
1037 &CvSVMSolver::select_working_set, &CvSVMSolver::calc_rho ))
1040 y = (char*)cvMemStorageAlloc( storage, sample_count*sizeof(y[0]) );
1041 n = cvRound( nu*sample_count );
1043 for( i = 0; i < sample_count; i++ )
1047 alpha[i] = i < n ? 1 : 0;
1050 if( n < sample_count )
1051 alpha[n] = nu * sample_count - n;
1053 alpha[n-1] = nu * sample_count - (n-1);
1055 return solve_generic(_si);
1059 bool CvSVMSolver::solve_eps_svr( int _sample_count, int _var_count, const float** _samples,
1060 const float* _y, CvMemStorage* _storage,
1061 CvSVMKernel* _kernel, double* _alpha, CvSVMSolutionInfo& _si )
1064 double p = _kernel->params->p, C = _kernel->params->C;
1066 if( !create( _sample_count, _var_count, _samples, 0,
1067 _sample_count*2, 0, C, C, _storage, _kernel, &CvSVMSolver::get_row_svr,
1068 &CvSVMSolver::select_working_set, &CvSVMSolver::calc_rho ))
1071 y = (char*)cvMemStorageAlloc( storage, sample_count*2*sizeof(y[0]) );
1072 alpha = (double*)cvMemStorageAlloc( storage, alpha_count*sizeof(alpha[0]) );
1074 for( i = 0; i < sample_count; i++ )
1080 alpha[i+sample_count] = 0;
1081 b[i+sample_count] = p + _y[i];
1082 y[i+sample_count] = -1;
1085 if( !solve_generic( _si ))
1088 for( i = 0; i < sample_count; i++ )
1089 _alpha[i] = alpha[i] - alpha[i+sample_count];
1095 bool CvSVMSolver::solve_nu_svr( int _sample_count, int _var_count, const float** _samples,
1096 const float* _y, CvMemStorage* _storage,
1097 CvSVMKernel* _kernel, double* _alpha, CvSVMSolutionInfo& _si )
1100 double C = _kernel->params->C, sum;
1102 if( !create( _sample_count, _var_count, _samples, 0,
1103 _sample_count*2, 0, 1., 1., _storage, _kernel, &CvSVMSolver::get_row_svr,
1104 &CvSVMSolver::select_working_set_nu_svm, &CvSVMSolver::calc_rho_nu_svm ))
1107 y = (char*)cvMemStorageAlloc( storage, sample_count*2*sizeof(y[0]) );
1108 alpha = (double*)cvMemStorageAlloc( storage, alpha_count*sizeof(alpha[0]) );
1109 sum = C * _kernel->params->nu * sample_count * 0.5;
1111 for( i = 0; i < sample_count; i++ )
1113 alpha[i] = alpha[i + sample_count] = MIN(sum, C);
1119 b[i + sample_count] = _y[i];
1120 y[i + sample_count] = -1;
1123 if( !solve_generic( _si ))
1126 for( i = 0; i < sample_count; i++ )
1127 _alpha[i] = alpha[i] - alpha[i+sample_count];
1133 //////////////////////////////////////////////////////////////////////////////////////////
1144 default_model_name = "my_svm";
1158 cvFree( &decision_func );
1159 cvReleaseMat( &class_labels );
1160 cvReleaseMat( &class_weights );
1161 cvReleaseMemStorage( &storage );
1162 cvReleaseMat( &var_idx );
1173 CvSVM::CvSVM( const CvMat* _train_data, const CvMat* _responses,
1174 const CvMat* _var_idx, const CvMat* _sample_idx, CvSVMParams _params )
1183 default_model_name = "my_svm";
1185 train( _train_data, _responses, _var_idx, _sample_idx, _params );
1189 int CvSVM::get_support_vector_count() const
1195 const float* CvSVM::get_support_vector(int i) const
1197 return sv && (unsigned)i < (unsigned)sv_total ? sv[i] : 0;
1201 bool CvSVM::set_params( const CvSVMParams& _params )
1205 CV_FUNCNAME( "CvSVM::set_params" );
1209 int kernel_type, svm_type;
1213 kernel_type = params.kernel_type;
1214 svm_type = params.svm_type;
1216 if( kernel_type != LINEAR && kernel_type != POLY &&
1217 kernel_type != SIGMOID && kernel_type != RBF )
1218 CV_ERROR( CV_StsBadArg, "Unknown/unsupported kernel type" );
1220 if( kernel_type == LINEAR )
1222 else if( params.gamma <= 0 )
1223 CV_ERROR( CV_StsOutOfRange, "gamma parameter of the kernel must be positive" );
1225 if( kernel_type != SIGMOID && kernel_type != POLY )
1227 else if( params.coef0 < 0 )
1228 CV_ERROR( CV_StsOutOfRange, "The kernel parameter <coef0> must be positive or zero" );
1230 if( kernel_type != POLY )
1232 else if( params.degree <= 0 )
1233 CV_ERROR( CV_StsOutOfRange, "The kernel parameter <degree> must be positive" );
1235 if( svm_type != C_SVC && svm_type != NU_SVC &&
1236 svm_type != ONE_CLASS && svm_type != EPS_SVR &&
1237 svm_type != NU_SVR )
1238 CV_ERROR( CV_StsBadArg, "Unknown/unsupported SVM type" );
1240 if( svm_type == ONE_CLASS || svm_type == NU_SVC )
1242 else if( params.C <= 0 )
1243 CV_ERROR( CV_StsOutOfRange, "The parameter C must be positive" );
1245 if( svm_type == C_SVC || svm_type == EPS_SVR )
1247 else if( params.nu <= 0 || params.nu >= 1 )
1248 CV_ERROR( CV_StsOutOfRange, "The parameter nu must be between 0 and 1" );
1250 if( svm_type != EPS_SVR )
1252 else if( params.p <= 0 )
1253 CV_ERROR( CV_StsOutOfRange, "The parameter p must be positive" );
1255 if( svm_type != C_SVC )
1256 params.class_weights = 0;
1258 params.term_crit = cvCheckTermCriteria( params.term_crit, DBL_EPSILON, INT_MAX );
1259 params.term_crit.epsilon = MAX( params.term_crit.epsilon, DBL_EPSILON );
1269 void CvSVM::create_kernel()
1271 kernel = new CvSVMKernel(¶ms,0);
1275 void CvSVM::create_solver( )
1277 solver = new CvSVMSolver;
1281 // switching function
1282 bool CvSVM::train1( int sample_count, int var_count, const float** samples,
1283 const void* _responses, double Cp, double Cn,
1284 CvMemStorage* _storage, double* alpha, double& rho )
1288 //CV_FUNCNAME( "CvSVM::train1" );
1292 CvSVMSolutionInfo si;
1293 int svm_type = params.svm_type;
1297 ok = svm_type == C_SVC ? solver->solve_c_svc( sample_count, var_count, samples, (char*)_responses,
1298 Cp, Cn, _storage, kernel, alpha, si ) :
1299 svm_type == NU_SVC ? solver->solve_nu_svc( sample_count, var_count, samples, (char*)_responses,
1300 _storage, kernel, alpha, si ) :
1301 svm_type == ONE_CLASS ? solver->solve_one_class( sample_count, var_count, samples,
1302 _storage, kernel, alpha, si ) :
1303 svm_type == EPS_SVR ? solver->solve_eps_svr( sample_count, var_count, samples, (float*)_responses,
1304 _storage, kernel, alpha, si ) :
1305 svm_type == NU_SVR ? solver->solve_nu_svr( sample_count, var_count, samples, (float*)_responses,
1306 _storage, kernel, alpha, si ) : false;
1316 bool CvSVM::do_train( int svm_type, int sample_count, int var_count, const float** samples,
1317 const CvMat* responses, CvMemStorage* temp_storage, double* alpha )
1321 CV_FUNCNAME( "CvSVM::do_train" );
1325 CvSVMDecisionFunc* df = 0;
1326 const int sample_size = var_count*sizeof(samples[0][0]);
1329 if( svm_type == ONE_CLASS || svm_type == EPS_SVR || svm_type == NU_SVR )
1333 CV_CALL( decision_func = df =
1334 (CvSVMDecisionFunc*)cvAlloc( sizeof(df[0]) ));
1337 if( !train1( sample_count, var_count, samples, svm_type == ONE_CLASS ? 0 :
1338 responses->data.i, 0, 0, temp_storage, alpha, df->rho ))
1341 for( i = 0; i < sample_count; i++ )
1342 sv_count += fabs(alpha[i]) > 0;
1344 sv_total = df->sv_count = sv_count;
1345 CV_CALL( df->alpha = (double*)cvMemStorageAlloc( storage, sv_count*sizeof(df->alpha[0])) );
1346 CV_CALL( sv = (float**)cvMemStorageAlloc( storage, sv_count*sizeof(sv[0])));
1348 for( i = k = 0; i < sample_count; i++ )
1350 if( fabs(alpha[i]) > 0 )
1352 CV_CALL( sv[k] = (float*)cvMemStorageAlloc( storage, sample_size ));
1353 memcpy( sv[k], samples[i], sample_size );
1354 df->alpha[k++] = alpha[i];
1360 int class_count = class_labels->cols;
1362 const float** temp_samples = 0;
1363 int* class_ranges = 0;
1365 assert( svm_type == CvSVM::C_SVC || svm_type == CvSVM::NU_SVC );
1367 if( svm_type == CvSVM::C_SVC && params.class_weights )
1369 const CvMat* cw = params.class_weights;
1371 if( !CV_IS_MAT(cw) || cw->cols != 1 && cw->rows != 1 ||
1372 cw->rows + cw->cols - 1 != class_count ||
1373 CV_MAT_TYPE(cw->type) != CV_32FC1 && CV_MAT_TYPE(cw->type) != CV_64FC1 )
1374 CV_ERROR( CV_StsBadArg, "params.class_weights must be 1d floating-point vector "
1375 "containing as many elements as the number of classes" );
1377 CV_CALL( class_weights = cvCreateMat( cw->rows, cw->cols, CV_64F ));
1378 CV_CALL( cvConvert( cw, class_weights ));
1379 CV_CALL( cvScale( class_weights, class_weights, params.C ));
1382 CV_CALL( decision_func = df = (CvSVMDecisionFunc*)cvAlloc(
1383 (class_count*(class_count-1)/2)*sizeof(df[0])));
1385 CV_CALL( sv_tab = (int*)cvMemStorageAlloc( temp_storage, sample_count*sizeof(sv_tab[0]) ));
1386 memset( sv_tab, 0, sample_count*sizeof(sv_tab[0]) );
1387 CV_CALL( class_ranges = (int*)cvMemStorageAlloc( temp_storage,
1388 (class_count + 1)*sizeof(class_ranges[0])));
1389 CV_CALL( temp_samples = (const float**)cvMemStorageAlloc( temp_storage,
1390 sample_count*sizeof(temp_samples[0])));
1391 CV_CALL( temp_y = (char*)cvMemStorageAlloc( temp_storage, sample_count));
1393 class_ranges[class_count] = 0;
1394 cvSortSamplesByClasses( samples, responses, class_ranges, 0 );
1395 //check that while cross-validation there were the samples from all the classes
1396 if( class_ranges[class_count] <= 0 )
1397 CV_ERROR( CV_StsBadArg, "While cross-validation one or more of the classes have "
1398 "been fell out of the sample. Try to enlarge <CvSVMParams::k_fold>" );
1400 if( svm_type == NU_SVC )
1402 // check if nu is feasible
1403 for(i = 0; i < class_count; i++ )
1405 int ci = class_ranges[i+1] - class_ranges[i];
1406 for( j = i+1; j< class_count; j++ )
1408 int cj = class_ranges[j+1] - class_ranges[j];
1409 if( params.nu*(ci + cj)*0.5 > MIN( ci, cj ) )
1411 // !!!TODO!!! add some diagnostic
1412 EXIT; // exit immediately; will release the model and return NULL pointer
1418 // train n*(n-1)/2 classifiers
1419 for( i = 0; i < class_count; i++ )
1421 for( j = i+1; j < class_count; j++, df++ )
1423 int si = class_ranges[i], ci = class_ranges[i+1] - si;
1424 int sj = class_ranges[j], cj = class_ranges[j+1] - sj;
1425 double Cp = params.C, Cn = Cp;
1426 int k1 = 0, sv_count = 0;
1428 for( k = 0; k < ci; k++ )
1430 temp_samples[k] = samples[si + k];
1434 for( k = 0; k < cj; k++ )
1436 temp_samples[ci + k] = samples[sj + k];
1437 temp_y[ci + k] = -1;
1442 Cp = class_weights->data.db[i];
1443 Cn = class_weights->data.db[j];
1446 if( !train1( ci + cj, var_count, temp_samples, temp_y,
1447 Cp, Cn, temp_storage, alpha, df->rho ))
1450 for( k = 0; k < ci + cj; k++ )
1451 sv_count += fabs(alpha[k]) > 0;
1453 df->sv_count = sv_count;
1455 CV_CALL( df->alpha = (double*)cvMemStorageAlloc( temp_storage,
1456 sv_count*sizeof(df->alpha[0])));
1457 CV_CALL( df->sv_index = (int*)cvMemStorageAlloc( temp_storage,
1458 sv_count*sizeof(df->sv_index[0])));
1460 for( k = 0; k < ci; k++ )
1462 if( fabs(alpha[k]) > 0 )
1465 df->sv_index[k1] = si + k;
1466 df->alpha[k1++] = alpha[k];
1470 for( k = 0; k < cj; k++ )
1472 if( fabs(alpha[ci + k]) > 0 )
1475 df->sv_index[k1] = sj + k;
1476 df->alpha[k1++] = alpha[ci + k];
1482 // allocate support vectors and initialize sv_tab
1483 for( i = 0, k = 0; i < sample_count; i++ )
1490 CV_CALL( sv = (float**)cvMemStorageAlloc( storage, sv_total*sizeof(sv[0])));
1492 for( i = 0, k = 0; i < sample_count; i++ )
1496 CV_CALL( sv[k] = (float*)cvMemStorageAlloc( storage, sample_size ));
1497 memcpy( sv[k], samples[i], sample_size );
1502 df = (CvSVMDecisionFunc*)decision_func;
1505 for( i = 0; i < class_count; i++ )
1507 for( j = i+1; j < class_count; j++, df++ )
1509 for( k = 0; k < df->sv_count; k++ )
1511 df->sv_index[k] = sv_tab[df->sv_index[k]]-1;
1512 assert( (unsigned)df->sv_index[k] < (unsigned)sv_total );
1525 bool CvSVM::train( const CvMat* _train_data, const CvMat* _responses,
1526 const CvMat* _var_idx, const CvMat* _sample_idx, CvSVMParams _params )
1529 CvMat* responses = 0;
1530 CvMemStorage* temp_storage = 0;
1531 const float** samples = 0;
1533 CV_FUNCNAME( "CvSVM::train" );
1537 int svm_type, sample_count, var_count, sample_size;
1538 int block_size = 1 << 16;
1542 CV_CALL( set_params( _params ));
1544 svm_type = _params.svm_type;
1546 /* Prepare training data and related parameters */
1547 CV_CALL( cvPrepareTrainData( "CvSVM::train", _train_data, CV_ROW_SAMPLE,
1548 svm_type != CvSVM::ONE_CLASS ? _responses : 0,
1549 svm_type == CvSVM::C_SVC ||
1550 svm_type == CvSVM::NU_SVC ? CV_VAR_CATEGORICAL :
1551 CV_VAR_ORDERED, _var_idx, _sample_idx,
1552 false, &samples, &sample_count, &var_count, &var_all,
1553 &responses, &class_labels, &var_idx ));
1556 sample_size = var_count*sizeof(samples[0][0]);
1558 // make the storage block size large enough to fit all
1559 // the temporary vectors and output support vectors.
1560 block_size = MAX( block_size, sample_count*(int)sizeof(CvSVMKernelRow));
1561 block_size = MAX( block_size, sample_count*2*(int)sizeof(double) + 1024 );
1562 block_size = MAX( block_size, sample_size*2 + 1024 );
1564 CV_CALL( storage = cvCreateMemStorage(block_size));
1565 CV_CALL( temp_storage = cvCreateChildMemStorage(storage));
1566 CV_CALL( alpha = (double*)cvMemStorageAlloc(temp_storage, sample_count*sizeof(double)));
1571 if( !do_train( svm_type, sample_count, var_count, samples, responses, temp_storage, alpha ))
1574 ok = true; // model has been trained succesfully
1580 cvReleaseMemStorage( &temp_storage );
1581 cvReleaseMat( &responses );
1584 if( cvGetErrStatus() < 0 || !ok )
1590 bool CvSVM::train_auto( const CvMat* _train_data, const CvMat* _responses,
1591 const CvMat* _var_idx, const CvMat* _sample_idx, CvSVMParams _params, int k_fold,
1592 CvParamGrid C_grid, CvParamGrid gamma_grid, CvParamGrid p_grid,
1593 CvParamGrid nu_grid, CvParamGrid coef_grid, CvParamGrid degree_grid )
1596 CvMat* responses = 0;
1597 CvMat* responses_local = 0;
1598 CvMemStorage* temp_storage = 0;
1599 const float** samples = 0;
1600 const float** samples_local = 0;
1602 CV_FUNCNAME( "CvSVM::train_auto" );
1605 int svm_type, sample_count, var_count, sample_size;
1606 int block_size = 1 << 16;
1609 CvRNG rng = cvRNG(-1);
1611 // all steps are logarithmic and must be > 1
1612 double degree_step = 10, g_step = 10, coef_step = 10, C_step = 10, nu_step = 10, p_step = 10;
1613 double gamma = 0, C = 0, degree = 0, coef = 0, p = 0, nu = 0;
1614 double best_degree = 0, best_gamma = 0, best_coef = 0, best_C = 0, best_nu = 0, best_p = 0;
1615 float min_error = FLT_MAX, error;
1617 if( _params.svm_type == CvSVM::ONE_CLASS )
1619 if(!train( _train_data, _responses, _var_idx, _sample_idx, _params ))
1627 CV_ERROR( CV_StsBadArg, "Parameter <k_fold> must be > 1" );
1629 CV_CALL(set_params( _params ));
1630 svm_type = _params.svm_type;
1632 // All the parameters except, possibly, <coef0> are positive.
1633 // <coef0> is nonnegative
1634 if( C_grid.step <= 1 )
1636 C_grid.min_val = C_grid.max_val = params.C;
1640 CV_CALL(C_grid.check());
1642 if( gamma_grid.step <= 1 )
1644 gamma_grid.min_val = gamma_grid.max_val = params.gamma;
1645 gamma_grid.step = 10;
1648 CV_CALL(gamma_grid.check());
1650 if( p_grid.step <= 1 )
1652 p_grid.min_val = p_grid.max_val = params.p;
1656 CV_CALL(p_grid.check());
1658 if( nu_grid.step <= 1 )
1660 nu_grid.min_val = nu_grid.max_val = params.nu;
1664 CV_CALL(nu_grid.check());
1666 if( coef_grid.step <= 1 )
1668 coef_grid.min_val = coef_grid.max_val = params.coef0;
1669 coef_grid.step = 10;
1672 CV_CALL(coef_grid.check());
1674 if( degree_grid.step <= 1 )
1676 degree_grid.min_val = degree_grid.max_val = params.degree;
1677 degree_grid.step = 10;
1680 CV_CALL(degree_grid.check());
1682 // these parameters are not used:
1683 if( params.kernel_type != CvSVM::POLY )
1684 degree_grid.min_val = degree_grid.max_val = params.degree;
1685 if( params.kernel_type == CvSVM::LINEAR )
1686 gamma_grid.min_val = gamma_grid.max_val = params.gamma;
1687 if( params.kernel_type != CvSVM::POLY && params.kernel_type != CvSVM::SIGMOID )
1688 coef_grid.min_val = coef_grid.max_val = params.coef0;
1689 if( svm_type == CvSVM::NU_SVC || svm_type == CvSVM::ONE_CLASS )
1690 C_grid.min_val = C_grid.max_val = params.C;
1691 if( svm_type == CvSVM::C_SVC || svm_type == CvSVM::EPS_SVR )
1692 nu_grid.min_val = nu_grid.max_val = params.nu;
1693 if( svm_type != CvSVM::EPS_SVR )
1694 p_grid.min_val = p_grid.max_val = params.p;
1696 CV_ASSERT( g_step > 1 && degree_step > 1 && coef_step > 1);
1697 CV_ASSERT( p_step > 1 && C_step > 1 && nu_step > 1 );
1699 /* Prepare training data and related parameters */
1700 CV_CALL(cvPrepareTrainData( "CvSVM::train_auto", _train_data, CV_ROW_SAMPLE,
1701 svm_type != CvSVM::ONE_CLASS ? _responses : 0,
1702 svm_type == CvSVM::C_SVC ||
1703 svm_type == CvSVM::NU_SVC ? CV_VAR_CATEGORICAL :
1704 CV_VAR_ORDERED, _var_idx, _sample_idx,
1705 false, &samples, &sample_count, &var_count, &var_all,
1706 &responses, &class_labels, &var_idx ));
1708 sample_size = var_count*sizeof(samples[0][0]);
1710 // make the storage block size large enough to fit all
1711 // the temporary vectors and output support vectors.
1712 block_size = MAX( block_size, sample_count*(int)sizeof(CvSVMKernelRow));
1713 block_size = MAX( block_size, sample_count*2*(int)sizeof(double) + 1024 );
1714 block_size = MAX( block_size, sample_size*2 + 1024 );
1716 CV_CALL(storage = cvCreateMemStorage(block_size));
1717 CV_CALL(temp_storage = cvCreateChildMemStorage(storage));
1718 CV_CALL(alpha = (double*)cvMemStorageAlloc(temp_storage, sample_count*sizeof(double)));
1724 const int testset_size = sample_count/k_fold;
1725 const int trainset_size = sample_count - testset_size;
1726 const int last_testset_size = sample_count - testset_size*(k_fold-1);
1727 const int last_trainset_size = sample_count - last_testset_size;
1728 const bool is_regression = (svm_type == EPS_SVR) || (svm_type == NU_SVR);
1730 size_t resp_elem_size = CV_ELEM_SIZE(responses->type);
1731 size_t size = 2*last_trainset_size*sizeof(samples[0]);
1733 samples_local = (const float**) cvAlloc( size );
1734 memset( samples_local, 0, size );
1736 responses_local = cvCreateMat( 1, trainset_size, CV_MAT_TYPE(responses->type) );
1737 cvZero( responses_local );
1739 // randomly permute samples and responses
1740 for( i = 0; i < sample_count; i++ )
1742 int i1 = cvRandInt( &rng ) % sample_count;
1743 int i2 = cvRandInt( &rng ) % sample_count;
1748 CV_SWAP( samples[i1], samples[i2], temp );
1750 CV_SWAP( responses->data.fl[i1], responses->data.fl[i2], t );
1752 CV_SWAP( responses->data.i[i1], responses->data.i[i2], y );
1759 gamma = gamma_grid.min_val;
1762 params.gamma = gamma;
1767 nu = nu_grid.min_val;
1771 coef = coef_grid.min_val;
1774 params.coef0 = coef;
1775 degree = degree_grid.min_val;
1778 params.degree = degree;
1780 float** test_samples_ptr = (float**)samples;
1781 uchar* true_resp = responses->data.ptr;
1782 int test_size = testset_size;
1783 int train_size = trainset_size;
1786 for( k = 0; k < k_fold; k++ )
1788 memcpy( samples_local, samples, sizeof(samples[0])*test_size*k );
1789 memcpy( samples_local + test_size*k, test_samples_ptr + test_size,
1790 sizeof(samples[0])*(sample_count - testset_size*(k+1)) );
1792 memcpy( responses_local->data.ptr, responses->data.ptr, resp_elem_size*test_size*k );
1793 memcpy( responses_local->data.ptr + resp_elem_size*test_size*k,
1794 true_resp + resp_elem_size*test_size,
1795 sizeof(samples[0])*(sample_count - testset_size*(k+1)) );
1797 if( k == k_fold - 1 )
1799 test_size = last_testset_size;
1800 train_size = last_trainset_size;
1801 responses_local->cols = last_trainset_size;
1804 // Train SVM on <train_size> samples
1805 if( !do_train( svm_type, train_size, var_count,
1806 (const float**)samples_local, responses_local, temp_storage, alpha ) )
1809 // Compute test set error on <test_size> samples
1810 CvMat s = cvMat( 1, var_count, CV_32FC1 );
1811 for( i = 0; i < test_size; i++, true_resp += resp_elem_size, test_samples_ptr++ )
1814 s.data.fl = *test_samples_ptr;
1815 resp = predict( &s );
1816 error += is_regression ? powf( resp - *(float*)true_resp, 2 )
1817 : ((int)resp != *(int*)true_resp);
1820 if( min_error > error )
1823 best_degree = degree;
1830 degree *= degree_grid.step;
1832 while( degree < degree_grid.max_val );
1833 coef *= coef_grid.step;
1835 while( coef < coef_grid.max_val );
1838 while( nu < nu_grid.max_val );
1841 while( p < p_grid.max_val );
1842 gamma *= gamma_grid.step;
1844 while( gamma < gamma_grid.max_val );
1847 while( C < C_grid.max_val );
1850 min_error /= (float) sample_count;
1853 params.nu = best_nu;
1855 params.gamma = best_gamma;
1856 params.degree = best_degree;
1857 params.coef0 = best_coef;
1859 CV_CALL(ok = do_train( svm_type, sample_count, var_count, samples, responses, temp_storage, alpha ));
1865 cvReleaseMemStorage( &temp_storage );
1866 cvReleaseMat( &responses );
1867 cvReleaseMat( &responses_local );
1869 cvFree( &samples_local );
1871 if( cvGetErrStatus() < 0 || !ok )
1877 float CvSVM::predict( const CvMat* sample ) const
1879 bool local_alloc = 0;
1881 float* row_sample = 0;
1884 CV_FUNCNAME( "CvSVM::predict" );
1889 int var_count, buf_sz;
1892 CV_ERROR( CV_StsBadArg, "The SVM should be trained first" );
1894 class_count = class_labels ? class_labels->cols :
1895 params.svm_type == ONE_CLASS ? 1 : 0;
1897 CV_CALL( cvPreparePredictData( sample, var_all, var_idx,
1898 class_count, 0, &row_sample ));
1900 var_count = get_var_count();
1902 buf_sz = sv_total*sizeof(buffer[0]) + (class_count+1)*sizeof(int);
1903 if( buf_sz <= CV_MAX_LOCAL_SIZE )
1905 CV_CALL( buffer = (Qfloat*)cvStackAlloc( buf_sz ));
1909 CV_CALL( buffer = (Qfloat*)cvAlloc( buf_sz ));
1911 if( params.svm_type == EPS_SVR ||
1912 params.svm_type == NU_SVR ||
1913 params.svm_type == ONE_CLASS )
1915 CvSVMDecisionFunc* df = (CvSVMDecisionFunc*)decision_func;
1916 int i, sv_count = df->sv_count;
1917 double sum = -df->rho;
1919 kernel->calc( sv_count, var_count, (const float**)sv, row_sample, buffer );
1920 for( i = 0; i < sv_count; i++ )
1921 sum += buffer[i]*df->alpha[i];
1923 result = params.svm_type == ONE_CLASS ? (float)(sum > 0) : (float)sum;
1925 else if( params.svm_type == C_SVC ||
1926 params.svm_type == NU_SVC )
1928 CvSVMDecisionFunc* df = (CvSVMDecisionFunc*)decision_func;
1929 int* vote = (int*)(buffer + sv_total);
1932 memset( vote, 0, class_count*sizeof(vote[0]));
1933 kernel->calc( sv_total, var_count, (const float**)sv, row_sample, buffer );
1935 for( i = 0; i < class_count; i++ )
1937 for( j = i+1; j < class_count; j++, df++ )
1939 double sum = -df->rho;
1940 int sv_count = df->sv_count;
1941 for( k = 0; k < sv_count; k++ )
1942 sum += df->alpha[k]*buffer[df->sv_index[k]];
1944 vote[sum > 0 ? i : j]++;
1948 for( i = 1, k = 0; i < class_count; i++ )
1950 if( vote[i] > vote[k] )
1954 result = (float)(class_labels->data.i[k]);
1957 CV_ERROR( CV_StsBadArg, "INTERNAL ERROR: Unknown SVM type, "
1958 "the SVM structure is probably corrupted" );
1962 if( sample && (!CV_IS_MAT(sample) || sample->data.fl != row_sample) )
1963 cvFree( &row_sample );
1972 void CvSVM::write_params( CvFileStorage* fs )
1974 //CV_FUNCNAME( "CvSVM::write_params" );
1978 int svm_type = params.svm_type;
1979 int kernel_type = params.kernel_type;
1981 const char* svm_type_str =
1982 svm_type == CvSVM::C_SVC ? "C_SVC" :
1983 svm_type == CvSVM::NU_SVC ? "NU_SVC" :
1984 svm_type == CvSVM::ONE_CLASS ? "ONE_CLASS" :
1985 svm_type == CvSVM::EPS_SVR ? "EPS_SVR" :
1986 svm_type == CvSVM::NU_SVR ? "NU_SVR" : 0;
1987 const char* kernel_type_str =
1988 kernel_type == CvSVM::LINEAR ? "LINEAR" :
1989 kernel_type == CvSVM::POLY ? "POLY" :
1990 kernel_type == CvSVM::RBF ? "RBF" :
1991 kernel_type == CvSVM::SIGMOID ? "SIGMOID" : 0;
1994 cvWriteString( fs, "svm_type", svm_type_str );
1996 cvWriteInt( fs, "svm_type", svm_type );
1999 cvStartWriteStruct( fs, "kernel", CV_NODE_MAP + CV_NODE_FLOW );
2001 if( kernel_type_str )
2002 cvWriteString( fs, "type", kernel_type_str );
2004 cvWriteInt( fs, "type", kernel_type );
2006 if( kernel_type == CvSVM::POLY || !kernel_type_str )
2007 cvWriteReal( fs, "degree", params.degree );
2009 if( kernel_type != CvSVM::LINEAR || !kernel_type_str )
2010 cvWriteReal( fs, "gamma", params.gamma );
2012 if( kernel_type == CvSVM::POLY || kernel_type == CvSVM::SIGMOID || !kernel_type_str )
2013 cvWriteReal( fs, "coef0", params.coef0 );
2015 cvEndWriteStruct(fs);
2017 if( svm_type == CvSVM::C_SVC || svm_type == CvSVM::EPS_SVR ||
2018 svm_type == CvSVM::NU_SVR || !svm_type_str )
2019 cvWriteReal( fs, "C", params.C );
2021 if( svm_type == CvSVM::NU_SVC || svm_type == CvSVM::ONE_CLASS ||
2022 svm_type == CvSVM::NU_SVR || !svm_type_str )
2023 cvWriteReal( fs, "nu", params.nu );
2025 if( svm_type == CvSVM::EPS_SVR || !svm_type_str )
2026 cvWriteReal( fs, "p", params.p );
2028 cvStartWriteStruct( fs, "term_criteria", CV_NODE_MAP + CV_NODE_FLOW );
2029 if( params.term_crit.type & CV_TERMCRIT_EPS )
2030 cvWriteReal( fs, "epsilon", params.term_crit.epsilon );
2031 if( params.term_crit.type & CV_TERMCRIT_ITER )
2032 cvWriteInt( fs, "iterations", params.term_crit.max_iter );
2033 cvEndWriteStruct( fs );
2039 void CvSVM::write( CvFileStorage* fs, const char* name )
2041 CV_FUNCNAME( "CvSVM::write" );
2045 int i, var_count = get_var_count(), df_count, class_count;
2046 const CvSVMDecisionFunc* df = decision_func;
2048 cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_SVM );
2052 cvWriteInt( fs, "var_all", var_all );
2053 cvWriteInt( fs, "var_count", var_count );
2055 class_count = class_labels ? class_labels->cols :
2056 params.svm_type == CvSVM::ONE_CLASS ? 1 : 0;
2060 cvWriteInt( fs, "class_count", class_count );
2063 cvWrite( fs, "class_labels", class_labels );
2066 cvWrite( fs, "class_weights", class_weights );
2070 cvWrite( fs, "var_idx", var_idx );
2072 // write the joint collection of support vectors
2073 cvWriteInt( fs, "sv_total", sv_total );
2074 cvStartWriteStruct( fs, "support_vectors", CV_NODE_SEQ );
2075 for( i = 0; i < sv_total; i++ )
2077 cvStartWriteStruct( fs, 0, CV_NODE_SEQ + CV_NODE_FLOW );
2078 cvWriteRawData( fs, sv[i], var_count, "f" );
2079 cvEndWriteStruct( fs );
2082 cvEndWriteStruct( fs );
2084 // write decision functions
2085 df_count = class_count > 1 ? class_count*(class_count-1)/2 : 1;
2088 cvStartWriteStruct( fs, "decision_functions", CV_NODE_SEQ );
2089 for( i = 0; i < df_count; i++ )
2091 int sv_count = df[i].sv_count;
2092 cvStartWriteStruct( fs, 0, CV_NODE_MAP );
2093 cvWriteInt( fs, "sv_count", sv_count );
2094 cvWriteReal( fs, "rho", df[i].rho );
2095 cvStartWriteStruct( fs, "alpha", CV_NODE_SEQ+CV_NODE_FLOW );
2096 cvWriteRawData( fs, df[i].alpha, df[i].sv_count, "d" );
2097 cvEndWriteStruct( fs );
2098 if( class_count > 1 )
2100 cvStartWriteStruct( fs, "index", CV_NODE_SEQ+CV_NODE_FLOW );
2101 cvWriteRawData( fs, df[i].sv_index, df[i].sv_count, "i" );
2102 cvEndWriteStruct( fs );
2105 CV_ASSERT( sv_count == sv_total );
2106 cvEndWriteStruct( fs );
2108 cvEndWriteStruct( fs );
2109 cvEndWriteStruct( fs );
2115 void CvSVM::read_params( CvFileStorage* fs, CvFileNode* svm_node )
2117 CV_FUNCNAME( "CvSVM::read_params" );
2121 int svm_type, kernel_type;
2122 CvSVMParams _params;
2124 CvFileNode* tmp_node = cvGetFileNodeByName( fs, svm_node, "svm_type" );
2125 CvFileNode* kernel_node;
2127 CV_ERROR( CV_StsBadArg, "svm_type tag is not found" );
2129 if( CV_NODE_TYPE(tmp_node->tag) == CV_NODE_INT )
2130 svm_type = cvReadInt( tmp_node, -1 );
2133 const char* svm_type_str = cvReadString( tmp_node, "" );
2135 strcmp( svm_type_str, "C_SVC" ) == 0 ? CvSVM::C_SVC :
2136 strcmp( svm_type_str, "NU_SVC" ) == 0 ? CvSVM::NU_SVC :
2137 strcmp( svm_type_str, "ONE_CLASS" ) == 0 ? CvSVM::ONE_CLASS :
2138 strcmp( svm_type_str, "EPS_SVR" ) == 0 ? CvSVM::EPS_SVR :
2139 strcmp( svm_type_str, "NU_SVR" ) == 0 ? CvSVM::NU_SVR : -1;
2142 CV_ERROR( CV_StsParseError, "Missing of invalid SVM type" );
2145 kernel_node = cvGetFileNodeByName( fs, svm_node, "kernel" );
2147 CV_ERROR( CV_StsParseError, "SVM kernel tag is not found" );
2149 tmp_node = cvGetFileNodeByName( fs, kernel_node, "type" );
2151 CV_ERROR( CV_StsParseError, "SVM kernel type tag is not found" );
2153 if( CV_NODE_TYPE(tmp_node->tag) == CV_NODE_INT )
2154 kernel_type = cvReadInt( tmp_node, -1 );
2157 const char* kernel_type_str = cvReadString( tmp_node, "" );
2159 strcmp( kernel_type_str, "LINEAR" ) == 0 ? CvSVM::LINEAR :
2160 strcmp( kernel_type_str, "POLY" ) == 0 ? CvSVM::POLY :
2161 strcmp( kernel_type_str, "RBF" ) == 0 ? CvSVM::RBF :
2162 strcmp( kernel_type_str, "SIGMOID" ) == 0 ? CvSVM::SIGMOID : -1;
2164 if( kernel_type < 0 )
2165 CV_ERROR( CV_StsParseError, "Missing of invalid SVM kernel type" );
2168 _params.svm_type = svm_type;
2169 _params.kernel_type = kernel_type;
2170 _params.degree = cvReadRealByName( fs, kernel_node, "degree", 0 );
2171 _params.gamma = cvReadRealByName( fs, kernel_node, "gamma", 0 );
2172 _params.coef0 = cvReadRealByName( fs, kernel_node, "coef0", 0 );
2174 _params.C = cvReadRealByName( fs, svm_node, "C", 0 );
2175 _params.nu = cvReadRealByName( fs, svm_node, "nu", 0 );
2176 _params.p = cvReadRealByName( fs, svm_node, "p", 0 );
2177 _params.class_weights = 0;
2179 tmp_node = cvGetFileNodeByName( fs, svm_node, "term_criteria" );
2182 _params.term_crit.epsilon = cvReadRealByName( fs, tmp_node, "epsilon", -1. );
2183 _params.term_crit.max_iter = cvReadIntByName( fs, tmp_node, "iterations", -1 );
2184 _params.term_crit.type = (_params.term_crit.epsilon >= 0 ? CV_TERMCRIT_EPS : 0) +
2185 (_params.term_crit.max_iter >= 0 ? CV_TERMCRIT_ITER : 0);
2188 _params.term_crit = cvTermCriteria( CV_TERMCRIT_EPS + CV_TERMCRIT_ITER, 1000, FLT_EPSILON );
2190 set_params( _params );
2196 void CvSVM::read( CvFileStorage* fs, CvFileNode* svm_node )
2198 const double not_found_dbl = DBL_MAX;
2200 CV_FUNCNAME( "CvSVM::read" );
2204 int i, var_count, df_count, class_count;
2205 int block_size = 1 << 16, sv_size;
2206 CvFileNode *sv_node, *df_node;
2207 CvSVMDecisionFunc* df;
2211 CV_ERROR( CV_StsParseError, "The requested element is not found" );
2215 // read SVM parameters
2216 read_params( fs, svm_node );
2218 // and top-level data
2219 sv_total = cvReadIntByName( fs, svm_node, "sv_total", -1 );
2220 var_all = cvReadIntByName( fs, svm_node, "var_all", -1 );
2221 var_count = cvReadIntByName( fs, svm_node, "var_count", var_all );
2222 class_count = cvReadIntByName( fs, svm_node, "class_count", 0 );
2224 if( sv_total <= 0 || var_all <= 0 || var_count <= 0 || var_count > var_all || class_count < 0 )
2225 CV_ERROR( CV_StsParseError, "SVM model data is invalid, check sv_count, var_* and class_count tags" );
2227 CV_CALL( class_labels = (CvMat*)cvReadByName( fs, svm_node, "class_labels" ));
2228 CV_CALL( class_weights = (CvMat*)cvReadByName( fs, svm_node, "class_weights" ));
2229 CV_CALL( var_idx = (CvMat*)cvReadByName( fs, svm_node, "comp_idx" ));
2231 if( class_count > 1 && (!class_labels ||
2232 !CV_IS_MAT(class_labels) || class_labels->cols != class_count))
2233 CV_ERROR( CV_StsParseError, "Array of class labels is missing or invalid" );
2235 if( var_count < var_all && (!var_idx || !CV_IS_MAT(var_idx) || var_idx->cols != var_count) )
2236 CV_ERROR( CV_StsParseError, "var_idx array is missing or invalid" );
2238 // read support vectors
2239 sv_node = cvGetFileNodeByName( fs, svm_node, "support_vectors" );
2240 if( !sv_node || !CV_NODE_IS_SEQ(sv_node->tag))
2241 CV_ERROR( CV_StsParseError, "Missing or invalid sequence of support vectors" );
2243 block_size = MAX( block_size, sv_total*(int)sizeof(CvSVMKernelRow));
2244 block_size = MAX( block_size, sv_total*2*(int)sizeof(double));
2245 block_size = MAX( block_size, var_all*(int)sizeof(double));
2246 CV_CALL( storage = cvCreateMemStorage( block_size ));
2247 CV_CALL( sv = (float**)cvMemStorageAlloc( storage,
2248 sv_total*sizeof(sv[0]) ));
2250 CV_CALL( cvStartReadSeq( sv_node->data.seq, &reader, 0 ));
2251 sv_size = var_count*sizeof(sv[0][0]);
2253 for( i = 0; i < sv_total; i++ )
2255 CvFileNode* sv_elem = (CvFileNode*)reader.ptr;
2256 CV_ASSERT( var_count == 1 || (CV_NODE_IS_SEQ(sv_elem->tag) &&
2257 sv_elem->data.seq->total == var_count) );
2259 CV_CALL( sv[i] = (float*)cvMemStorageAlloc( storage, sv_size ));
2260 CV_CALL( cvReadRawData( fs, sv_elem, sv[i], "f" ));
2261 CV_NEXT_SEQ_ELEM( sv_node->data.seq->elem_size, reader );
2264 // read decision functions
2265 df_count = class_count > 1 ? class_count*(class_count-1)/2 : 1;
2266 df_node = cvGetFileNodeByName( fs, svm_node, "decision_functions" );
2267 if( !df_node || !CV_NODE_IS_SEQ(df_node->tag) ||
2268 df_node->data.seq->total != df_count )
2269 CV_ERROR( CV_StsParseError, "decision_functions is missing or is not a collection "
2270 "or has a wrong number of elements" );
2272 CV_CALL( df = decision_func = (CvSVMDecisionFunc*)cvAlloc( df_count*sizeof(df[0]) ));
2273 cvStartReadSeq( df_node->data.seq, &reader, 0 );
2275 for( i = 0; i < df_count; i++ )
2277 CvFileNode* df_elem = (CvFileNode*)reader.ptr;
2278 CvFileNode* alpha_node = cvGetFileNodeByName( fs, df_elem, "alpha" );
2280 int sv_count = cvReadIntByName( fs, df_elem, "sv_count", -1 );
2282 CV_ERROR( CV_StsParseError, "sv_count is missing or non-positive" );
2283 df[i].sv_count = sv_count;
2285 df[i].rho = cvReadRealByName( fs, df_elem, "rho", not_found_dbl );
2286 if( fabs(df[i].rho - not_found_dbl) < DBL_EPSILON )
2287 CV_ERROR( CV_StsParseError, "rho is missing" );
2290 CV_ERROR( CV_StsParseError, "alpha is missing in the decision function" );
2292 CV_CALL( df[i].alpha = (double*)cvMemStorageAlloc( storage,
2293 sv_count*sizeof(df[i].alpha[0])));
2294 CV_ASSERT( sv_count == 1 || CV_NODE_IS_SEQ(alpha_node->tag) &&
2295 alpha_node->data.seq->total == sv_count );
2296 CV_CALL( cvReadRawData( fs, alpha_node, df[i].alpha, "d" ));
2298 if( class_count > 1 )
2300 CvFileNode* index_node = cvGetFileNodeByName( fs, df_elem, "index" );
2302 CV_ERROR( CV_StsParseError, "index is missing in the decision function" );
2303 CV_CALL( df[i].sv_index = (int*)cvMemStorageAlloc( storage,
2304 sv_count*sizeof(df[i].sv_index[0])));
2305 CV_ASSERT( sv_count == 1 || CV_NODE_IS_SEQ(index_node->tag) &&
2306 index_node->data.seq->total == sv_count );
2307 CV_CALL( cvReadRawData( fs, index_node, df[i].sv_index, "i" ));
2312 CV_NEXT_SEQ_ELEM( df_node->data.seq->elem_size, reader );
2323 icvCloneSVM( const void* _src )
2325 CvSVMModel* dst = 0;
2327 CV_FUNCNAME( "icvCloneSVM" );
2331 const CvSVMModel* src = (const CvSVMModel*)_src;
2332 int var_count, class_count;
2333 int i, sv_total, df_count;
2336 if( !CV_IS_SVM(src) )
2337 CV_ERROR( !src ? CV_StsNullPtr : CV_StsBadArg, "Input pointer is NULL or invalid" );
2339 // 0. create initial CvSVMModel structure
2340 CV_CALL( dst = icvCreateSVM() );
2341 dst->params = src->params;
2342 dst->params.weight_labels = 0;
2343 dst->params.weights = 0;
2345 dst->var_all = src->var_all;
2346 if( src->class_labels )
2347 dst->class_labels = cvCloneMat( src->class_labels );
2348 if( src->class_weights )
2349 dst->class_weights = cvCloneMat( src->class_weights );
2351 dst->comp_idx = cvCloneMat( src->comp_idx );
2353 var_count = src->comp_idx ? src->comp_idx->cols : src->var_all;
2354 class_count = src->class_labels ? src->class_labels->cols :
2355 src->params.svm_type == CvSVM::ONE_CLASS ? 1 : 0;
2356 sv_total = dst->sv_total = src->sv_total;
2357 CV_CALL( dst->storage = cvCreateMemStorage( src->storage->block_size ));
2358 CV_CALL( dst->sv = (float**)cvMemStorageAlloc( dst->storage,
2359 sv_total*sizeof(dst->sv[0]) ));
2361 sv_size = var_count*sizeof(dst->sv[0][0]);
2363 for( i = 0; i < sv_total; i++ )
2365 CV_CALL( dst->sv[i] = (float*)cvMemStorageAlloc( dst->storage, sv_size ));
2366 memcpy( dst->sv[i], src->sv[i], sv_size );
2369 df_count = class_count > 1 ? class_count*(class_count-1)/2 : 1;
2371 CV_CALL( dst->decision_func = cvAlloc( df_count*sizeof(CvSVMDecisionFunc) ));
2373 for( i = 0; i < df_count; i++ )
2375 const CvSVMDecisionFunc *sdf =
2376 (const CvSVMDecisionFunc*)src->decision_func+i;
2377 CvSVMDecisionFunc *ddf =
2378 (CvSVMDecisionFunc*)dst->decision_func+i;
2379 int sv_count = sdf->sv_count;
2380 ddf->sv_count = sv_count;
2381 ddf->rho = sdf->rho;
2382 CV_CALL( ddf->alpha = (double*)cvMemStorageAlloc( dst->storage,
2383 sv_count*sizeof(ddf->alpha[0])));
2384 memcpy( ddf->alpha, sdf->alpha, sv_count*sizeof(ddf->alpha[0]));
2386 if( class_count > 1 )
2388 CV_CALL( ddf->sv_index = (int*)cvMemStorageAlloc( dst->storage,
2389 sv_count*sizeof(ddf->sv_index[0])));
2390 memcpy( ddf->sv_index, sdf->sv_index, sv_count*sizeof(ddf->sv_index[0]));
2398 if( cvGetErrStatus() < 0 && dst )
2399 icvReleaseSVM( &dst );
2404 static int icvRegisterSVMType()
2407 memset( &info, 0, sizeof(info) );
2410 info.header_size = sizeof( info );
2411 info.is_instance = icvIsSVM;
2412 info.release = (CvReleaseFunc)icvReleaseSVM;
2413 info.read = icvReadSVM;
2414 info.write = icvWriteSVM;
2415 info.clone = icvCloneSVM;
2416 info.type_name = CV_TYPE_NAME_ML_SVM;
2417 cvRegisterType( &info );
2423 static int svm = icvRegisterSVMType();
2425 /* The function trains SVM model with optimal parameters, obtained by using cross-validation.
2426 The parameters to be estimated should be indicated by setting theirs values to FLT_MAX.
2427 The optimal parameters are saved in <model_params> */
2428 CV_IMPL CvStatModel*
2429 cvTrainSVM_CrossValidation( const CvMat* train_data, int tflag,
2430 const CvMat* responses,
2431 CvStatModelParams* model_params,
2432 const CvStatModelParams* cross_valid_params,
2433 const CvMat* comp_idx,
2434 const CvMat* sample_idx,
2435 const CvParamGrid* degree_grid,
2436 const CvParamGrid* gamma_grid,
2437 const CvParamGrid* coef_grid,
2438 const CvParamGrid* C_grid,
2439 const CvParamGrid* nu_grid,
2440 const CvParamGrid* p_grid )
2442 CvStatModel* svm = 0;
2444 CV_FUNCNAME("cvTainSVMCrossValidation");
2447 double degree_step = 7,
2452 p_step = 7; // all steps must be > 1
2453 double degree_begin = 0.01, degree_end = 2;
2454 double g_begin = 1e-5, g_end = 0.5;
2455 double coef_begin = 0.1, coef_end = 300;
2456 double C_begin = 0.1, C_end = 6000;
2457 double nu_begin = 0.01, nu_end = 0.4;
2458 double p_begin = 0.01, p_end = 100;
2460 double rate = 0, gamma = 0, C = 0, degree = 0, coef = 0, p = 0, nu = 0;
2462 double best_rate = 0;
2463 double best_degree = degree_begin;
2464 double best_gamma = g_begin;
2465 double best_coef = coef_begin;
2466 double best_C = C_begin;
2467 double best_nu = nu_begin;
2468 double best_p = p_begin;
2470 CvSVMModelParams svm_params, *psvm_params;
2471 CvCrossValidationParams* cv_params = (CvCrossValidationParams*)cross_valid_params;
2472 int svm_type, kernel;
2476 CV_ERROR( CV_StsBadArg, "" );
2478 CV_ERROR( CV_StsBadArg, "" );
2480 svm_params = *(CvSVMModelParams*)model_params;
2481 psvm_params = (CvSVMModelParams*)model_params;
2482 svm_type = svm_params.svm_type;
2483 kernel = svm_params.kernel_type;
2485 svm_params.degree = svm_params.degree > 0 ? svm_params.degree : 1;
2486 svm_params.gamma = svm_params.gamma > 0 ? svm_params.gamma : 1;
2487 svm_params.coef0 = svm_params.coef0 > 0 ? svm_params.coef0 : 1e-6;
2488 svm_params.C = svm_params.C > 0 ? svm_params.C : 1;
2489 svm_params.nu = svm_params.nu > 0 ? svm_params.nu : 1;
2490 svm_params.p = svm_params.p > 0 ? svm_params.p : 1;
2494 if( !(degree_grid->max_val == 0 && degree_grid->min_val == 0 &&
2495 degree_grid->step == 0) )
2497 if( degree_grid->min_val > degree_grid->max_val )
2498 CV_ERROR( CV_StsBadArg,
2499 "low bound of grid should be less then the upper one");
2500 if( degree_grid->step <= 1 )
2501 CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
2502 degree_begin = degree_grid->min_val;
2503 degree_end = degree_grid->max_val;
2504 degree_step = degree_grid->step;
2508 degree_begin = degree_end = svm_params.degree;
2512 if( !(gamma_grid->max_val == 0 && gamma_grid->min_val == 0 &&
2513 gamma_grid->step == 0) )
2515 if( gamma_grid->min_val > gamma_grid->max_val )
2516 CV_ERROR( CV_StsBadArg,
2517 "low bound of grid should be less then the upper one");
2518 if( gamma_grid->step <= 1 )
2519 CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
2520 g_begin = gamma_grid->min_val;
2521 g_end = gamma_grid->max_val;
2522 g_step = gamma_grid->step;
2526 g_begin = g_end = svm_params.gamma;
2530 if( !(coef_grid->max_val == 0 && coef_grid->min_val == 0 &&
2531 coef_grid->step == 0) )
2533 if( coef_grid->min_val > coef_grid->max_val )
2534 CV_ERROR( CV_StsBadArg,
2535 "low bound of grid should be less then the upper one");
2536 if( coef_grid->step <= 1 )
2537 CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
2538 coef_begin = coef_grid->min_val;
2539 coef_end = coef_grid->max_val;
2540 coef_step = coef_grid->step;
2544 coef_begin = coef_end = svm_params.coef0;
2548 if( !(C_grid->max_val == 0 && C_grid->min_val == 0 && C_grid->step == 0))
2550 if( C_grid->min_val > C_grid->max_val )
2551 CV_ERROR( CV_StsBadArg,
2552 "low bound of grid should be less then the upper one");
2553 if( C_grid->step <= 1 )
2554 CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
2555 C_begin = C_grid->min_val;
2556 C_end = C_grid->max_val;
2557 C_step = C_grid->step;
2561 C_begin = C_end = svm_params.C;
2565 if(!(nu_grid->max_val == 0 && nu_grid->min_val == 0 && nu_grid->step==0))
2567 if( nu_grid->min_val > nu_grid->max_val )
2568 CV_ERROR( CV_StsBadArg,
2569 "low bound of grid should be less then the upper one");
2570 if( nu_grid->step <= 1 )
2571 CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
2572 nu_begin = nu_grid->min_val;
2573 nu_end = nu_grid->max_val;
2574 nu_step = nu_grid->step;
2578 nu_begin = nu_end = svm_params.nu;
2582 if( !(p_grid->max_val == 0 && p_grid->min_val == 0 && p_grid->step == 0))
2584 if( p_grid->min_val > p_grid->max_val )
2585 CV_ERROR( CV_StsBadArg,
2586 "low bound of grid should be less then the upper one");
2587 if( p_grid->step <= 1 )
2588 CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
2589 p_begin = p_grid->min_val;
2590 p_end = p_grid->max_val;
2591 p_step = p_grid->step;
2595 p_begin = p_end = svm_params.p;
2597 // these parameters are not used:
2598 if( kernel != CvSVM::POLY )
2599 degree_begin = degree_end = svm_params.degree;
2601 if( kernel == CvSVM::LINEAR )
2602 g_begin = g_end = svm_params.gamma;
2604 if( kernel != CvSVM::POLY && kernel != CvSVM::SIGMOID )
2605 coef_begin = coef_end = svm_params.coef0;
2607 if( svm_type == CvSVM::NU_SVC || svm_type == CvSVM::ONE_CLASS )
2608 C_begin = C_end = svm_params.C;
2610 if( svm_type == CvSVM::C_SVC || svm_type == CvSVM::EPS_SVR )
2611 nu_begin = nu_end = svm_params.nu;
2613 if( svm_type != CvSVM::EPS_SVR )
2614 p_begin = p_end = svm_params.p;
2616 is_regression = cv_params->is_regression;
2617 best_rate = is_regression ? FLT_MAX : 0;
2619 assert( g_step > 1 && degree_step > 1 && coef_step > 1);
2620 assert( p_step > 1 && C_step > 1 && nu_step > 1 );
2622 for( degree = degree_begin; degree <= degree_end; degree *= degree_step )
2624 svm_params.degree = degree;
2625 //printf("degree = %.3f\n", degree );
2626 for( gamma= g_begin; gamma <= g_end; gamma *= g_step )
2628 svm_params.gamma = gamma;
2629 //printf(" gamma = %.3f\n", gamma );
2630 for( coef = coef_begin; coef <= coef_end; coef *= coef_step )
2632 svm_params.coef0 = coef;
2633 //printf(" coef = %.3f\n", coef );
2634 for( C = C_begin; C <= C_end; C *= C_step )
2637 //printf(" C = %.3f\n", C );
2638 for( nu = nu_begin; nu <= nu_end; nu *= nu_step )
2641 //printf(" nu = %.3f\n", nu );
2642 for( p = p_begin; p <= p_end; p *= p_step )
2646 //printf(" p = %.3f\n", p );
2648 CV_CALL(rate = cvCrossValidation( train_data, tflag, responses, &cvTrainSVM,
2649 cross_valid_params, (CvStatModelParams*)&svm_params, comp_idx, sample_idx ));
2651 well = rate > best_rate && !is_regression || rate < best_rate && is_regression;
2652 if( well || (rate == best_rate && C < best_C) )
2655 best_degree = degree;
2662 //printf(" rate = %.2f\n", rate );
2669 //printf("The best:\nrate = %.2f%% degree = %f gamma = %f coef = %f c = %f nu = %f p = %f\n",
2670 // best_rate, best_degree, best_gamma, best_coef, best_C, best_nu, best_p );
2672 psvm_params->C = best_C;
2673 psvm_params->nu = best_nu;
2674 psvm_params->p = best_p;
2675 psvm_params->gamma = best_gamma;
2676 psvm_params->degree = best_degree;
2677 psvm_params->coef0 = best_coef;
2679 CV_CALL(svm = cvTrainSVM( train_data, tflag, responses, model_params, comp_idx, sample_idx ));