Update the trunk to the OpenCV's CVS (2008-07-14)
[opencv] / samples / c / letter_recog.cpp
1 #include "ml.h"
2 #include <stdio.h>
3
4 /*
5 The sample demonstrates how to train Random Trees classifier
6 (or Boosting classifier, or MLP - see main()) using the provided dataset.
7
8 We use the sample database letter-recognition.data
9 from UCI Repository, here is the link:
10
11 Newman, D.J. & Hettich, S. & Blake, C.L. & Merz, C.J. (1998).
12 UCI Repository of machine learning databases
13 [http://www.ics.uci.edu/~mlearn/MLRepository.html].
14 Irvine, CA: University of California, Department of Information and Computer Science.
15
16 The dataset consists of 20000 feature vectors along with the
17 responses - capital latin letters A..Z.
18 The first 16000 (10000 for boosting)) samples are used for training
19 and the remaining 4000 (10000 for boosting) - to test the classifier.
20 */
21
22 // This function reads data and responses from the file <filename>
23 static int
24 read_num_class_data( const char* filename, int var_count,
25                      CvMat** data, CvMat** responses )
26 {
27     const int M = 1024;
28     FILE* f = fopen( filename, "rt" );
29     CvMemStorage* storage;
30     CvSeq* seq;
31     char buf[M+2];
32     float* el_ptr;
33     CvSeqReader reader;
34     int i, j;
35
36     if( !f )
37         return 0;
38
39     el_ptr = new float[var_count+1];
40     storage = cvCreateMemStorage();
41     seq = cvCreateSeq( 0, sizeof(*seq), (var_count+1)*sizeof(float), storage );
42
43     for(;;)
44     {
45         char* ptr;
46         if( !fgets( buf, M, f ) || !strchr( buf, ',' ) )
47             break;
48         el_ptr[0] = buf[0];
49         ptr = buf+2;
50         for( i = 1; i <= var_count; i++ )
51         {
52             int n = 0;
53             sscanf( ptr, "%f%n", el_ptr + i, &n );
54             ptr += n + 1;
55         }
56         if( i <= var_count )
57             break;
58         cvSeqPush( seq, el_ptr );
59     }
60     fclose(f);
61
62     *data = cvCreateMat( seq->total, var_count, CV_32F );
63     *responses = cvCreateMat( seq->total, 1, CV_32F );
64
65     cvStartReadSeq( seq, &reader );
66
67     for( i = 0; i < seq->total; i++ )
68     {
69         const float* sdata = (float*)reader.ptr + 1;
70         float* ddata = data[0]->data.fl + var_count*i;
71         float* dr = responses[0]->data.fl + i;
72
73         for( j = 0; j < var_count; j++ )
74             ddata[j] = sdata[j];
75         *dr = sdata[-1];
76         CV_NEXT_SEQ_ELEM( seq->elem_size, reader );
77     }
78
79     cvReleaseMemStorage( &storage );
80     delete el_ptr;
81     return 1;
82 }
83
84 static
85 int build_rtrees_classifier( char* data_filename,
86     char* filename_to_save, char* filename_to_load )
87 {
88     CvMat* data = 0;
89     CvMat* responses = 0;
90     CvMat* var_type = 0;
91     CvMat* sample_idx = 0;
92
93     int ok = read_num_class_data( data_filename, 16, &data, &responses );
94     int nsamples_all = 0, ntrain_samples = 0;
95     int i = 0;
96     double train_hr = 0, test_hr = 0;
97     CvRTrees forest;
98     CvMat* var_importance = 0;
99
100     if( !ok )
101     {
102         printf( "Could not read the database %s\n", data_filename );
103         return -1;
104     }
105
106     printf( "The database %s is loaded.\n", data_filename );
107     nsamples_all = data->rows;
108     ntrain_samples = (int)(nsamples_all*0.8);
109
110     // Create or load Random Trees classifier
111     if( filename_to_load )
112     {
113         // load classifier from the specified file
114         forest.load( filename_to_load );
115         ntrain_samples = 0;
116         if( forest.get_tree_count() == 0 )
117         {
118             printf( "Could not read the classifier %s\n", filename_to_load );
119             return -1;
120         }
121         printf( "The classifier %s is loaded.\n", data_filename );
122     }
123     else
124     {
125         // create classifier by using <data> and <responses>
126         printf( "Training the classifier ...");
127
128         // 1. create type mask
129         var_type = cvCreateMat( data->cols + 1, 1, CV_8U );
130         cvSet( var_type, cvScalarAll(CV_VAR_ORDERED) );
131         cvSetReal1D( var_type, data->cols, CV_VAR_CATEGORICAL );
132
133         // 2. create sample_idx
134         sample_idx = cvCreateMat( 1, nsamples_all, CV_8UC1 );
135         {
136             CvMat mat;
137             cvGetCols( sample_idx, &mat, 0, ntrain_samples );
138             cvSet( &mat, cvRealScalar(1) );
139
140             cvGetCols( sample_idx, &mat, ntrain_samples, nsamples_all );
141             cvSetZero( &mat );
142         }
143
144         // 3. train classifier
145         forest.train( data, CV_ROW_SAMPLE, responses, 0, sample_idx, var_type, 0,
146             CvRTParams(10,10,0,false,15,0,true,4,100,0.01f,CV_TERMCRIT_ITER));
147         printf( "\n");
148     }
149
150     // compute prediction error on train and test data
151     for( i = 0; i < nsamples_all; i++ )
152     {
153         double r;
154         CvMat sample;
155         cvGetRow( data, &sample, i );
156
157         r = forest.predict( &sample );
158         r = fabs((double)r - responses->data.fl[i]) <= FLT_EPSILON ? 1 : 0;
159
160         if( i < ntrain_samples )
161             train_hr += r;
162         else
163             test_hr += r;
164     }
165
166     test_hr /= (double)(nsamples_all-ntrain_samples);
167     train_hr /= (double)ntrain_samples;
168     printf( "Recognition rate: train = %.1f%%, test = %.1f%%\n",
169             train_hr*100., test_hr*100. );
170
171     printf( "Number of trees: %d\n", forest.get_tree_count() );
172
173     // Print variable importance
174     var_importance = (CvMat*)forest.get_var_importance();
175     if( var_importance )
176     {
177         double rt_imp_sum = cvSum( var_importance ).val[0];
178         printf("var#\timportance (in %%):\n");
179         for( i = 0; i < var_importance->cols; i++ )
180             printf( "%-2d\t%-4.1f\n", i,
181             100.f*var_importance->data.fl[i]/rt_imp_sum);
182     }
183
184     //Print some proximitites
185     printf( "Proximities between some samples corresponding to the letter 'T':\n" );
186     {
187         CvMat sample1, sample2;
188         const int pairs[][2] = {{0,103}, {0,106}, {106,103}, {-1,-1}};
189
190         for( i = 0; pairs[i][0] >= 0; i++ )
191         {
192             cvGetRow( data, &sample1, pairs[i][0] );
193             cvGetRow( data, &sample2, pairs[i][1] );
194             printf( "proximity(%d,%d) = %.1f%%\n", pairs[i][0], pairs[i][1],
195                 forest.get_proximity( &sample1, &sample2 )*100. );
196         }
197     }
198
199     // Save Random Trees classifier to file if needed
200     if( filename_to_save )
201         forest.save( filename_to_save );
202
203     cvReleaseMat( &sample_idx );
204     cvReleaseMat( &var_type );
205     cvReleaseMat( &data );
206     cvReleaseMat( &responses );
207
208     return 0;
209 }
210
211
212 static
213 int build_boost_classifier( char* data_filename,
214     char* filename_to_save, char* filename_to_load )
215 {
216     const int class_count = 26;
217     CvMat* data = 0;
218     CvMat* responses = 0;
219     CvMat* var_type = 0;
220     CvMat* temp_sample = 0;
221     CvMat* weak_responses = 0;
222
223     int ok = read_num_class_data( data_filename, 16, &data, &responses );
224     int nsamples_all = 0, ntrain_samples = 0;
225     int var_count;
226     int i, j, k;
227     double train_hr = 0, test_hr = 0;
228     CvBoost boost;
229
230     if( !ok )
231     {
232         printf( "Could not read the database %s\n", data_filename );
233         return -1;
234     }
235
236     printf( "The database %s is loaded.\n", data_filename );
237     nsamples_all = data->rows;
238     ntrain_samples = (int)(nsamples_all*0.5);
239     var_count = data->cols;
240
241     // Create or load Boosted Tree classifier
242     if( filename_to_load )
243     {
244         // load classifier from the specified file
245         boost.load( filename_to_load );
246         ntrain_samples = 0;
247         if( !boost.get_weak_predictors() )
248         {
249             printf( "Could not read the classifier %s\n", filename_to_load );
250             return -1;
251         }
252         printf( "The classifier %s is loaded.\n", data_filename );
253     }
254     else
255     {
256         // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
257         //
258         // As currently boosted tree classifier in MLL can only be trained
259         // for 2-class problems, we transform the training database by
260         // "unrolling" each training sample as many times as the number of
261         // classes (26) that we have.
262         //
263         // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
264
265         CvMat* new_data = cvCreateMat( ntrain_samples*class_count, var_count + 1, CV_32F );
266         CvMat* new_responses = cvCreateMat( ntrain_samples*class_count, 1, CV_32S );
267
268         // 1. unroll the database type mask
269         printf( "Unrolling the database...\n");
270         for( i = 0; i < ntrain_samples; i++ )
271         {
272             float* data_row = (float*)(data->data.ptr + data->step*i);
273             for( j = 0; j < class_count; j++ )
274             {
275                 float* new_data_row = (float*)(new_data->data.ptr +
276                                 new_data->step*(i*class_count+j));
277                 for( k = 0; k < var_count; k++ )
278                     new_data_row[k] = data_row[k];
279                 new_data_row[var_count] = (float)j;
280                 new_responses->data.i[i*class_count + j] = responses->data.fl[i] == j+'A';
281             }
282         }
283
284         // 2. create type mask
285         var_type = cvCreateMat( var_count + 2, 1, CV_8U );
286         cvSet( var_type, cvScalarAll(CV_VAR_ORDERED) );
287         // the last indicator variable, as well
288         // as the new (binary) response are categorical
289         cvSetReal1D( var_type, var_count, CV_VAR_CATEGORICAL );
290         cvSetReal1D( var_type, var_count+1, CV_VAR_CATEGORICAL );
291
292         // 3. train classifier
293         printf( "Training the classifier (may take a few minutes)...");
294         boost.train( new_data, CV_ROW_SAMPLE, new_responses, 0, 0, var_type, 0,
295             CvBoostParams(CvBoost::REAL, 100, 0.95, 5, false, 0 ));
296         cvReleaseMat( &new_data );
297         cvReleaseMat( &new_responses );
298         printf("\n");
299     }
300
301     temp_sample = cvCreateMat( 1, var_count + 1, CV_32F );
302     weak_responses = cvCreateMat( 1, boost.get_weak_predictors()->total, CV_32F ); 
303
304     // compute prediction error on train and test data
305     for( i = 0; i < nsamples_all; i++ )
306     {
307         int best_class = 0;
308         double max_sum = -DBL_MAX;
309         double r;
310         CvMat sample;
311         cvGetRow( data, &sample, i );
312         for( k = 0; k < var_count; k++ )
313             temp_sample->data.fl[k] = sample.data.fl[k];
314
315         for( j = 0; j < class_count; j++ )
316         {
317             temp_sample->data.fl[var_count] = (float)j;
318             boost.predict( temp_sample, 0, weak_responses );
319             double sum = cvSum( weak_responses ).val[0];
320             if( max_sum < sum )
321             {
322                 max_sum = sum;
323                 best_class = j + 'A';
324             }
325         }
326
327         r = fabs(best_class - responses->data.fl[i]) < FLT_EPSILON ? 1 : 0;
328
329         if( i < ntrain_samples )
330             train_hr += r;
331         else
332             test_hr += r;
333     }
334
335     test_hr /= (double)(nsamples_all-ntrain_samples);
336     train_hr /= (double)ntrain_samples;
337     printf( "Recognition rate: train = %.1f%%, test = %.1f%%\n",
338             train_hr*100., test_hr*100. );
339
340     printf( "Number of trees: %d\n", boost.get_weak_predictors()->total );
341
342     // Save classifier to file if needed
343     if( filename_to_save )
344         boost.save( filename_to_save );
345
346     cvReleaseMat( &temp_sample );
347     cvReleaseMat( &weak_responses );
348     cvReleaseMat( &var_type );
349     cvReleaseMat( &data );
350     cvReleaseMat( &responses );
351
352     return 0;
353 }
354
355
356 static
357 int build_mlp_classifier( char* data_filename,
358     char* filename_to_save, char* filename_to_load )
359 {
360     const int class_count = 26;
361     CvMat* data = 0;
362     CvMat train_data;
363     CvMat* responses = 0;
364     CvMat* mlp_response = 0;
365
366     int ok = read_num_class_data( data_filename, 16, &data, &responses );
367     int nsamples_all = 0, ntrain_samples = 0;
368     int i, j;
369     double train_hr = 0, test_hr = 0;
370     CvANN_MLP mlp;
371
372     if( !ok )
373     {
374         printf( "Could not read the database %s\n", data_filename );
375         return -1;
376     }
377
378     printf( "The database %s is loaded.\n", data_filename );
379     nsamples_all = data->rows;
380     ntrain_samples = (int)(nsamples_all*0.8);
381
382     // Create or load MLP classifier
383     if( filename_to_load )
384     {
385         // load classifier from the specified file
386         mlp.load( filename_to_load );
387         ntrain_samples = 0;
388         if( !mlp.get_layer_count() )
389         {
390             printf( "Could not read the classifier %s\n", filename_to_load );
391             return -1;
392         }
393         printf( "The classifier %s is loaded.\n", data_filename );
394     }
395     else
396     {
397         // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
398         //
399         // MLP does not support categorical variables by explicitly.
400         // So, instead of the output class label, we will use
401         // a binary vector of <class_count> components for training and,
402         // therefore, MLP will give us a vector of "probabilities" at the
403         // prediction stage
404         //
405         // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
406
407         CvMat* new_responses = cvCreateMat( ntrain_samples, class_count, CV_32F );
408
409         // 1. unroll the responses
410         printf( "Unrolling the responses...\n");
411         for( i = 0; i < ntrain_samples; i++ )
412         {
413             int cls_label = cvRound(responses->data.fl[i]) - 'A';
414             float* bit_vec = (float*)(new_responses->data.ptr + i*new_responses->step);
415             for( j = 0; j < class_count; j++ )
416                 bit_vec[j] = 0.f;
417             bit_vec[cls_label] = 1.f;
418         }
419         cvGetRows( data, &train_data, 0, ntrain_samples );
420
421         // 2. train classifier
422         int layer_sz[] = { data->cols, 100, 100, class_count };
423         CvMat layer_sizes =
424             cvMat( 1, (int)(sizeof(layer_sz)/sizeof(layer_sz[0])), CV_32S, layer_sz );
425         mlp.create( &layer_sizes );
426         printf( "Training the classifier (may take a few minutes)...");
427         mlp.train( &train_data, new_responses, 0, 0,
428             CvANN_MLP_TrainParams(cvTermCriteria(CV_TERMCRIT_ITER,300,0.01),
429             CvANN_MLP_TrainParams::RPROP,0.01));
430         cvReleaseMat( &new_responses );
431         printf("\n");
432     }
433
434     mlp_response = cvCreateMat( 1, class_count, CV_32F );
435
436     // compute prediction error on train and test data
437     for( i = 0; i < nsamples_all; i++ )
438     {
439         int best_class;
440         CvMat sample;
441         cvGetRow( data, &sample, i );
442         CvPoint max_loc = {0,0};
443         mlp.predict( &sample, mlp_response );
444         cvMinMaxLoc( mlp_response, 0, 0, 0, &max_loc, 0 );
445         best_class = max_loc.x + 'A';
446
447         int r = fabs((double)best_class - responses->data.fl[i]) < FLT_EPSILON ? 1 : 0;
448
449         if( i < ntrain_samples )
450             train_hr += r;
451         else
452             test_hr += r;
453     }
454
455     test_hr /= (double)(nsamples_all-ntrain_samples);
456     train_hr /= (double)ntrain_samples;
457     printf( "Recognition rate: train = %.1f%%, test = %.1f%%\n",
458             train_hr*100., test_hr*100. );
459
460     // Save classifier to file if needed
461     if( filename_to_save )
462         mlp.save( filename_to_save );
463
464     cvReleaseMat( &mlp_response );
465     cvReleaseMat( &data );
466     cvReleaseMat( &responses );
467
468     return 0;
469 }
470
471
472 int main( int argc, char *argv[] )
473 {
474     char* filename_to_save = 0;
475     char* filename_to_load = 0;
476     char default_data_filename[] = "./letter-recognition.data";
477     char* data_filename = default_data_filename;
478     int method = 0;
479
480     int i;
481     for( i = 1; i < argc; i++ )
482     {
483         if( strcmp(argv[i],"-data") == 0 ) // flag "-data letter_recognition.xml"
484         {
485             i++;
486             data_filename = argv[i];
487         }
488         else if( strcmp(argv[i],"-save") == 0 ) // flag "-save filename.xml"
489         {
490             i++;
491             filename_to_save = argv[i];
492         }
493         else if( strcmp(argv[i],"-load") == 0) // flag "-load filename.xml"
494         {
495             i++;
496             filename_to_load = argv[i];
497         }
498         else if( strcmp(argv[i],"-boost") == 0)
499         {
500             method = 1;
501         }
502         else if( strcmp(argv[i],"-mlp") == 0 )
503         {
504             method = 2;
505         }
506         else
507             break;
508     }
509
510     if( i < argc ||
511         (method == 0 ?
512         build_rtrees_classifier( data_filename, filename_to_save, filename_to_load ) :
513         method == 1 ?
514         build_boost_classifier( data_filename, filename_to_save, filename_to_load ) :
515         method == 2 ?
516         build_mlp_classifier( data_filename, filename_to_save, filename_to_load ) :
517         -1) < 0)
518     {
519         printf("This is letter recognition sample.\n"
520                 "The usage: letter_recog [-data <path to letter-recognition.data>] \\\n"
521                 "  [-save <output XML file for the classifier>] \\\n"
522                 "  [-load <XML file with the pre-trained classifier>] \\\n"
523                 "  [-boost|-mlp] # to use boost/mlp classifier instead of default Random Trees\n" );
524     }
525     return 0;
526 }