Move the sources to trunk
[opencv] / ml / src / mlsvm.cpp
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                        Intel License Agreement
11 //
12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
13 // Third party copyrights are property of their respective owners.
14 //
15 // Redistribution and use in source and binary forms, with or without modification,
16 // are permitted provided that the following conditions are met:
17 //
18 //   * Redistribution's of source code must retain the above copyright notice,
19 //     this list of conditions and the following disclaimer.
20 //
21 //   * Redistribution's in binary form must reproduce the above copyright notice,
22 //     this list of conditions and the following disclaimer in the documentation
23 //     and/or other materials provided with the distribution.
24 //
25 //   * The name of Intel Corporation may not be used to endorse or promote products
26 //     derived from this software without specific prior written permission.
27 //
28 // This software is provided by the copyright holders and contributors "as is" and
29 // any express or implied warranties, including, but not limited to, the implied
30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
31 // In no event shall the Intel Corporation or contributors be liable for any direct,
32 // indirect, incidental, special, exemplary, or consequential damages
33 // (including, but not limited to, procurement of substitute goods or services;
34 // loss of use, data, or profits; or business interruption) however caused
35 // and on any theory of liability, whether in contract, strict liability,
36 // or tort (including negligence or otherwise) arising in any way out of
37 // the use of this software, even if advised of the possibility of such damage.
38 //
39 //M*/
40
41 #include "_ml.h"
42
43 /****************************************************************************************\
44                                 COPYRIGHT NOTICE
45                                 ----------------
46
47   The code has been derived from libsvm library (version 2.6)
48   (http://www.csie.ntu.edu.tw/~cjlin/libsvm).
49
50   Here is the orignal copyright:
51 ------------------------------------------------------------------------------------------
52     Copyright (c) 2000-2003 Chih-Chung Chang and Chih-Jen Lin
53     All rights reserved.
54
55     Redistribution and use in source and binary forms, with or without
56     modification, are permitted provided that the following conditions
57     are met:
58
59     1. Redistributions of source code must retain the above copyright
60     notice, this list of conditions and the following disclaimer.
61
62     2. Redistributions in binary form must reproduce the above copyright
63     notice, this list of conditions and the following disclaimer in the
64     documentation and/or other materials provided with the distribution.
65
66     3. Neither name of copyright holders nor the names of its contributors
67     may be used to endorse or promote products derived from this software
68     without specific prior written permission.
69
70
71     THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
72     ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
73     LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
74     A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE REGENTS OR
75     CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
76     EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
77     PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
78     PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
79     LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
80     NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
81     SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
82 \****************************************************************************************/
83
84 #define CV_SVM_MIN_CACHE_SIZE  (40 << 20)  /* 40Mb */
85
86 #include <stdarg.h>
87 #include <ctype.h>
88
89 #if _MSC_VER >= 1200
90 #pragma warning( disable: 4514 ) /* unreferenced inline functions */
91 #endif
92
93 #if 1
94 typedef float Qfloat;
95 #define QFLOAT_TYPE CV_32F
96 #else
97 typedef double Qfloat;
98 #define QFLOAT_TYPE CV_64F
99 #endif
100
101 // Param Grid 
102 bool CvParamGrid::check() const
103 {
104     bool ok = false;
105
106     CV_FUNCNAME( "CvParamGrid::check" );
107     __BEGIN__;
108
109     if( min_val > max_val )
110         CV_ERROR( CV_StsBadArg, "Lower bound of the grid must be less then the upper one" );
111     if( min_val < DBL_EPSILON )
112         CV_ERROR( CV_StsBadArg, "Lower bound of the grid must be positive" );
113     if( step < 1. + FLT_EPSILON )
114         CV_ERROR( CV_StsBadArg, "Grid step must greater then 1" );
115
116     ok = true;
117
118     __END__;
119
120     return ok;
121 }
122
123 CvParamGrid CvSVM::get_default_grid( int param_id )
124 {
125     CvParamGrid grid;
126     if( param_id == CvSVM::C )
127     {
128         grid.min_val = 0.1;
129         grid.max_val = 500;
130         grid.step = 5; // total iterations = 5
131     }
132     else if( param_id == CvSVM::GAMMA )
133     {
134         grid.min_val = 1e-5;
135         grid.max_val = 0.6;
136         grid.step = 15; // total iterations = 4
137     }
138     else if( param_id == CvSVM::P )
139     {
140         grid.min_val = 0.01;
141         grid.max_val = 100;
142         grid.step = 7; // total iterations = 4
143     }
144     else if( param_id == CvSVM::NU )
145     {
146         grid.min_val = 0.01;
147         grid.max_val = 0.2;
148         grid.step = 3; // total iterations = 3
149     }
150     else if( param_id == CvSVM::COEF )
151     {
152         grid.min_val = 0.1;
153         grid.max_val = 300;
154         grid.step = 14; // total iterations = 3
155     }
156     else if( param_id == CvSVM::DEGREE )
157     {
158         grid.min_val = 0.01;
159         grid.max_val = 4;
160         grid.step = 7; // total iterations = 3
161     }
162     else
163         cvError( CV_StsBadArg, "CvSVM::get_default_grid", "Invalid type of parameter "
164             "(use one of CvSVM::C, CvSVM::GAMMA et al.)", __FILE__, __LINE__ );
165     return grid;
166 }
167
168 // SVM training parameters
169 CvSVMParams::CvSVMParams() :
170     svm_type(CvSVM::C_SVC), kernel_type(CvSVM::RBF), degree(0),
171     gamma(1), coef0(0), C(1), nu(0), p(0), class_weights(0)
172 {
173     term_crit = cvTermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 1000, FLT_EPSILON );
174 }
175
176
177 CvSVMParams::CvSVMParams( int _svm_type, int _kernel_type,
178     double _degree, double _gamma, double _coef0,
179     double _Con, double _nu, double _p,
180     CvMat* _class_weights, CvTermCriteria _term_crit ) :
181     svm_type(_svm_type), kernel_type(_kernel_type),
182     degree(_degree), gamma(_gamma), coef0(_coef0),
183     C(_Con), nu(_nu), p(_p), class_weights(_class_weights), term_crit(_term_crit)
184 {
185 }
186
187
188 /////////////////////////////////////// SVM kernel ///////////////////////////////////////
189
190 CvSVMKernel::CvSVMKernel()
191 {
192     clear();
193 }
194
195
196 void CvSVMKernel::clear()
197 {
198     params = 0;
199     calc_func = 0;
200 }
201
202
203 CvSVMKernel::~CvSVMKernel()
204 {
205 }
206
207
208 CvSVMKernel::CvSVMKernel( const CvSVMParams* _params, Calc _calc_func )
209 {
210     clear();
211     create( _params, _calc_func );
212 }
213
214
215 bool CvSVMKernel::create( const CvSVMParams* _params, Calc _calc_func )
216 {
217     clear();
218     params = _params;
219     calc_func = _calc_func;
220
221     if( !calc_func )
222         calc_func = params->kernel_type == CvSVM::RBF ? &CvSVMKernel::calc_rbf :
223                     params->kernel_type == CvSVM::POLY ? &CvSVMKernel::calc_poly :
224                     params->kernel_type == CvSVM::SIGMOID ? &CvSVMKernel::calc_sigmoid :
225                     &CvSVMKernel::calc_linear;
226
227     return true;
228 }
229
230
231 void CvSVMKernel::calc_non_rbf_base( int vcount, int var_count, const float** vecs,
232                                      const float* another, Qfloat* results,
233                                      double alpha, double beta )
234 {
235     int j, k;
236     for( j = 0; j < vcount; j++ )
237     {
238         const float* sample = vecs[j];
239         double s = 0;
240         for( k = 0; k <= var_count - 4; k += 4 )
241             s += sample[k]*another[k] + sample[k+1]*another[k+1] +
242                  sample[k+2]*another[k+2] + sample[k+3]*another[k+3];
243         for( ; k < var_count; k++ )
244             s += sample[k]*another[k];
245         results[j] = (Qfloat)(s*alpha + beta);
246     }
247 }
248
249
250 void CvSVMKernel::calc_linear( int vcount, int var_count, const float** vecs,
251                                const float* another, Qfloat* results )
252 {
253     calc_non_rbf_base( vcount, var_count, vecs, another, results, 1, 0 );
254 }
255
256
257 void CvSVMKernel::calc_poly( int vcount, int var_count, const float** vecs,
258                              const float* another, Qfloat* results )
259 {
260     CvMat R = cvMat( 1, vcount, QFLOAT_TYPE, results );
261     calc_non_rbf_base( vcount, var_count, vecs, another, results, params->gamma, params->coef0 );
262     cvPow( &R, &R, params->degree );
263 }
264
265
266 void CvSVMKernel::calc_sigmoid( int vcount, int var_count, const float** vecs,
267                                 const float* another, Qfloat* results )
268 {
269     int j;
270     calc_non_rbf_base( vcount, var_count, vecs, another, results,
271                        -2*params->gamma, -2*params->coef0 );
272     // TODO: speedup this
273     for( j = 0; j < vcount; j++ )
274     {
275         Qfloat t = results[j];
276         double e = exp(-fabs(t));
277         if( t > 0 )
278             results[j] = (Qfloat)((1. - e)/(1. + e));
279         else
280             results[j] = (Qfloat)((e - 1.)/(e + 1.));
281     }
282 }
283
284
285 void CvSVMKernel::calc_rbf( int vcount, int var_count, const float** vecs,
286                             const float* another, Qfloat* results )
287 {
288     CvMat R = cvMat( 1, vcount, QFLOAT_TYPE, results );
289     double gamma = -params->gamma;
290     int j, k;
291
292     for( j = 0; j < vcount; j++ )
293     {
294         const float* sample = vecs[j];
295         double s = 0;
296         
297         for( k = 0; k <= var_count - 4; k += 4 )
298         {
299             double t0 = sample[k] - another[k];
300             double t1 = sample[k+1] - another[k+1];
301
302             s += t0*t0 + t1*t1;
303
304             t0 = sample[k+2] - another[k+2];
305             t1 = sample[k+3] - another[k+3];
306
307             s += t0*t0 + t1*t1;
308         }
309
310         for( ; k < var_count; k++ )
311         {
312             double t0 = sample[k] - another[k];
313             s += t0*t0;
314         }
315         results[j] = (Qfloat)(s*gamma);
316     }
317     
318     cvExp( &R, &R );
319 }
320
321
322 void CvSVMKernel::calc( int vcount, int var_count, const float** vecs,
323                         const float* another, Qfloat* results )
324 {
325     const Qfloat max_val = (Qfloat)(FLT_MAX*1e-3);
326     int j;
327     (this->*calc_func)( vcount, var_count, vecs, another, results );
328     for( j = 0; j < vcount; j++ )
329     {
330         if( results[j] > max_val )
331             results[j] = max_val;
332     }
333 }
334
335
336 // Generalized SMO+SVMlight algorithm
337 // Solves:
338 //
339 //  min [0.5(\alpha^T Q \alpha) + b^T \alpha]
340 //
341 //      y^T \alpha = \delta
342 //      y_i = +1 or -1
343 //      0 <= alpha_i <= Cp for y_i = 1
344 //      0 <= alpha_i <= Cn for y_i = -1
345 //
346 // Given:
347 //
348 //  Q, b, y, Cp, Cn, and an initial feasible point \alpha
349 //  l is the size of vectors and matrices
350 //  eps is the stopping criterion
351 //
352 // solution will be put in \alpha, objective value will be put in obj
353 //
354
355 void CvSVMSolver::clear()
356 {
357     G = 0;
358     alpha = 0;
359     y = 0;
360     b = 0;
361     buf[0] = buf[1] = 0;
362     cvReleaseMemStorage( &storage );
363     kernel = 0;
364     select_working_set_func = 0;
365     calc_rho_func = 0;
366
367     rows = 0;
368     samples = 0;
369     get_row_func = 0;
370 }
371
372
373 CvSVMSolver::CvSVMSolver()
374 {
375     storage = 0;
376     clear();
377 }
378
379
380 CvSVMSolver::~CvSVMSolver()
381 {
382     clear();
383 }
384
385
386 CvSVMSolver::CvSVMSolver( int _sample_count, int _var_count, const float** _samples, char* _y,
387                 int _alpha_count, double* _alpha, double _Cp, double _Cn,
388                 CvMemStorage* _storage, CvSVMKernel* _kernel, GetRow _get_row,
389                 SelectWorkingSet _select_working_set, CalcRho _calc_rho )
390 {
391     storage = 0;
392     create( _sample_count, _var_count, _samples, _y, _alpha_count, _alpha, _Cp, _Cn,
393             _storage, _kernel, _get_row, _select_working_set, _calc_rho );
394 }
395
396
397 bool CvSVMSolver::create( int _sample_count, int _var_count, const float** _samples, char* _y,
398                 int _alpha_count, double* _alpha, double _Cp, double _Cn,
399                 CvMemStorage* _storage, CvSVMKernel* _kernel, GetRow _get_row,
400                 SelectWorkingSet _select_working_set, CalcRho _calc_rho )
401 {
402     bool ok = false;
403     int i, svm_type;
404
405     CV_FUNCNAME( "CvSVMSolver::create" );
406
407     __BEGIN__;
408     
409     int rows_hdr_size;
410
411     clear();
412
413     sample_count = _sample_count;
414     var_count = _var_count;
415     samples = _samples;
416     y = _y;
417     alpha_count = _alpha_count;
418     alpha = _alpha;
419     kernel = _kernel;
420
421     C[0] = _Cn;
422     C[1] = _Cp;
423     eps = kernel->params->term_crit.epsilon;
424     max_iter = kernel->params->term_crit.max_iter;
425     storage = cvCreateChildMemStorage( _storage );
426
427     b = (double*)cvMemStorageAlloc( storage, alpha_count*sizeof(b[0]));
428     alpha_status = (char*)cvMemStorageAlloc( storage, alpha_count*sizeof(alpha_status[0]));
429     G = (double*)cvMemStorageAlloc( storage, alpha_count*sizeof(G[0]));
430     for( i = 0; i < 2; i++ )
431         buf[i] = (Qfloat*)cvMemStorageAlloc( storage, sample_count*2*sizeof(buf[i][0]) );
432     svm_type = kernel->params->svm_type;
433
434     select_working_set_func = _select_working_set;
435     if( !select_working_set_func )
436         select_working_set_func = svm_type == CvSVM::NU_SVC || svm_type == CvSVM::NU_SVR ?
437         &CvSVMSolver::select_working_set_nu_svm : &CvSVMSolver::select_working_set;
438
439     calc_rho_func = _calc_rho;
440     if( !calc_rho_func )
441         calc_rho_func = svm_type == CvSVM::NU_SVC || svm_type == CvSVM::NU_SVR ?
442             &CvSVMSolver::calc_rho_nu_svm : &CvSVMSolver::calc_rho;
443
444     get_row_func = _get_row;
445     if( !get_row_func )
446         get_row_func = params->svm_type == CvSVM::EPS_SVR ||
447                        params->svm_type == CvSVM::NU_SVR ? &CvSVMSolver::get_row_svr :
448                        params->svm_type == CvSVM::C_SVC ||
449                        params->svm_type == CvSVM::NU_SVC ? &CvSVMSolver::get_row_svc :
450                        &CvSVMSolver::get_row_one_class;
451
452     cache_line_size = sample_count*sizeof(Qfloat);
453     // cache size = max(num_of_samples^2*sizeof(Qfloat)*0.25, 64Kb)
454     // (assuming that for large training sets ~25% of Q matrix is used)
455     cache_size = MAX( cache_line_size*sample_count/4, CV_SVM_MIN_CACHE_SIZE );
456     
457     // the size of Q matrix row headers
458     rows_hdr_size = sample_count*sizeof(rows[0]);
459     if( rows_hdr_size > storage->block_size )
460         CV_ERROR( CV_StsOutOfRange, "Too small storage block size" );
461
462     lru_list.prev = lru_list.next = &lru_list;
463     rows = (CvSVMKernelRow*)cvMemStorageAlloc( storage, rows_hdr_size );
464     memset( rows, 0, rows_hdr_size );
465
466     ok = true;
467
468     __END__;
469
470     return ok;
471 }
472
473
474 float* CvSVMSolver::get_row_base( int i, bool* _existed )
475 {
476     int i1 = i < sample_count ? i : i - sample_count;
477     CvSVMKernelRow* row = rows + i1;
478     bool existed = row->data != 0;
479     Qfloat* data;
480
481     if( existed || cache_size <= 0 )
482     {
483         CvSVMKernelRow* del_row = existed ? row : lru_list.prev;
484         data = del_row->data;
485         assert( data != 0 );
486
487         // delete row from the LRU list
488         del_row->data = 0;
489         del_row->prev->next = del_row->next;
490         del_row->next->prev = del_row->prev;
491     }
492     else
493     {
494         data = (Qfloat*)cvMemStorageAlloc( storage, cache_line_size );
495         cache_size -= cache_line_size;
496     }
497
498     // insert row into the LRU list
499     row->data = data;
500     row->prev = &lru_list;
501     row->next = lru_list.next;
502     row->prev->next = row->next->prev = row;
503
504     if( !existed )
505     {
506         kernel->calc( sample_count, var_count, samples, samples[i1], row->data );
507     }
508
509     if( _existed )
510         *_existed = existed;
511
512     return row->data;
513 }
514
515
516 float* CvSVMSolver::get_row_svc( int i, float* row, float*, bool existed )
517 {
518     if( !existed )
519     {
520         const char* _y = y;
521         int j, len = sample_count;
522         assert( _y && i < sample_count );
523
524         if( _y[i] > 0 )
525         {
526             for( j = 0; j < len; j++ )
527                 row[j] = _y[j]*row[j];
528         }
529         else
530         {
531             for( j = 0; j < len; j++ )
532                 row[j] = -_y[j]*row[j];
533         }
534     }
535     return row;
536 }
537
538
539 float* CvSVMSolver::get_row_one_class( int, float* row, float*, bool )
540 {
541     return row;
542 }
543
544
545 float* CvSVMSolver::get_row_svr( int i, float* row, float* dst, bool )
546 {
547     int j, len = sample_count;
548     Qfloat* dst_pos = dst;
549     Qfloat* dst_neg = dst + len;
550     if( i >= len )
551     {
552         Qfloat* temp;
553         CV_SWAP( dst_pos, dst_neg, temp );
554     }
555
556     for( j = 0; j < len; j++ )
557     {
558         Qfloat t = row[j];
559         dst_pos[j] = t;
560         dst_neg[j] = -t;
561     }
562     return dst;
563 }
564
565
566
567 float* CvSVMSolver::get_row( int i, float* dst )
568 {
569     bool existed = false;
570     float* row = get_row_base( i, &existed );
571     return (this->*get_row_func)( i, row, dst, existed );
572 }
573
574
575 #undef is_upper_bound
576 #define is_upper_bound(i) (alpha_status[i] > 0)
577
578 #undef is_lower_bound
579 #define is_lower_bound(i) (alpha_status[i] < 0)
580
581 #undef is_free
582 #define is_free(i) (alpha_status[i] == 0)
583
584 #undef get_C
585 #define get_C(i) (C[y[i]>0])
586
587 #undef update_alpha_status
588 #define update_alpha_status(i) \
589     alpha_status[i] = (char)(alpha[i] >= get_C(i) ? 1 : alpha[i] <= 0 ? -1 : 0)
590
591 #undef reconstruct_gradient
592 #define reconstruct_gradient() /* empty for now */
593
594
595 bool CvSVMSolver::solve_generic( CvSVMSolutionInfo& si )
596 {
597     int iter = 0;
598     int i, j, k;
599
600     // 1. initialize gradient and alpha status
601     for( i = 0; i < alpha_count; i++ )
602     {
603         update_alpha_status(i);
604         G[i] = b[i];
605         if( fabs(G[i]) > 1e200 )
606             return false;
607     }
608
609     for( i = 0; i < alpha_count; i++ )
610     {
611         if( !is_lower_bound(i) )
612         {
613             const Qfloat *Q_i = get_row( i, buf[0] );
614             double alpha_i = alpha[i];
615
616             for( j = 0; j < alpha_count; j++ )
617                 G[j] += alpha_i*Q_i[j];
618         }
619     }
620
621     // 2. optimization loop
622     for(;;)
623     {
624         const Qfloat *Q_i, *Q_j;
625         double C_i, C_j;
626         double old_alpha_i, old_alpha_j, alpha_i, alpha_j;
627         double delta_alpha_i, delta_alpha_j;
628         
629 #ifdef _DEBUG        
630         for( i = 0; i < alpha_count; i++ )
631         {
632             if( fabs(G[i]) > 1e+300 )
633                 return false;
634
635             if( fabs(alpha[i]) > 1e16 )
636                 return false;
637         }
638 #endif
639
640         if( (this->*select_working_set_func)( i, j ) != 0 || iter++ >= max_iter )
641             break;
642
643         Q_i = get_row( i, buf[0] );
644         Q_j = get_row( j, buf[1] );
645
646         C_i = get_C(i);
647         C_j = get_C(j);
648
649         alpha_i = old_alpha_i = alpha[i];
650         alpha_j = old_alpha_j = alpha[j];
651
652         if( y[i] != y[j] )
653         {
654             double denom = Q_i[i]+Q_j[j]+2*Q_i[j];
655             double delta = (-G[i]-G[j])/MAX(fabs(denom),FLT_EPSILON);
656             double diff = alpha_i - alpha_j;
657             alpha_i += delta;
658             alpha_j += delta;
659             
660             if( diff > 0 && alpha_j < 0 )
661             {
662                 alpha_j = 0;
663                 alpha_i = diff;
664             }
665             else if( diff <= 0 && alpha_i < 0 )
666             {
667                 alpha_i = 0;
668                 alpha_j = -diff;
669             }
670
671             if( diff > C_i - C_j && alpha_i > C_i )
672             {
673                 alpha_i = C_i;
674                 alpha_j = C_i - diff;
675             }
676             else if( diff <= C_i - C_j && alpha_j > C_j )
677             {
678                 alpha_j = C_j;
679                 alpha_i = C_j + diff;
680             }
681         }
682         else
683         {
684             double denom = Q_i[i]+Q_j[j]-2*Q_i[j];
685             double delta = (G[i]-G[j])/MAX(fabs(denom),FLT_EPSILON);
686             double sum = alpha_i + alpha_j;
687             alpha_i -= delta;
688             alpha_j += delta;
689             
690             if( sum > C_i && alpha_i > C_i )
691             {
692                 alpha_i = C_i;
693                 alpha_j = sum - C_i;
694             }
695             else if( sum <= C_i && alpha_j < 0)
696             {
697                 alpha_j = 0;
698                 alpha_i = sum;
699             }
700
701             if( sum > C_j && alpha_j > C_j )
702             {
703                 alpha_j = C_j;
704                 alpha_i = sum - C_j;
705             }
706             else if( sum <= C_j && alpha_i < 0 )
707             {
708                 alpha_i = 0;
709                 alpha_j = sum;
710             }
711         }
712
713         // update alpha
714         alpha[i] = alpha_i;
715         alpha[j] = alpha_j;
716         update_alpha_status(i);
717         update_alpha_status(j);
718
719         // update G
720         delta_alpha_i = alpha_i - old_alpha_i;
721         delta_alpha_j = alpha_j - old_alpha_j;
722         
723         for( k = 0; k < alpha_count; k++ )
724             G[k] += Q_i[k]*delta_alpha_i + Q_j[k]*delta_alpha_j;
725     }
726
727     // calculate rho
728     (this->*calc_rho_func)( si.rho, si.r );
729
730     // calculate objective value
731     for( i = 0, si.obj = 0; i < alpha_count; i++ )
732         si.obj += alpha[i] * (G[i] + b[i]);
733
734     si.obj *= 0.5;
735
736     si.upper_bound_p = C[1];
737     si.upper_bound_n = C[0];
738
739     return true;
740 }
741
742
743 // return 1 if already optimal, return 0 otherwise
744 bool
745 CvSVMSolver::select_working_set( int& out_i, int& out_j )
746 {
747     // return i,j which maximize -grad(f)^T d , under constraint
748     // if alpha_i == C, d != +1
749     // if alpha_i == 0, d != -1
750     double Gmax1 = -DBL_MAX;        // max { -grad(f)_i * d | y_i*d = +1 }
751     int Gmax1_idx = -1;
752
753     double Gmax2 = -DBL_MAX;        // max { -grad(f)_i * d | y_i*d = -1 }
754     int Gmax2_idx = -1;
755
756     int i;
757
758     for( i = 0; i < alpha_count; i++ )
759     {
760         double t;
761
762         if( y[i] > 0 )    // y = +1
763         {
764             if( !is_upper_bound(i) && (t = -G[i]) > Gmax1 )  // d = +1
765             {
766                 Gmax1 = t;
767                 Gmax1_idx = i;
768             }
769             if( !is_lower_bound(i) && (t = G[i]) > Gmax2 )  // d = -1
770             {
771                 Gmax2 = t;
772                 Gmax2_idx = i;
773             }
774         }
775         else        // y = -1
776         {
777             if( !is_upper_bound(i) && (t = -G[i]) > Gmax2 )  // d = +1
778             {
779                 Gmax2 = t;
780                 Gmax2_idx = i;
781             }
782             if( !is_lower_bound(i) && (t = G[i]) > Gmax1 )  // d = -1
783             {
784                 Gmax1 = t;
785                 Gmax1_idx = i;
786             }
787         }
788     }
789
790     out_i = Gmax1_idx;
791     out_j = Gmax2_idx;
792
793     return Gmax1 + Gmax2 < eps;
794 }
795
796
797 void
798 CvSVMSolver::calc_rho( double& rho, double& r )
799 {
800     int i, nr_free = 0;
801     double ub = DBL_MAX, lb = -DBL_MAX, sum_free = 0;
802     
803     for( i = 0; i < alpha_count; i++ )
804     {
805         double yG = y[i]*G[i];
806
807         if( is_lower_bound(i) )
808         {
809             if( y[i] > 0 )
810                 ub = MIN(ub,yG);
811             else
812                 lb = MAX(lb,yG);
813         }
814         else if( is_upper_bound(i) )
815         {
816             if( y[i] < 0)
817                 ub = MIN(ub,yG);
818             else
819                 lb = MAX(lb,yG);
820         }
821         else
822         {
823             ++nr_free;
824             sum_free += yG;
825         }
826     }
827
828     rho = nr_free > 0 ? sum_free/nr_free : (ub + lb)*0.5;
829     r = 0;
830 }
831
832
833 bool
834 CvSVMSolver::select_working_set_nu_svm( int& out_i, int& out_j )
835 {
836     // return i,j which maximize -grad(f)^T d , under constraint
837     // if alpha_i == C, d != +1
838     // if alpha_i == 0, d != -1
839     double Gmax1 = -DBL_MAX;    // max { -grad(f)_i * d | y_i = +1, d = +1 }
840     int Gmax1_idx = -1;
841
842     double Gmax2 = -DBL_MAX;    // max { -grad(f)_i * d | y_i = +1, d = -1 }
843     int Gmax2_idx = -1;
844
845     double Gmax3 = -DBL_MAX;    // max { -grad(f)_i * d | y_i = -1, d = +1 }
846     int Gmax3_idx = -1;
847
848     double Gmax4 = -DBL_MAX;    // max { -grad(f)_i * d | y_i = -1, d = -1 }
849     int Gmax4_idx = -1;
850
851     int i;
852
853     for( i = 0; i < alpha_count; i++ )
854     {
855         double t;
856
857         if( y[i] > 0 )    // y == +1
858         {
859             if( !is_upper_bound(i) && (t = -G[i]) > Gmax1 )  // d = +1
860             {
861                 Gmax1 = t;
862                 Gmax1_idx = i;
863             }
864             if( !is_lower_bound(i) && (t = G[i]) > Gmax2 )  // d = -1
865             {
866                 Gmax2 = t;
867                 Gmax2_idx = i;
868             }
869         }
870         else        // y == -1
871         {
872             if( !is_upper_bound(i) && (t = -G[i]) > Gmax3 )  // d = +1
873             {
874                 Gmax3 = t;
875                 Gmax3_idx = i;
876             }
877             if( !is_lower_bound(i) && (t = G[i]) > Gmax4 )  // d = -1
878             {
879                 Gmax4 = t;
880                 Gmax4_idx = i;
881             }
882         }
883     }
884
885     if( MAX(Gmax1 + Gmax2, Gmax3 + Gmax4) < eps )
886         return 1;
887
888     if( Gmax1 + Gmax2 > Gmax3 + Gmax4 )
889     {
890         out_i = Gmax1_idx;
891         out_j = Gmax2_idx;
892     }
893     else
894     {
895         out_i = Gmax3_idx;
896         out_j = Gmax4_idx;
897     }
898     return 0;
899 }
900
901
902 void
903 CvSVMSolver::calc_rho_nu_svm( double& rho, double& r )
904 {
905     int nr_free1 = 0, nr_free2 = 0;
906     double ub1 = DBL_MAX, ub2 = DBL_MAX;
907     double lb1 = -DBL_MAX, lb2 = -DBL_MAX;
908     double sum_free1 = 0, sum_free2 = 0;
909     double r1, r2;
910
911     int i;
912
913     for( i = 0; i < alpha_count; i++ )
914     {
915         double G_i = G[i];
916         if( y[i] > 0 )
917         {
918             if( is_lower_bound(i) )
919                 ub1 = MIN( ub1, G_i );
920             else if( is_upper_bound(i) )
921                 lb1 = MAX( lb1, G_i );
922             else
923             {
924                 ++nr_free1;
925                 sum_free1 += G_i;
926             }
927         }
928         else
929         {
930             if( is_lower_bound(i) )
931                 ub2 = MIN( ub2, G_i );
932             else if( is_upper_bound(i) )
933                 lb2 = MAX( lb2, G_i );
934             else
935             {
936                 ++nr_free2;
937                 sum_free2 += G_i;
938             }
939         }
940     }
941
942     r1 = nr_free1 > 0 ? sum_free1/nr_free1 : (ub1 + lb1)*0.5;
943     r2 = nr_free2 > 0 ? sum_free2/nr_free2 : (ub2 + lb2)*0.5;
944     
945     rho = (r1 - r2)*0.5;
946     r = (r1 + r2)*0.5;
947 }
948
949
950 /*
951 ///////////////////////// construct and solve various formulations ///////////////////////
952 */
953
954 bool CvSVMSolver::solve_c_svc( int _sample_count, int _var_count, const float** _samples, char* _y,
955                                double _Cp, double _Cn, CvMemStorage* _storage,
956                                CvSVMKernel* _kernel, double* _alpha, CvSVMSolutionInfo& _si )
957 {
958     int i;
959
960     if( !create( _sample_count, _var_count, _samples, _y, _sample_count,
961                  _alpha, _Cp, _Cn, _storage, _kernel, &CvSVMSolver::get_row_svc,
962                  &CvSVMSolver::select_working_set, &CvSVMSolver::calc_rho ))
963         return false;
964
965     for( i = 0; i < sample_count; i++ )
966     {
967         alpha[i] = 0;
968         b[i] = -1;
969     }
970
971     if( !solve_generic( _si ))
972         return false;
973
974     for( i = 0; i < sample_count; i++ )
975         alpha[i] *= y[i];
976
977     return true;
978 }
979
980
981 bool CvSVMSolver::solve_nu_svc( int _sample_count, int _var_count, const float** _samples, char* _y,
982                                 CvMemStorage* _storage, CvSVMKernel* _kernel,
983                                 double* _alpha, CvSVMSolutionInfo& _si )
984 {
985     int i;
986     double sum_pos, sum_neg, inv_r;
987
988     if( !create( _sample_count, _var_count, _samples, _y, _sample_count,
989                  _alpha, 1., 1., _storage, _kernel, &CvSVMSolver::get_row_svc,
990                  &CvSVMSolver::select_working_set_nu_svm, &CvSVMSolver::calc_rho_nu_svm ))
991         return false;
992
993     sum_pos = kernel->params->nu * sample_count * 0.5;
994     sum_neg = kernel->params->nu * sample_count * 0.5;
995
996     for( i = 0; i < sample_count; i++ )
997     {
998         if( y[i] > 0 )
999         {
1000             alpha[i] = MIN(1.0, sum_pos);
1001             sum_pos -= alpha[i];
1002         }
1003         else
1004         {
1005             alpha[i] = MIN(1.0, sum_neg);
1006             sum_neg -= alpha[i];
1007         }
1008         b[i] = 0;
1009     }
1010
1011     if( !solve_generic( _si ))
1012         return false;
1013
1014     inv_r = 1./_si.r;
1015
1016     for( i = 0; i < sample_count; i++ )
1017         alpha[i] *= y[i]*inv_r;
1018
1019     _si.rho *= inv_r;
1020     _si.obj *= (inv_r*inv_r);
1021     _si.upper_bound_p = inv_r;
1022     _si.upper_bound_n = inv_r;
1023
1024     return true;
1025 }
1026
1027
1028 bool CvSVMSolver::solve_one_class( int _sample_count, int _var_count, const float** _samples,
1029                                    CvMemStorage* _storage, CvSVMKernel* _kernel,
1030                                    double* _alpha, CvSVMSolutionInfo& _si )
1031 {
1032     int i, n;
1033     double nu = _kernel->params->nu;
1034     
1035     if( !create( _sample_count, _var_count, _samples, 0, _sample_count,
1036                  _alpha, 1., 1., _storage, _kernel, &CvSVMSolver::get_row_one_class,
1037                  &CvSVMSolver::select_working_set, &CvSVMSolver::calc_rho ))
1038         return false;
1039
1040     y = (char*)cvMemStorageAlloc( storage, sample_count*sizeof(y[0]) );
1041     n = cvRound( nu*sample_count );
1042
1043     for( i = 0; i < sample_count; i++ )
1044     {
1045         y[i] = 1;
1046         b[i] = 0;
1047         alpha[i] = i < n ? 1 : 0;
1048     }
1049
1050     if( n < sample_count )
1051         alpha[n] = nu * sample_count - n;
1052     else
1053         alpha[n-1] = nu * sample_count - (n-1);
1054     
1055     return solve_generic(_si);
1056 }
1057
1058
1059 bool CvSVMSolver::solve_eps_svr( int _sample_count, int _var_count, const float** _samples,
1060                                  const float* _y, CvMemStorage* _storage,
1061                                  CvSVMKernel* _kernel, double* _alpha, CvSVMSolutionInfo& _si )
1062 {
1063     int i;
1064     double p = _kernel->params->p, C = _kernel->params->C;
1065     
1066     if( !create( _sample_count, _var_count, _samples, 0,
1067                  _sample_count*2, 0, C, C, _storage, _kernel, &CvSVMSolver::get_row_svr,
1068                  &CvSVMSolver::select_working_set, &CvSVMSolver::calc_rho ))
1069         return false;
1070
1071     y = (char*)cvMemStorageAlloc( storage, sample_count*2*sizeof(y[0]) );
1072     alpha = (double*)cvMemStorageAlloc( storage, alpha_count*sizeof(alpha[0]) );
1073
1074     for( i = 0; i < sample_count; i++ )
1075     {
1076         alpha[i] = 0;
1077         b[i] = p - _y[i];
1078         y[i] = 1;
1079
1080         alpha[i+sample_count] = 0;
1081         b[i+sample_count] = p + _y[i];
1082         y[i+sample_count] = -1;
1083     }
1084     
1085     if( !solve_generic( _si ))
1086         return false;
1087
1088     for( i = 0; i < sample_count; i++ )
1089         _alpha[i] = alpha[i] - alpha[i+sample_count];
1090
1091     return true;
1092 }
1093
1094
1095 bool CvSVMSolver::solve_nu_svr( int _sample_count, int _var_count, const float** _samples,
1096                                 const float* _y, CvMemStorage* _storage,
1097                                 CvSVMKernel* _kernel, double* _alpha, CvSVMSolutionInfo& _si )
1098 {
1099     int i;
1100     double C = _kernel->params->C, sum;
1101
1102     if( !create( _sample_count, _var_count, _samples, 0,
1103                  _sample_count*2, 0, 1., 1., _storage, _kernel, &CvSVMSolver::get_row_svr,
1104                  &CvSVMSolver::select_working_set_nu_svm, &CvSVMSolver::calc_rho_nu_svm ))
1105         return false;
1106
1107     y = (char*)cvMemStorageAlloc( storage, sample_count*2*sizeof(y[0]) );
1108     alpha = (double*)cvMemStorageAlloc( storage, alpha_count*sizeof(alpha[0]) );
1109     sum = C * _kernel->params->nu * sample_count * 0.5;
1110
1111     for( i = 0; i < sample_count; i++ )
1112     {
1113         alpha[i] = alpha[i + sample_count] = MIN(sum, C);
1114         sum -= alpha[i];
1115
1116         b[i] = -_y[i];
1117         y[i] = 1;
1118
1119         b[i + sample_count] = _y[i];
1120         y[i + sample_count] = -1;
1121     }
1122     
1123     if( !solve_generic( _si ))
1124         return false;
1125
1126     for( i = 0; i < sample_count; i++ )
1127         _alpha[i] = alpha[i] - alpha[i+sample_count];
1128
1129     return true;
1130 }
1131
1132
1133 //////////////////////////////////////////////////////////////////////////////////////////
1134
1135 CvSVM::CvSVM()
1136 {
1137     decision_func = 0;
1138     class_labels = 0;
1139     class_weights = 0;
1140     storage = 0;
1141     var_idx = 0;
1142     kernel = 0;
1143     solver = 0;
1144     default_model_name = "my_svm";
1145
1146     clear();
1147 }
1148
1149
1150 CvSVM::~CvSVM()
1151 {
1152     clear();
1153 }
1154
1155
1156 void CvSVM::clear()
1157 {
1158     cvFree( &decision_func );
1159     cvReleaseMat( &class_labels );
1160     cvReleaseMat( &class_weights );
1161     cvReleaseMemStorage( &storage );
1162     cvReleaseMat( &var_idx );
1163     delete kernel;
1164     delete solver;
1165     kernel = 0;
1166     solver = 0;
1167     var_all = 0;
1168     sv = 0;
1169     sv_total = 0;
1170 }
1171
1172
1173 CvSVM::CvSVM( const CvMat* _train_data, const CvMat* _responses,
1174     const CvMat* _var_idx, const CvMat* _sample_idx, CvSVMParams _params )
1175 {
1176     decision_func = 0;
1177     class_labels = 0;
1178     class_weights = 0;
1179     storage = 0;
1180     var_idx = 0;
1181     kernel = 0;
1182     solver = 0;
1183     default_model_name = "my_svm";
1184
1185     train( _train_data, _responses, _var_idx, _sample_idx, _params );
1186 }
1187
1188
1189 int CvSVM::get_support_vector_count() const
1190 {
1191     return sv_total;
1192 }
1193
1194
1195 const float* CvSVM::get_support_vector(int i) const
1196 {
1197     return sv && (unsigned)i < (unsigned)sv_total ? sv[i] : 0;
1198 }
1199
1200
1201 bool CvSVM::set_params( const CvSVMParams& _params )
1202 {
1203     bool ok = false;
1204     
1205     CV_FUNCNAME( "CvSVM::set_params" );
1206
1207     __BEGIN__;
1208
1209     int kernel_type, svm_type;
1210
1211     params = _params;
1212
1213     kernel_type = params.kernel_type;
1214     svm_type = params.svm_type;
1215
1216     if( kernel_type != LINEAR && kernel_type != POLY &&
1217         kernel_type != SIGMOID && kernel_type != RBF )
1218         CV_ERROR( CV_StsBadArg, "Unknown/unsupported kernel type" );
1219
1220     if( kernel_type == LINEAR )
1221         params.gamma = 1;
1222     else if( params.gamma <= 0 )
1223         CV_ERROR( CV_StsOutOfRange, "gamma parameter of the kernel must be positive" );
1224
1225     if( kernel_type != SIGMOID && kernel_type != POLY )
1226         params.coef0 = 0;
1227     else if( params.coef0 < 0 )
1228         CV_ERROR( CV_StsOutOfRange, "The kernel parameter <coef0> must be positive or zero" );
1229
1230     if( kernel_type != POLY )
1231         params.degree = 0;
1232     else if( params.degree <= 0 )
1233         CV_ERROR( CV_StsOutOfRange, "The kernel parameter <degree> must be positive" );
1234
1235     if( svm_type != C_SVC && svm_type != NU_SVC &&
1236         svm_type != ONE_CLASS && svm_type != EPS_SVR &&
1237         svm_type != NU_SVR )
1238         CV_ERROR( CV_StsBadArg, "Unknown/unsupported SVM type" );
1239
1240     if( svm_type == ONE_CLASS || svm_type == NU_SVC )
1241         params.C = 0;
1242     else if( params.C <= 0 )
1243         CV_ERROR( CV_StsOutOfRange, "The parameter C must be positive" );
1244
1245     if( svm_type == C_SVC || svm_type == EPS_SVR )
1246         params.nu = 0;
1247     else if( params.nu <= 0 || params.nu >= 1 )
1248         CV_ERROR( CV_StsOutOfRange, "The parameter nu must be between 0 and 1" );
1249
1250     if( svm_type != EPS_SVR )
1251         params.p = 0;
1252     else if( params.p <= 0 )
1253         CV_ERROR( CV_StsOutOfRange, "The parameter p must be positive" );
1254
1255     if( svm_type != C_SVC )
1256         params.class_weights = 0;
1257
1258     params.term_crit = cvCheckTermCriteria( params.term_crit, DBL_EPSILON, INT_MAX );
1259     params.term_crit.epsilon = MAX( params.term_crit.epsilon, DBL_EPSILON );
1260     ok = true;
1261
1262     __END__;
1263
1264     return ok;
1265 }
1266
1267
1268
1269 void CvSVM::create_kernel()
1270 {
1271     kernel = new CvSVMKernel(&params,0);
1272 }
1273
1274
1275 void CvSVM::create_solver( )
1276 {
1277     solver = new CvSVMSolver;
1278 }
1279
1280
1281 // switching function
1282 bool CvSVM::train1( int sample_count, int var_count, const float** samples,
1283                     const void* _responses, double Cp, double Cn,
1284                     CvMemStorage* _storage, double* alpha, double& rho )
1285 {
1286     bool ok = false;
1287     
1288     //CV_FUNCNAME( "CvSVM::train1" );
1289
1290     __BEGIN__;
1291
1292     CvSVMSolutionInfo si;
1293     int svm_type = params.svm_type;
1294
1295     si.rho = 0;
1296
1297     ok = svm_type == C_SVC ? solver->solve_c_svc( sample_count, var_count, samples, (char*)_responses,
1298                                                   Cp, Cn, _storage, kernel, alpha, si ) :
1299          svm_type == NU_SVC ? solver->solve_nu_svc( sample_count, var_count, samples, (char*)_responses,
1300                                                     _storage, kernel, alpha, si ) :
1301          svm_type == ONE_CLASS ? solver->solve_one_class( sample_count, var_count, samples,
1302                                                           _storage, kernel, alpha, si ) :
1303          svm_type == EPS_SVR ? solver->solve_eps_svr( sample_count, var_count, samples, (float*)_responses,
1304                                                       _storage, kernel, alpha, si ) :
1305          svm_type == NU_SVR ? solver->solve_nu_svr( sample_count, var_count, samples, (float*)_responses,
1306                                                     _storage, kernel, alpha, si ) : false;
1307
1308     rho = si.rho;
1309
1310     __END__;
1311
1312     return ok;
1313 }
1314
1315
1316 bool CvSVM::do_train( int svm_type, int sample_count, int var_count, const float** samples,
1317                     const CvMat* responses, CvMemStorage* temp_storage, double* alpha )
1318 {
1319     bool ok = false;
1320
1321     CV_FUNCNAME( "CvSVM::do_train" );
1322
1323     __BEGIN__;
1324
1325     CvSVMDecisionFunc* df = 0;
1326     const int sample_size = var_count*sizeof(samples[0][0]);
1327     int i, j, k;
1328
1329     if( svm_type == ONE_CLASS || svm_type == EPS_SVR || svm_type == NU_SVR )
1330     {
1331         int sv_count = 0;
1332
1333         CV_CALL( decision_func = df =
1334             (CvSVMDecisionFunc*)cvAlloc( sizeof(df[0]) ));
1335
1336         df->rho = 0;
1337         if( !train1( sample_count, var_count, samples, svm_type == ONE_CLASS ? 0 :
1338             responses->data.i, 0, 0, temp_storage, alpha, df->rho ))
1339             EXIT;
1340
1341         for( i = 0; i < sample_count; i++ )
1342             sv_count += fabs(alpha[i]) > 0;
1343
1344         sv_total = df->sv_count = sv_count;
1345         CV_CALL( df->alpha = (double*)cvMemStorageAlloc( storage, sv_count*sizeof(df->alpha[0])) );
1346         CV_CALL( sv = (float**)cvMemStorageAlloc( storage, sv_count*sizeof(sv[0])));
1347
1348         for( i = k = 0; i < sample_count; i++ )
1349         {
1350             if( fabs(alpha[i]) > 0 )
1351             {
1352                 CV_CALL( sv[k] = (float*)cvMemStorageAlloc( storage, sample_size ));
1353                 memcpy( sv[k], samples[i], sample_size );
1354                 df->alpha[k++] = alpha[i];
1355             }
1356         }
1357     }
1358     else
1359     {
1360         int class_count = class_labels->cols;
1361         int* sv_tab = 0;
1362         const float** temp_samples = 0;
1363         int* class_ranges = 0;
1364         char* temp_y = 0;
1365         assert( svm_type == CvSVM::C_SVC || svm_type == CvSVM::NU_SVC );
1366
1367         if( svm_type == CvSVM::C_SVC && params.class_weights )
1368         {
1369             const CvMat* cw = params.class_weights;
1370
1371             if( !CV_IS_MAT(cw) || cw->cols != 1 && cw->rows != 1 ||
1372                 cw->rows + cw->cols - 1 != class_count ||
1373                 CV_MAT_TYPE(cw->type) != CV_32FC1 && CV_MAT_TYPE(cw->type) != CV_64FC1 )
1374                 CV_ERROR( CV_StsBadArg, "params.class_weights must be 1d floating-point vector "
1375                     "containing as many elements as the number of classes" );
1376
1377             CV_CALL( class_weights = cvCreateMat( cw->rows, cw->cols, CV_64F ));
1378             CV_CALL( cvConvert( cw, class_weights ));
1379             CV_CALL( cvScale( class_weights, class_weights, params.C ));
1380         }
1381
1382         CV_CALL( decision_func = df = (CvSVMDecisionFunc*)cvAlloc(
1383             (class_count*(class_count-1)/2)*sizeof(df[0])));
1384
1385         CV_CALL( sv_tab = (int*)cvMemStorageAlloc( temp_storage, sample_count*sizeof(sv_tab[0]) ));
1386         memset( sv_tab, 0, sample_count*sizeof(sv_tab[0]) );
1387         CV_CALL( class_ranges = (int*)cvMemStorageAlloc( temp_storage,
1388                             (class_count + 1)*sizeof(class_ranges[0])));
1389         CV_CALL( temp_samples = (const float**)cvMemStorageAlloc( temp_storage,
1390                             sample_count*sizeof(temp_samples[0])));
1391         CV_CALL( temp_y = (char*)cvMemStorageAlloc( temp_storage, sample_count));
1392
1393         class_ranges[class_count] = 0;
1394         cvSortSamplesByClasses( samples, responses, class_ranges, 0 );
1395         //check that while cross-validation there were the samples from all the classes
1396         if( class_ranges[class_count] <= 0 )
1397             CV_ERROR( CV_StsBadArg, "While cross-validation one or more of the classes have "
1398             "been fell out of the sample. Try to enlarge <CvSVMParams::k_fold>" ); 
1399
1400         if( svm_type == NU_SVC )
1401         {
1402             // check if nu is feasible
1403             for(i = 0; i < class_count; i++ )
1404             {
1405                 int ci = class_ranges[i+1] - class_ranges[i];
1406                 for( j = i+1; j< class_count; j++ )
1407                 {
1408                     int cj = class_ranges[j+1] - class_ranges[j];
1409                     if( params.nu*(ci + cj)*0.5 > MIN( ci, cj ) )
1410                     {
1411                         // !!!TODO!!! add some diagnostic
1412                         EXIT; // exit immediately; will release the model and return NULL pointer
1413                     }
1414                 }
1415             }
1416         }
1417
1418         // train n*(n-1)/2 classifiers
1419         for( i = 0; i < class_count; i++ )
1420         {
1421             for( j = i+1; j < class_count; j++, df++ )
1422             {
1423                 int si = class_ranges[i], ci = class_ranges[i+1] - si;
1424                 int sj = class_ranges[j], cj = class_ranges[j+1] - sj;
1425                 double Cp = params.C, Cn = Cp;
1426                 int k1 = 0, sv_count = 0;
1427
1428                 for( k = 0; k < ci; k++ )
1429                 {
1430                     temp_samples[k] = samples[si + k];
1431                     temp_y[k] = 1;
1432                 }
1433
1434                 for( k = 0; k < cj; k++ )
1435                 {
1436                     temp_samples[ci + k] = samples[sj + k];
1437                     temp_y[ci + k] = -1;
1438                 }
1439
1440                 if( class_weights )
1441                 {
1442                     Cp = class_weights->data.db[i];
1443                     Cn = class_weights->data.db[j];
1444                 }
1445
1446                 if( !train1( ci + cj, var_count, temp_samples, temp_y,
1447                              Cp, Cn, temp_storage, alpha, df->rho ))
1448                     EXIT;
1449
1450                 for( k = 0; k < ci + cj; k++ )
1451                     sv_count += fabs(alpha[k]) > 0;
1452
1453                 df->sv_count = sv_count;
1454
1455                 CV_CALL( df->alpha = (double*)cvMemStorageAlloc( temp_storage,
1456                                                 sv_count*sizeof(df->alpha[0])));
1457                 CV_CALL( df->sv_index = (int*)cvMemStorageAlloc( temp_storage,
1458                                                 sv_count*sizeof(df->sv_index[0])));
1459
1460                 for( k = 0; k < ci; k++ )
1461                 {
1462                     if( fabs(alpha[k]) > 0 )
1463                     {
1464                         sv_tab[si + k] = 1;
1465                         df->sv_index[k1] = si + k;
1466                         df->alpha[k1++] = alpha[k];
1467                     }
1468                 }
1469
1470                 for( k = 0; k < cj; k++ )
1471                 {
1472                     if( fabs(alpha[ci + k]) > 0 )
1473                     {
1474                         sv_tab[sj + k] = 1;
1475                         df->sv_index[k1] = sj + k;
1476                         df->alpha[k1++] = alpha[ci + k];
1477                     }
1478                 }
1479             }
1480         }
1481
1482         // allocate support vectors and initialize sv_tab
1483         for( i = 0, k = 0; i < sample_count; i++ )
1484         {
1485             if( sv_tab[i] )
1486                 sv_tab[i] = ++k;
1487         }
1488
1489         sv_total = k;
1490         CV_CALL( sv = (float**)cvMemStorageAlloc( storage, sv_total*sizeof(sv[0])));
1491
1492         for( i = 0, k = 0; i < sample_count; i++ )
1493         {
1494             if( sv_tab[i] )
1495             {
1496                 CV_CALL( sv[k] = (float*)cvMemStorageAlloc( storage, sample_size ));
1497                 memcpy( sv[k], samples[i], sample_size );
1498                 k++;
1499             }
1500         }
1501
1502         df = (CvSVMDecisionFunc*)decision_func;
1503
1504         // set sv pointers
1505         for( i = 0; i < class_count; i++ )
1506         {
1507             for( j = i+1; j < class_count; j++, df++ )
1508             {
1509                 for( k = 0; k < df->sv_count; k++ )
1510                 {
1511                     df->sv_index[k] = sv_tab[df->sv_index[k]]-1;
1512                     assert( (unsigned)df->sv_index[k] < (unsigned)sv_total );
1513                 }
1514             }
1515         }
1516     }
1517
1518     ok = true;
1519
1520     __END__;
1521
1522     return ok;
1523 }
1524
1525 bool CvSVM::train( const CvMat* _train_data, const CvMat* _responses,
1526     const CvMat* _var_idx, const CvMat* _sample_idx, CvSVMParams _params )
1527 {
1528     bool ok = false;
1529     CvMat* responses = 0;
1530     CvMemStorage* temp_storage = 0;
1531     const float** samples = 0;
1532     
1533     CV_FUNCNAME( "CvSVM::train" );
1534
1535     __BEGIN__;
1536
1537     int svm_type, sample_count, var_count, sample_size;
1538     int block_size = 1 << 16;
1539     double* alpha;
1540
1541     clear();
1542     CV_CALL( set_params( _params ));
1543
1544     svm_type = _params.svm_type;
1545
1546     /* Prepare training data and related parameters */
1547     CV_CALL( cvPrepareTrainData( "CvSVM::train", _train_data, CV_ROW_SAMPLE,
1548                                  svm_type != CvSVM::ONE_CLASS ? _responses : 0,
1549                                  svm_type == CvSVM::C_SVC ||
1550                                  svm_type == CvSVM::NU_SVC ? CV_VAR_CATEGORICAL :
1551                                  CV_VAR_ORDERED, _var_idx, _sample_idx,
1552                                  false, &samples, &sample_count, &var_count, &var_all,
1553                                  &responses, &class_labels, &var_idx ));
1554
1555
1556     sample_size = var_count*sizeof(samples[0][0]);
1557
1558     // make the storage block size large enough to fit all
1559     // the temporary vectors and output support vectors.
1560     block_size = MAX( block_size, sample_count*(int)sizeof(CvSVMKernelRow));
1561     block_size = MAX( block_size, sample_count*2*(int)sizeof(double) + 1024 );
1562     block_size = MAX( block_size, sample_size*2 + 1024 );
1563
1564     CV_CALL( storage = cvCreateMemStorage(block_size));
1565     CV_CALL( temp_storage = cvCreateChildMemStorage(storage));
1566     CV_CALL( alpha = (double*)cvMemStorageAlloc(temp_storage, sample_count*sizeof(double)));
1567
1568     create_kernel();
1569     create_solver();
1570
1571     if( !do_train( svm_type, sample_count, var_count, samples, responses, temp_storage, alpha ))
1572         EXIT;
1573
1574     ok = true; // model has been trained succesfully
1575
1576     __END__;
1577
1578     delete solver;
1579     solver = 0;
1580     cvReleaseMemStorage( &temp_storage );
1581     cvReleaseMat( &responses );
1582     cvFree( &samples );
1583
1584     if( cvGetErrStatus() < 0 || !ok )
1585         clear();
1586
1587     return ok;
1588 }
1589
1590 bool CvSVM::train_auto( const CvMat* _train_data, const CvMat* _responses,
1591     const CvMat* _var_idx, const CvMat* _sample_idx, CvSVMParams _params, int k_fold,
1592     CvParamGrid C_grid, CvParamGrid gamma_grid, CvParamGrid p_grid,
1593     CvParamGrid nu_grid, CvParamGrid coef_grid, CvParamGrid degree_grid )
1594 {
1595     bool ok = false;
1596     CvMat* responses = 0;
1597     CvMat* responses_local = 0;
1598     CvMemStorage* temp_storage = 0;
1599     const float** samples = 0;
1600     const float** samples_local = 0;
1601     
1602     CV_FUNCNAME( "CvSVM::train_auto" );
1603     __BEGIN__;
1604
1605     int svm_type, sample_count, var_count, sample_size;
1606     int block_size = 1 << 16;
1607     double* alpha;
1608     int i, k;
1609     CvRNG rng = cvRNG(-1);
1610
1611     // all steps are logarithmic and must be > 1
1612     double degree_step = 10, g_step = 10, coef_step = 10, C_step = 10, nu_step = 10, p_step = 10;
1613     double gamma = 0, C = 0, degree = 0, coef = 0, p = 0, nu = 0;
1614     double best_degree = 0, best_gamma = 0, best_coef = 0, best_C = 0, best_nu = 0, best_p = 0;
1615     float min_error = FLT_MAX, error;
1616
1617     if( _params.svm_type == CvSVM::ONE_CLASS )
1618     {
1619         if(!train( _train_data, _responses, _var_idx, _sample_idx, _params ))
1620             EXIT;
1621         return true;
1622     }
1623
1624     clear();
1625
1626     if( k_fold < 2 )
1627         CV_ERROR( CV_StsBadArg, "Parameter <k_fold> must be > 1" );
1628
1629     CV_CALL(set_params( _params ));
1630     svm_type = _params.svm_type;
1631
1632     // All the parameters except, possibly, <coef0> are positive.
1633     // <coef0> is nonnegative
1634     if( C_grid.step <= 1 )
1635     {
1636         C_grid.min_val = C_grid.max_val = params.C;
1637         C_grid.step = 10;
1638     }
1639     else
1640         CV_CALL(C_grid.check());
1641
1642     if( gamma_grid.step <= 1 )
1643     {
1644         gamma_grid.min_val = gamma_grid.max_val = params.gamma;
1645         gamma_grid.step = 10;
1646     }
1647     else
1648         CV_CALL(gamma_grid.check());
1649
1650     if( p_grid.step <= 1 )
1651     {
1652         p_grid.min_val = p_grid.max_val = params.p;
1653         p_grid.step = 10;
1654     }
1655     else
1656         CV_CALL(p_grid.check());
1657
1658     if( nu_grid.step <= 1 )
1659     {
1660         nu_grid.min_val = nu_grid.max_val = params.nu;
1661         nu_grid.step = 10;
1662     }
1663     else
1664         CV_CALL(nu_grid.check());
1665
1666     if( coef_grid.step <= 1 )
1667     {
1668         coef_grid.min_val = coef_grid.max_val = params.coef0;
1669         coef_grid.step = 10;
1670     }
1671     else
1672         CV_CALL(coef_grid.check());
1673
1674     if( degree_grid.step <= 1 )
1675     {
1676         degree_grid.min_val = degree_grid.max_val = params.degree;
1677         degree_grid.step = 10;
1678     }
1679     else
1680         CV_CALL(degree_grid.check());
1681
1682     // these parameters are not used:
1683     if( params.kernel_type != CvSVM::POLY )
1684         degree_grid.min_val = degree_grid.max_val = params.degree;
1685     if( params.kernel_type == CvSVM::LINEAR )
1686         gamma_grid.min_val = gamma_grid.max_val = params.gamma;
1687     if( params.kernel_type != CvSVM::POLY && params.kernel_type != CvSVM::SIGMOID )
1688         coef_grid.min_val = coef_grid.max_val = params.coef0;
1689     if( svm_type == CvSVM::NU_SVC || svm_type == CvSVM::ONE_CLASS )
1690         C_grid.min_val = C_grid.max_val = params.C;
1691     if( svm_type == CvSVM::C_SVC || svm_type == CvSVM::EPS_SVR )
1692         nu_grid.min_val = nu_grid.max_val = params.nu;
1693     if( svm_type != CvSVM::EPS_SVR )
1694         p_grid.min_val = p_grid.max_val = params.p;
1695
1696     CV_ASSERT( g_step > 1 && degree_step > 1 && coef_step > 1);
1697     CV_ASSERT( p_step > 1 && C_step > 1 && nu_step > 1 );
1698
1699     /* Prepare training data and related parameters */
1700     CV_CALL(cvPrepareTrainData( "CvSVM::train_auto", _train_data, CV_ROW_SAMPLE,
1701                                  svm_type != CvSVM::ONE_CLASS ? _responses : 0,
1702                                  svm_type == CvSVM::C_SVC ||
1703                                  svm_type == CvSVM::NU_SVC ? CV_VAR_CATEGORICAL :
1704                                  CV_VAR_ORDERED, _var_idx, _sample_idx,
1705                                  false, &samples, &sample_count, &var_count, &var_all,
1706                                  &responses, &class_labels, &var_idx ));
1707
1708     sample_size = var_count*sizeof(samples[0][0]);
1709
1710     // make the storage block size large enough to fit all
1711     // the temporary vectors and output support vectors.
1712     block_size = MAX( block_size, sample_count*(int)sizeof(CvSVMKernelRow));
1713     block_size = MAX( block_size, sample_count*2*(int)sizeof(double) + 1024 );
1714     block_size = MAX( block_size, sample_size*2 + 1024 );
1715
1716     CV_CALL(storage = cvCreateMemStorage(block_size));
1717     CV_CALL(temp_storage = cvCreateChildMemStorage(storage));
1718     CV_CALL(alpha = (double*)cvMemStorageAlloc(temp_storage, sample_count*sizeof(double)));
1719
1720     create_kernel();
1721     create_solver();
1722
1723     {
1724     const int testset_size = sample_count/k_fold;
1725     const int trainset_size = sample_count - testset_size;
1726     const int last_testset_size = sample_count - testset_size*(k_fold-1);
1727     const int last_trainset_size = sample_count - last_testset_size;
1728     const bool is_regression = (svm_type == EPS_SVR) || (svm_type == NU_SVR);
1729
1730     size_t resp_elem_size = CV_ELEM_SIZE(responses->type);
1731     size_t size = 2*last_trainset_size*sizeof(samples[0]);
1732
1733     samples_local = (const float**) cvAlloc( size );
1734     memset( samples_local, 0, size );
1735
1736     responses_local = cvCreateMat( 1, trainset_size, CV_MAT_TYPE(responses->type) );
1737     cvZero( responses_local );
1738
1739     // randomly permute samples and responses
1740     for( i = 0; i < sample_count; i++ )
1741     {
1742         int i1 = cvRandInt( &rng ) % sample_count;
1743         int i2 = cvRandInt( &rng ) % sample_count;
1744         const float* temp;
1745         float t;
1746         int y;
1747
1748         CV_SWAP( samples[i1], samples[i2], temp );
1749         if( is_regression )
1750             CV_SWAP( responses->data.fl[i1], responses->data.fl[i2], t );
1751         else
1752             CV_SWAP( responses->data.i[i1], responses->data.i[i2], y );
1753     }
1754
1755     C = C_grid.min_val;
1756     do
1757     {
1758       params.C = C;
1759       gamma = gamma_grid.min_val;
1760       do
1761       {
1762         params.gamma = gamma;
1763         p = p_grid.min_val;
1764         do
1765         {
1766           params.p = p;
1767           nu = nu_grid.min_val;
1768           do
1769           {
1770             params.nu = nu;
1771             coef = coef_grid.min_val;
1772             do
1773             {
1774               params.coef0 = coef;
1775               degree = degree_grid.min_val;
1776               do
1777               {
1778                 params.degree = degree;
1779
1780                 float** test_samples_ptr = (float**)samples;
1781                 uchar* true_resp = responses->data.ptr;
1782                 int test_size = testset_size;
1783                 int train_size = trainset_size;
1784
1785                 error = 0;
1786                 for( k = 0; k < k_fold; k++ )
1787                 {
1788                     memcpy( samples_local, samples, sizeof(samples[0])*test_size*k );
1789                     memcpy( samples_local + test_size*k, test_samples_ptr + test_size,
1790                         sizeof(samples[0])*(sample_count - testset_size*(k+1)) );
1791
1792                     memcpy( responses_local->data.ptr, responses->data.ptr, resp_elem_size*test_size*k );
1793                     memcpy( responses_local->data.ptr + resp_elem_size*test_size*k,
1794                         true_resp + resp_elem_size*test_size,
1795                         sizeof(samples[0])*(sample_count - testset_size*(k+1)) );
1796
1797                     if( k == k_fold - 1 )
1798                     {
1799                         test_size = last_testset_size;
1800                         train_size = last_trainset_size;
1801                         responses_local->cols = last_trainset_size;
1802                     }
1803
1804                     // Train SVM on <train_size> samples
1805                     if( !do_train( svm_type, train_size, var_count,
1806                         (const float**)samples_local, responses_local, temp_storage, alpha ) )
1807                         EXIT;
1808
1809                     // Compute test set error on <test_size> samples
1810                     CvMat s = cvMat( 1, var_count, CV_32FC1 );
1811                     for( i = 0; i < test_size; i++, true_resp += resp_elem_size, test_samples_ptr++ )
1812                     {
1813                         float resp;
1814                         s.data.fl = *test_samples_ptr;
1815                         resp = predict( &s );
1816                         error += is_regression ? powf( resp - *(float*)true_resp, 2 )
1817                             : ((int)resp != *(int*)true_resp);
1818                     }
1819                 }
1820                 if( min_error > error )
1821                 {
1822                     min_error   = error;
1823                     best_degree = degree;
1824                     best_gamma  = gamma;
1825                     best_coef   = coef;
1826                     best_C      = C;
1827                     best_nu     = nu;
1828                     best_p      = p;
1829                 }
1830                 degree *= degree_grid.step;
1831               }
1832               while( degree < degree_grid.max_val );
1833               coef *= coef_grid.step;
1834             }
1835             while( coef < coef_grid.max_val );
1836             nu *= nu_grid.step;
1837           }
1838           while( nu < nu_grid.max_val );
1839           p *= p_grid.step;
1840         }
1841         while( p < p_grid.max_val );
1842         gamma *= gamma_grid.step;
1843       }
1844       while( gamma < gamma_grid.max_val );
1845       C *= C_grid.step;
1846     }
1847     while( C < C_grid.max_val );
1848     }
1849
1850     min_error /= (float) sample_count;
1851
1852     params.C      = best_C;
1853     params.nu     = best_nu;
1854     params.p      = best_p;
1855     params.gamma  = best_gamma;
1856     params.degree = best_degree;
1857     params.coef0  = best_coef;
1858
1859     CV_CALL(ok = do_train( svm_type, sample_count, var_count, samples, responses, temp_storage, alpha ));
1860  
1861     __END__;
1862
1863     delete solver;
1864     solver = 0;
1865     cvReleaseMemStorage( &temp_storage );
1866     cvReleaseMat( &responses );
1867     cvReleaseMat( &responses_local );
1868     cvFree( &samples );
1869     cvFree( &samples_local );
1870
1871     if( cvGetErrStatus() < 0 || !ok )
1872         clear();
1873
1874     return ok;
1875 }
1876
1877 float CvSVM::predict( const CvMat* sample ) const
1878 {
1879     bool local_alloc = 0;
1880     float result = 0;
1881     float* row_sample = 0;
1882     Qfloat* buffer = 0;
1883
1884     CV_FUNCNAME( "CvSVM::predict" );
1885
1886     __BEGIN__;
1887
1888     int class_count;
1889     int var_count, buf_sz;
1890
1891     if( !kernel )
1892         CV_ERROR( CV_StsBadArg, "The SVM should be trained first" );
1893
1894     class_count = class_labels ? class_labels->cols :
1895                   params.svm_type == ONE_CLASS ? 1 : 0;
1896
1897     CV_CALL( cvPreparePredictData( sample, var_all, var_idx,
1898                                    class_count, 0, &row_sample ));
1899
1900     var_count = get_var_count();
1901
1902     buf_sz = sv_total*sizeof(buffer[0]) + (class_count+1)*sizeof(int);
1903     if( buf_sz <= CV_MAX_LOCAL_SIZE )
1904     {
1905         CV_CALL( buffer = (Qfloat*)cvStackAlloc( buf_sz ));
1906         local_alloc = 1;
1907     }
1908     else
1909         CV_CALL( buffer = (Qfloat*)cvAlloc( buf_sz ));
1910     
1911     if( params.svm_type == EPS_SVR ||
1912         params.svm_type == NU_SVR ||
1913         params.svm_type == ONE_CLASS )
1914     {
1915         CvSVMDecisionFunc* df = (CvSVMDecisionFunc*)decision_func;
1916         int i, sv_count = df->sv_count;
1917         double sum = -df->rho;
1918
1919         kernel->calc( sv_count, var_count, (const float**)sv, row_sample, buffer );
1920         for( i = 0; i < sv_count; i++ )
1921             sum += buffer[i]*df->alpha[i];
1922
1923         result = params.svm_type == ONE_CLASS ? (float)(sum > 0) : (float)sum;
1924     }
1925     else if( params.svm_type == C_SVC ||
1926              params.svm_type == NU_SVC )
1927     {
1928         CvSVMDecisionFunc* df = (CvSVMDecisionFunc*)decision_func;
1929         int* vote = (int*)(buffer + sv_total);
1930         int i, j, k;
1931
1932         memset( vote, 0, class_count*sizeof(vote[0]));
1933         kernel->calc( sv_total, var_count, (const float**)sv, row_sample, buffer );
1934
1935         for( i = 0; i < class_count; i++ )
1936         {
1937             for( j = i+1; j < class_count; j++, df++ )
1938             {
1939                 double sum = -df->rho;
1940                 int sv_count = df->sv_count;
1941                 for( k = 0; k < sv_count; k++ )
1942                     sum += df->alpha[k]*buffer[df->sv_index[k]];
1943
1944                 vote[sum > 0 ? i : j]++;
1945             }
1946         }
1947
1948         for( i = 1, k = 0; i < class_count; i++ )
1949         {
1950             if( vote[i] > vote[k] )
1951                 k = i;
1952         }
1953
1954         result = (float)(class_labels->data.i[k]);
1955     }
1956     else
1957         CV_ERROR( CV_StsBadArg, "INTERNAL ERROR: Unknown SVM type, "
1958                                 "the SVM structure is probably corrupted" );
1959
1960     __END__;
1961
1962     if( sample && (!CV_IS_MAT(sample) || sample->data.fl != row_sample) )
1963         cvFree( &row_sample );
1964
1965     if( !local_alloc )
1966         cvFree( &buffer );
1967
1968     return result;
1969 }
1970
1971
1972 void CvSVM::write_params( CvFileStorage* fs )
1973 {
1974     //CV_FUNCNAME( "CvSVM::write_params" );
1975
1976     __BEGIN__;
1977     
1978     int svm_type = params.svm_type;
1979     int kernel_type = params.kernel_type;
1980
1981     const char* svm_type_str =
1982         svm_type == CvSVM::C_SVC ? "C_SVC" :
1983         svm_type == CvSVM::NU_SVC ? "NU_SVC" :
1984         svm_type == CvSVM::ONE_CLASS ? "ONE_CLASS" :
1985         svm_type == CvSVM::EPS_SVR ? "EPS_SVR" :
1986         svm_type == CvSVM::NU_SVR ? "NU_SVR" : 0;
1987     const char* kernel_type_str =
1988         kernel_type == CvSVM::LINEAR ? "LINEAR" :
1989         kernel_type == CvSVM::POLY ? "POLY" :
1990         kernel_type == CvSVM::RBF ? "RBF" :
1991         kernel_type == CvSVM::SIGMOID ? "SIGMOID" : 0;
1992
1993     if( svm_type_str )
1994         cvWriteString( fs, "svm_type", svm_type_str );
1995     else
1996         cvWriteInt( fs, "svm_type", svm_type );
1997
1998     // save kernel
1999     cvStartWriteStruct( fs, "kernel", CV_NODE_MAP + CV_NODE_FLOW );
2000     
2001     if( kernel_type_str )
2002         cvWriteString( fs, "type", kernel_type_str );
2003     else
2004         cvWriteInt( fs, "type", kernel_type );
2005
2006     if( kernel_type == CvSVM::POLY || !kernel_type_str )
2007         cvWriteReal( fs, "degree", params.degree );
2008
2009     if( kernel_type != CvSVM::LINEAR || !kernel_type_str )
2010         cvWriteReal( fs, "gamma", params.gamma );
2011
2012     if( kernel_type == CvSVM::POLY || kernel_type == CvSVM::SIGMOID || !kernel_type_str )
2013         cvWriteReal( fs, "coef0", params.coef0 );
2014
2015     cvEndWriteStruct(fs);
2016
2017     if( svm_type == CvSVM::C_SVC || svm_type == CvSVM::EPS_SVR ||
2018         svm_type == CvSVM::NU_SVR || !svm_type_str )
2019         cvWriteReal( fs, "C", params.C );
2020
2021     if( svm_type == CvSVM::NU_SVC || svm_type == CvSVM::ONE_CLASS ||
2022         svm_type == CvSVM::NU_SVR || !svm_type_str )
2023         cvWriteReal( fs, "nu", params.nu );
2024
2025     if( svm_type == CvSVM::EPS_SVR || !svm_type_str )
2026         cvWriteReal( fs, "p", params.p );
2027
2028     cvStartWriteStruct( fs, "term_criteria", CV_NODE_MAP + CV_NODE_FLOW );
2029     if( params.term_crit.type & CV_TERMCRIT_EPS )
2030         cvWriteReal( fs, "epsilon", params.term_crit.epsilon );
2031     if( params.term_crit.type & CV_TERMCRIT_ITER )
2032         cvWriteInt( fs, "iterations", params.term_crit.max_iter );
2033     cvEndWriteStruct( fs );
2034
2035     __END__;
2036 }
2037
2038
2039 void CvSVM::write( CvFileStorage* fs, const char* name )
2040 {
2041     CV_FUNCNAME( "CvSVM::write" );
2042
2043     __BEGIN__;
2044
2045     int i, var_count = get_var_count(), df_count, class_count;
2046     const CvSVMDecisionFunc* df = decision_func;
2047
2048     cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_SVM );
2049
2050     write_params( fs );
2051
2052     cvWriteInt( fs, "var_all", var_all );
2053     cvWriteInt( fs, "var_count", var_count );
2054
2055     class_count = class_labels ? class_labels->cols :
2056                   params.svm_type == CvSVM::ONE_CLASS ? 1 : 0;
2057
2058     if( class_count )
2059     {
2060         cvWriteInt( fs, "class_count", class_count );
2061
2062         if( class_labels )
2063             cvWrite( fs, "class_labels", class_labels );
2064
2065         if( class_weights )
2066             cvWrite( fs, "class_weights", class_weights );
2067     }
2068
2069     if( var_idx )
2070         cvWrite( fs, "var_idx", var_idx );
2071
2072     // write the joint collection of support vectors
2073     cvWriteInt( fs, "sv_total", sv_total );
2074     cvStartWriteStruct( fs, "support_vectors", CV_NODE_SEQ );
2075     for( i = 0; i < sv_total; i++ )
2076     {
2077         cvStartWriteStruct( fs, 0, CV_NODE_SEQ + CV_NODE_FLOW );
2078         cvWriteRawData( fs, sv[i], var_count, "f" );
2079         cvEndWriteStruct( fs );
2080     }
2081
2082     cvEndWriteStruct( fs );
2083
2084     // write decision functions
2085     df_count = class_count > 1 ? class_count*(class_count-1)/2 : 1;
2086     df = decision_func;
2087
2088     cvStartWriteStruct( fs, "decision_functions", CV_NODE_SEQ );
2089     for( i = 0; i < df_count; i++ )
2090     {
2091         int sv_count = df[i].sv_count;
2092         cvStartWriteStruct( fs, 0, CV_NODE_MAP );
2093         cvWriteInt( fs, "sv_count", sv_count );
2094         cvWriteReal( fs, "rho", df[i].rho );
2095         cvStartWriteStruct( fs, "alpha", CV_NODE_SEQ+CV_NODE_FLOW );
2096         cvWriteRawData( fs, df[i].alpha, df[i].sv_count, "d" );
2097         cvEndWriteStruct( fs );
2098         if( class_count > 1 )
2099         {
2100             cvStartWriteStruct( fs, "index", CV_NODE_SEQ+CV_NODE_FLOW );
2101             cvWriteRawData( fs, df[i].sv_index, df[i].sv_count, "i" );
2102             cvEndWriteStruct( fs );
2103         }
2104         else
2105             CV_ASSERT( sv_count == sv_total );
2106         cvEndWriteStruct( fs );
2107     }
2108     cvEndWriteStruct( fs );
2109     cvEndWriteStruct( fs );
2110
2111     __END__;
2112 }
2113
2114
2115 void CvSVM::read_params( CvFileStorage* fs, CvFileNode* svm_node )
2116 {
2117     CV_FUNCNAME( "CvSVM::read_params" );
2118     
2119     __BEGIN__;
2120     
2121     int svm_type, kernel_type;
2122     CvSVMParams _params;
2123
2124     CvFileNode* tmp_node = cvGetFileNodeByName( fs, svm_node, "svm_type" );
2125     CvFileNode* kernel_node;
2126     if( !tmp_node )
2127         CV_ERROR( CV_StsBadArg, "svm_type tag is not found" );
2128
2129     if( CV_NODE_TYPE(tmp_node->tag) == CV_NODE_INT )
2130         svm_type = cvReadInt( tmp_node, -1 );
2131     else
2132     {
2133         const char* svm_type_str = cvReadString( tmp_node, "" );
2134         svm_type =
2135             strcmp( svm_type_str, "C_SVC" ) == 0 ? CvSVM::C_SVC :
2136             strcmp( svm_type_str, "NU_SVC" ) == 0 ? CvSVM::NU_SVC :
2137             strcmp( svm_type_str, "ONE_CLASS" ) == 0 ? CvSVM::ONE_CLASS :
2138             strcmp( svm_type_str, "EPS_SVR" ) == 0 ? CvSVM::EPS_SVR :
2139             strcmp( svm_type_str, "NU_SVR" ) == 0 ? CvSVM::NU_SVR : -1;
2140
2141         if( svm_type < 0 )
2142             CV_ERROR( CV_StsParseError, "Missing of invalid SVM type" );
2143     }
2144
2145     kernel_node = cvGetFileNodeByName( fs, svm_node, "kernel" );
2146     if( !kernel_node )
2147         CV_ERROR( CV_StsParseError, "SVM kernel tag is not found" );
2148
2149     tmp_node = cvGetFileNodeByName( fs, kernel_node, "type" );
2150     if( !tmp_node )
2151         CV_ERROR( CV_StsParseError, "SVM kernel type tag is not found" );
2152
2153     if( CV_NODE_TYPE(tmp_node->tag) == CV_NODE_INT )
2154         kernel_type = cvReadInt( tmp_node, -1 );
2155     else
2156     {
2157         const char* kernel_type_str = cvReadString( tmp_node, "" );
2158         kernel_type =
2159             strcmp( kernel_type_str, "LINEAR" ) == 0 ? CvSVM::LINEAR :
2160             strcmp( kernel_type_str, "POLY" ) == 0 ? CvSVM::POLY :
2161             strcmp( kernel_type_str, "RBF" ) == 0 ? CvSVM::RBF :
2162             strcmp( kernel_type_str, "SIGMOID" ) == 0 ? CvSVM::SIGMOID : -1;
2163
2164         if( kernel_type < 0 )
2165             CV_ERROR( CV_StsParseError, "Missing of invalid SVM kernel type" );
2166     }
2167
2168     _params.svm_type = svm_type;
2169     _params.kernel_type = kernel_type;
2170     _params.degree = cvReadRealByName( fs, kernel_node, "degree", 0 );
2171     _params.gamma = cvReadRealByName( fs, kernel_node, "gamma", 0 );
2172     _params.coef0 = cvReadRealByName( fs, kernel_node, "coef0", 0 );
2173
2174     _params.C = cvReadRealByName( fs, svm_node, "C", 0 );
2175     _params.nu = cvReadRealByName( fs, svm_node, "nu", 0 );
2176     _params.p = cvReadRealByName( fs, svm_node, "p", 0 );
2177     _params.class_weights = 0;
2178
2179     tmp_node = cvGetFileNodeByName( fs, svm_node, "term_criteria" );
2180     if( tmp_node )
2181     {
2182         _params.term_crit.epsilon = cvReadRealByName( fs, tmp_node, "epsilon", -1. );
2183         _params.term_crit.max_iter = cvReadIntByName( fs, tmp_node, "iterations", -1 );
2184         _params.term_crit.type = (_params.term_crit.epsilon >= 0 ? CV_TERMCRIT_EPS : 0) +
2185                                (_params.term_crit.max_iter >= 0 ? CV_TERMCRIT_ITER : 0);
2186     }
2187     else
2188         _params.term_crit = cvTermCriteria( CV_TERMCRIT_EPS + CV_TERMCRIT_ITER, 1000, FLT_EPSILON );
2189
2190     set_params( _params );
2191
2192     __END__;
2193 }
2194
2195
2196 void CvSVM::read( CvFileStorage* fs, CvFileNode* svm_node )
2197 {
2198     const double not_found_dbl = DBL_MAX;
2199     
2200     CV_FUNCNAME( "CvSVM::read" );
2201
2202     __BEGIN__;
2203
2204     int i, var_count, df_count, class_count;
2205     int block_size = 1 << 16, sv_size;
2206     CvFileNode *sv_node, *df_node;
2207     CvSVMDecisionFunc* df;
2208     CvSeqReader reader;
2209
2210     if( !svm_node )
2211         CV_ERROR( CV_StsParseError, "The requested element is not found" );
2212
2213     clear();
2214
2215     // read SVM parameters
2216     read_params( fs, svm_node );
2217
2218     // and top-level data
2219     sv_total = cvReadIntByName( fs, svm_node, "sv_total", -1 );
2220     var_all = cvReadIntByName( fs, svm_node, "var_all", -1 );
2221     var_count = cvReadIntByName( fs, svm_node, "var_count", var_all );
2222     class_count = cvReadIntByName( fs, svm_node, "class_count", 0 );
2223
2224     if( sv_total <= 0 || var_all <= 0 || var_count <= 0 || var_count > var_all || class_count < 0 )
2225         CV_ERROR( CV_StsParseError, "SVM model data is invalid, check sv_count, var_* and class_count tags" );
2226
2227     CV_CALL( class_labels = (CvMat*)cvReadByName( fs, svm_node, "class_labels" ));
2228     CV_CALL( class_weights = (CvMat*)cvReadByName( fs, svm_node, "class_weights" ));
2229     CV_CALL( var_idx = (CvMat*)cvReadByName( fs, svm_node, "comp_idx" ));
2230
2231     if( class_count > 1 && (!class_labels ||
2232         !CV_IS_MAT(class_labels) || class_labels->cols != class_count))
2233         CV_ERROR( CV_StsParseError, "Array of class labels is missing or invalid" );
2234
2235     if( var_count < var_all && (!var_idx || !CV_IS_MAT(var_idx) || var_idx->cols != var_count) )
2236         CV_ERROR( CV_StsParseError, "var_idx array is missing or invalid" );
2237
2238     // read support vectors
2239     sv_node = cvGetFileNodeByName( fs, svm_node, "support_vectors" );
2240     if( !sv_node || !CV_NODE_IS_SEQ(sv_node->tag))
2241         CV_ERROR( CV_StsParseError, "Missing or invalid sequence of support vectors" );
2242
2243     block_size = MAX( block_size, sv_total*(int)sizeof(CvSVMKernelRow));
2244     block_size = MAX( block_size, sv_total*2*(int)sizeof(double));
2245     block_size = MAX( block_size, var_all*(int)sizeof(double));
2246     CV_CALL( storage = cvCreateMemStorage( block_size ));
2247     CV_CALL( sv = (float**)cvMemStorageAlloc( storage,
2248                                 sv_total*sizeof(sv[0]) ));
2249     
2250     CV_CALL( cvStartReadSeq( sv_node->data.seq, &reader, 0 ));
2251     sv_size = var_count*sizeof(sv[0][0]);
2252
2253     for( i = 0; i < sv_total; i++ )
2254     {
2255         CvFileNode* sv_elem = (CvFileNode*)reader.ptr;
2256         CV_ASSERT( var_count == 1 || (CV_NODE_IS_SEQ(sv_elem->tag) &&
2257                    sv_elem->data.seq->total == var_count) );
2258
2259         CV_CALL( sv[i] = (float*)cvMemStorageAlloc( storage, sv_size ));
2260         CV_CALL( cvReadRawData( fs, sv_elem, sv[i], "f" ));
2261         CV_NEXT_SEQ_ELEM( sv_node->data.seq->elem_size, reader );
2262     }
2263
2264     // read decision functions
2265     df_count = class_count > 1 ? class_count*(class_count-1)/2 : 1;
2266     df_node = cvGetFileNodeByName( fs, svm_node, "decision_functions" );
2267     if( !df_node || !CV_NODE_IS_SEQ(df_node->tag) ||
2268         df_node->data.seq->total != df_count )
2269         CV_ERROR( CV_StsParseError, "decision_functions is missing or is not a collection "
2270                   "or has a wrong number of elements" );
2271     
2272     CV_CALL( df = decision_func = (CvSVMDecisionFunc*)cvAlloc( df_count*sizeof(df[0]) ));
2273     cvStartReadSeq( df_node->data.seq, &reader, 0 );
2274
2275     for( i = 0; i < df_count; i++ )
2276     {
2277         CvFileNode* df_elem = (CvFileNode*)reader.ptr;
2278         CvFileNode* alpha_node = cvGetFileNodeByName( fs, df_elem, "alpha" );
2279
2280         int sv_count = cvReadIntByName( fs, df_elem, "sv_count", -1 );
2281         if( sv_count <= 0 )
2282             CV_ERROR( CV_StsParseError, "sv_count is missing or non-positive" );
2283         df[i].sv_count = sv_count;
2284
2285         df[i].rho = cvReadRealByName( fs, df_elem, "rho", not_found_dbl );
2286         if( fabs(df[i].rho - not_found_dbl) < DBL_EPSILON )
2287             CV_ERROR( CV_StsParseError, "rho is missing" );
2288
2289         if( !alpha_node )
2290             CV_ERROR( CV_StsParseError, "alpha is missing in the decision function" );
2291
2292         CV_CALL( df[i].alpha = (double*)cvMemStorageAlloc( storage,
2293                                         sv_count*sizeof(df[i].alpha[0])));
2294         CV_ASSERT( sv_count == 1 || CV_NODE_IS_SEQ(alpha_node->tag) &&
2295                    alpha_node->data.seq->total == sv_count );
2296         CV_CALL( cvReadRawData( fs, alpha_node, df[i].alpha, "d" ));
2297
2298         if( class_count > 1 )
2299         {
2300             CvFileNode* index_node = cvGetFileNodeByName( fs, df_elem, "index" );
2301             if( !index_node )
2302                 CV_ERROR( CV_StsParseError, "index is missing in the decision function" );
2303             CV_CALL( df[i].sv_index = (int*)cvMemStorageAlloc( storage,
2304                                             sv_count*sizeof(df[i].sv_index[0])));
2305             CV_ASSERT( sv_count == 1 || CV_NODE_IS_SEQ(index_node->tag) &&
2306                    index_node->data.seq->total == sv_count );
2307             CV_CALL( cvReadRawData( fs, index_node, df[i].sv_index, "i" ));
2308         }
2309         else
2310             df[i].sv_index = 0;
2311
2312         CV_NEXT_SEQ_ELEM( df_node->data.seq->elem_size, reader );
2313     }
2314
2315     create_kernel();
2316
2317     __END__;
2318 }
2319
2320 #if 0
2321
2322 static void*
2323 icvCloneSVM( const void* _src )
2324 {
2325     CvSVMModel* dst = 0;
2326     
2327     CV_FUNCNAME( "icvCloneSVM" );
2328
2329     __BEGIN__;
2330
2331     const CvSVMModel* src = (const CvSVMModel*)_src;
2332     int var_count, class_count;
2333     int i, sv_total, df_count;
2334     int sv_size;
2335
2336     if( !CV_IS_SVM(src) )
2337         CV_ERROR( !src ? CV_StsNullPtr : CV_StsBadArg, "Input pointer is NULL or invalid" );
2338
2339     // 0. create initial CvSVMModel structure
2340     CV_CALL( dst = icvCreateSVM() );
2341     dst->params = src->params;
2342     dst->params.weight_labels = 0;
2343     dst->params.weights = 0;
2344
2345     dst->var_all = src->var_all;
2346     if( src->class_labels )
2347         dst->class_labels = cvCloneMat( src->class_labels );
2348     if( src->class_weights )
2349         dst->class_weights = cvCloneMat( src->class_weights );
2350     if( src->comp_idx )
2351         dst->comp_idx = cvCloneMat( src->comp_idx );
2352
2353     var_count = src->comp_idx ? src->comp_idx->cols : src->var_all;
2354     class_count = src->class_labels ? src->class_labels->cols :
2355                   src->params.svm_type == CvSVM::ONE_CLASS ? 1 : 0;
2356     sv_total = dst->sv_total = src->sv_total;
2357     CV_CALL( dst->storage = cvCreateMemStorage( src->storage->block_size ));
2358     CV_CALL( dst->sv = (float**)cvMemStorageAlloc( dst->storage,
2359                                     sv_total*sizeof(dst->sv[0]) ));
2360     
2361     sv_size = var_count*sizeof(dst->sv[0][0]);
2362     
2363     for( i = 0; i < sv_total; i++ )
2364     {
2365         CV_CALL( dst->sv[i] = (float*)cvMemStorageAlloc( dst->storage, sv_size ));
2366         memcpy( dst->sv[i], src->sv[i], sv_size );
2367     }
2368
2369     df_count = class_count > 1 ? class_count*(class_count-1)/2 : 1;
2370     
2371     CV_CALL( dst->decision_func = cvAlloc( df_count*sizeof(CvSVMDecisionFunc) ));
2372
2373     for( i = 0; i < df_count; i++ )
2374     {
2375         const CvSVMDecisionFunc *sdf =
2376             (const CvSVMDecisionFunc*)src->decision_func+i;
2377         CvSVMDecisionFunc *ddf =
2378             (CvSVMDecisionFunc*)dst->decision_func+i;
2379         int sv_count = sdf->sv_count;
2380         ddf->sv_count = sv_count;
2381         ddf->rho = sdf->rho;
2382         CV_CALL( ddf->alpha = (double*)cvMemStorageAlloc( dst->storage,
2383                                         sv_count*sizeof(ddf->alpha[0])));
2384         memcpy( ddf->alpha, sdf->alpha, sv_count*sizeof(ddf->alpha[0]));
2385
2386         if( class_count > 1 )
2387         {
2388             CV_CALL( ddf->sv_index = (int*)cvMemStorageAlloc( dst->storage,
2389                                                 sv_count*sizeof(ddf->sv_index[0])));
2390             memcpy( ddf->sv_index, sdf->sv_index, sv_count*sizeof(ddf->sv_index[0]));
2391         }
2392         else
2393             ddf->sv_index = 0;
2394     }
2395
2396     __END__;
2397
2398     if( cvGetErrStatus() < 0 && dst )
2399         icvReleaseSVM( &dst );
2400     
2401     return dst;
2402 }
2403
2404 static int icvRegisterSVMType()
2405 {
2406     CvTypeInfo info;
2407     memset( &info, 0, sizeof(info) );
2408
2409     info.flags = 0;
2410     info.header_size = sizeof( info );
2411     info.is_instance = icvIsSVM;
2412     info.release = (CvReleaseFunc)icvReleaseSVM;
2413     info.read = icvReadSVM;
2414     info.write = icvWriteSVM;
2415     info.clone = icvCloneSVM;
2416     info.type_name = CV_TYPE_NAME_ML_SVM;
2417     cvRegisterType( &info );
2418
2419     return 1;
2420 }
2421
2422
2423 static int svm = icvRegisterSVMType();
2424
2425 /* The function trains SVM model with optimal parameters, obtained by using cross-validation.
2426 The parameters to be estimated should be indicated by setting theirs values to FLT_MAX.
2427 The optimal parameters are saved in <model_params> */
2428 CV_IMPL CvStatModel*
2429 cvTrainSVM_CrossValidation( const CvMat* train_data, int tflag,
2430             const CvMat* responses,
2431             CvStatModelParams* model_params,
2432             const CvStatModelParams* cross_valid_params,
2433             const CvMat* comp_idx,
2434             const CvMat* sample_idx,
2435             const CvParamGrid* degree_grid,
2436             const CvParamGrid* gamma_grid,
2437             const CvParamGrid* coef_grid,
2438             const CvParamGrid* C_grid,
2439             const CvParamGrid* nu_grid,
2440             const CvParamGrid* p_grid )
2441 {
2442     CvStatModel* svm = 0;
2443
2444     CV_FUNCNAME("cvTainSVMCrossValidation");
2445     __BEGIN__;
2446
2447     double degree_step = 7,
2448                g_step      = 15,
2449                    coef_step   = 14,
2450                    C_step      = 20,
2451                    nu_step     = 5,
2452                    p_step      = 7; // all steps must be > 1
2453     double degree_begin = 0.01, degree_end = 2;
2454     double g_begin      = 1e-5, g_end      = 0.5;
2455     double coef_begin   = 0.1,  coef_end   = 300;
2456     double C_begin      = 0.1,  C_end      = 6000;
2457     double nu_begin     = 0.01,  nu_end    = 0.4;
2458     double p_begin      = 0.01, p_end      = 100;
2459
2460     double rate = 0, gamma = 0, C = 0, degree = 0, coef = 0, p = 0, nu = 0;
2461
2462         double best_rate    = 0;
2463     double best_degree  = degree_begin;
2464     double best_gamma   = g_begin;
2465     double best_coef    = coef_begin;
2466         double best_C       = C_begin;
2467         double best_nu      = nu_begin;
2468     double best_p       = p_begin;
2469
2470     CvSVMModelParams svm_params, *psvm_params;
2471     CvCrossValidationParams* cv_params = (CvCrossValidationParams*)cross_valid_params;
2472     int svm_type, kernel;
2473     int is_regression;
2474
2475     if( !model_params )
2476         CV_ERROR( CV_StsBadArg, "" );
2477     if( !cv_params )
2478         CV_ERROR( CV_StsBadArg, "" );
2479
2480     svm_params = *(CvSVMModelParams*)model_params;
2481     psvm_params = (CvSVMModelParams*)model_params;
2482     svm_type = svm_params.svm_type;
2483     kernel = svm_params.kernel_type;
2484
2485     svm_params.degree = svm_params.degree > 0 ? svm_params.degree : 1;
2486     svm_params.gamma = svm_params.gamma > 0 ? svm_params.gamma : 1;
2487     svm_params.coef0 = svm_params.coef0 > 0 ? svm_params.coef0 : 1e-6;
2488     svm_params.C = svm_params.C > 0 ? svm_params.C : 1;
2489     svm_params.nu = svm_params.nu > 0 ? svm_params.nu : 1;
2490     svm_params.p = svm_params.p > 0 ? svm_params.p : 1;
2491
2492     if( degree_grid )
2493     {
2494         if( !(degree_grid->max_val == 0 && degree_grid->min_val == 0 &&
2495               degree_grid->step == 0) )
2496         {
2497             if( degree_grid->min_val > degree_grid->max_val )
2498                 CV_ERROR( CV_StsBadArg,
2499                 "low bound of grid should be less then the upper one");
2500             if( degree_grid->step <= 1 )
2501                 CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
2502             degree_begin = degree_grid->min_val;
2503             degree_end   = degree_grid->max_val;
2504             degree_step  = degree_grid->step;
2505         }
2506     }
2507     else
2508         degree_begin = degree_end = svm_params.degree;
2509
2510     if( gamma_grid )
2511     {
2512         if( !(gamma_grid->max_val == 0 && gamma_grid->min_val == 0 &&
2513               gamma_grid->step == 0) )
2514         {
2515             if( gamma_grid->min_val > gamma_grid->max_val )
2516                 CV_ERROR( CV_StsBadArg,
2517                 "low bound of grid should be less then the upper one");
2518             if( gamma_grid->step <= 1 )
2519                 CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
2520             g_begin = gamma_grid->min_val;
2521             g_end   = gamma_grid->max_val;
2522             g_step  = gamma_grid->step;
2523         }
2524     }
2525     else
2526         g_begin = g_end = svm_params.gamma;
2527
2528     if( coef_grid )
2529     {
2530         if( !(coef_grid->max_val == 0 && coef_grid->min_val == 0 &&
2531               coef_grid->step == 0) )
2532         {
2533             if( coef_grid->min_val > coef_grid->max_val )
2534                 CV_ERROR( CV_StsBadArg,
2535                 "low bound of grid should be less then the upper one");
2536             if( coef_grid->step <= 1 )
2537                 CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
2538             coef_begin = coef_grid->min_val;
2539             coef_end   = coef_grid->max_val;
2540             coef_step  = coef_grid->step;
2541         }
2542     }
2543     else
2544         coef_begin = coef_end = svm_params.coef0;
2545
2546     if( C_grid )
2547     {
2548         if( !(C_grid->max_val == 0 && C_grid->min_val == 0 && C_grid->step == 0))
2549         {
2550             if( C_grid->min_val > C_grid->max_val )
2551                 CV_ERROR( CV_StsBadArg,
2552                 "low bound of grid should be less then the upper one");
2553             if( C_grid->step <= 1 )
2554                 CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
2555             C_begin = C_grid->min_val;
2556             C_end   = C_grid->max_val;
2557             C_step  = C_grid->step;
2558         }
2559     }
2560     else
2561         C_begin = C_end = svm_params.C;
2562
2563     if( nu_grid )
2564     {
2565         if(!(nu_grid->max_val == 0 && nu_grid->min_val == 0 && nu_grid->step==0))
2566         {
2567             if( nu_grid->min_val > nu_grid->max_val )
2568                 CV_ERROR( CV_StsBadArg,
2569                 "low bound of grid should be less then the upper one");
2570             if( nu_grid->step <= 1 )
2571                 CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
2572             nu_begin = nu_grid->min_val;
2573             nu_end   = nu_grid->max_val;
2574             nu_step  = nu_grid->step;
2575         }
2576     }
2577     else
2578         nu_begin = nu_end = svm_params.nu;
2579
2580     if( p_grid )
2581     {
2582         if( !(p_grid->max_val == 0 && p_grid->min_val == 0 && p_grid->step == 0))
2583         {
2584             if( p_grid->min_val > p_grid->max_val )
2585                 CV_ERROR( CV_StsBadArg,
2586                 "low bound of grid should be less then the upper one");
2587             if( p_grid->step <= 1 )
2588                 CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
2589             p_begin = p_grid->min_val;
2590             p_end   = p_grid->max_val;
2591             p_step  = p_grid->step;
2592         }
2593     }
2594     else
2595         p_begin = p_end = svm_params.p;
2596
2597     // these parameters are not used:
2598     if( kernel != CvSVM::POLY )
2599         degree_begin = degree_end = svm_params.degree;
2600
2601    if( kernel == CvSVM::LINEAR )
2602         g_begin = g_end = svm_params.gamma;
2603
2604     if( kernel != CvSVM::POLY && kernel != CvSVM::SIGMOID )
2605         coef_begin = coef_end = svm_params.coef0;
2606  
2607     if( svm_type == CvSVM::NU_SVC || svm_type == CvSVM::ONE_CLASS )
2608         C_begin = C_end = svm_params.C;
2609
2610     if( svm_type == CvSVM::C_SVC || svm_type == CvSVM::EPS_SVR )
2611         nu_begin = nu_end = svm_params.nu;
2612
2613     if( svm_type != CvSVM::EPS_SVR )
2614         p_begin = p_end = svm_params.p;
2615
2616     is_regression = cv_params->is_regression;
2617     best_rate = is_regression ? FLT_MAX : 0;
2618
2619     assert( g_step > 1 && degree_step > 1 && coef_step > 1);
2620     assert( p_step > 1 && C_step > 1 && nu_step > 1 );
2621
2622     for( degree = degree_begin; degree <= degree_end; degree *= degree_step )
2623     {
2624       svm_params.degree = degree;
2625       //printf("degree = %.3f\n", degree );
2626       for( gamma= g_begin; gamma <= g_end; gamma *= g_step )
2627       {
2628         svm_params.gamma = gamma;
2629         //printf("   gamma = %.3f\n", gamma );
2630         for( coef = coef_begin; coef <= coef_end; coef *= coef_step )
2631         {
2632           svm_params.coef0 = coef;
2633           //printf("      coef = %.3f\n", coef );
2634           for( C = C_begin; C <= C_end; C *= C_step )
2635           {
2636             svm_params.C = C;
2637             //printf("         C = %.3f\n", C );
2638             for( nu = nu_begin; nu <= nu_end; nu *= nu_step )
2639             {
2640               svm_params.nu = nu;
2641               //printf("            nu = %.3f\n", nu );
2642               for( p = p_begin; p <= p_end; p *= p_step )
2643               {
2644                 int well;
2645                 svm_params.p = p;
2646                 //printf("               p = %.3f\n", p );
2647
2648                 CV_CALL(rate = cvCrossValidation( train_data, tflag, responses, &cvTrainSVM,
2649                     cross_valid_params, (CvStatModelParams*)&svm_params, comp_idx, sample_idx ));
2650
2651                 well =  rate > best_rate && !is_regression || rate < best_rate && is_regression;
2652                 if( well || (rate == best_rate && C < best_C) )
2653                 {
2654                     best_rate   = rate;
2655                     best_degree = degree;
2656                     best_gamma  = gamma;
2657                     best_coef   = coef;
2658                     best_C      = C;
2659                     best_nu     = nu;
2660                     best_p      = p;
2661                 }
2662                 //printf("                  rate = %.2f\n", rate );
2663               }
2664             }
2665           }
2666         }
2667       }
2668     }
2669     //printf("The best:\nrate = %.2f%% degree = %f gamma = %f coef = %f c = %f nu = %f p = %f\n",
2670       //  best_rate, best_degree, best_gamma, best_coef, best_C, best_nu, best_p );
2671
2672     psvm_params->C      = best_C;
2673     psvm_params->nu     = best_nu;
2674     psvm_params->p      = best_p;
2675     psvm_params->gamma  = best_gamma;
2676     psvm_params->degree = best_degree;
2677     psvm_params->coef0  = best_coef;
2678
2679     CV_CALL(svm = cvTrainSVM( train_data, tflag, responses, model_params, comp_idx, sample_idx ));
2680
2681     __END__;
2682
2683     return svm;
2684 }
2685
2686 #endif
2687
2688 /* End of file. */
2689