Update the changelog
[opencv] / cv / src / cvemd.cpp
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                        Intel License Agreement
11 //                For Open Source Computer Vision Library
12 //
13 // Copyright (C) 2000, Intel Corporation, all rights reserved.
14 // Third party copyrights are property of their respective owners.
15 //
16 // Redistribution and use in source and binary forms, with or without modification,
17 // are permitted provided that the following conditions are met:
18 //
19 //   * Redistribution's of source code must retain the above copyright notice,
20 //     this list of conditions and the following disclaimer.
21 //
22 //   * Redistribution's in binary form must reproduce the above copyright notice,
23 //     this list of conditions and the following disclaimer in the documentation
24 //     and/or other materials provided with the distribution.
25 //
26 //   * The name of Intel Corporation may not be used to endorse or promote products
27 //     derived from this software without specific prior written permission.
28 //
29 // This software is provided by the copyright holders and contributors "as is" and
30 // any express or implied warranties, including, but not limited to, the implied
31 // warranties of merchantability and fitness for a particular purpose are disclaimed.
32 // In no event shall the Intel Corporation or contributors be liable for any direct,
33 // indirect, incidental, special, exemplary, or consequential damages
34 // (including, but not limited to, procurement of substitute goods or services;
35 // loss of use, data, or profits; or business interruption) however caused
36 // and on any theory of liability, whether in contract, strict liability,
37 // or tort (including negligence or otherwise) arising in any way out of
38 // the use of this software, even if advised of the possibility of such damage.
39 //
40 //M*/
41
42 /*
43     Partially based on Yossi Rubner code:
44     =========================================================================
45     emd.c
46
47     Last update: 3/14/98
48
49     An implementation of the Earth Movers Distance.
50     Based of the solution for the Transportation problem as described in
51     "Introduction to Mathematical Programming" by F. S. Hillier and 
52     G. J. Lieberman, McGraw-Hill, 1990.
53
54     Copyright (C) 1998 Yossi Rubner
55     Computer Science Department, Stanford University
56     E-Mail: rubner@cs.stanford.edu   URL: http://vision.stanford.edu/~rubner
57     ==========================================================================
58 */
59 #include "_cv.h"
60
61 #define MAX_ITERATIONS 500
62 #define CV_EMD_INF   ((float)1e20)
63 #define CV_EMD_EPS   ((float)1e-5)
64
65 /* CvNode1D is used for lists, representing 1D sparse array */
66 typedef struct CvNode1D
67 {
68     float val;
69     struct CvNode1D *next;
70 }
71 CvNode1D;
72
73 /* CvNode2D is used for lists, representing 2D sparse matrix */
74 typedef struct CvNode2D
75 {
76     float val;
77     struct CvNode2D *next[2];  /* next row & next column */
78     int i, j;
79 }
80 CvNode2D;
81
82
83 typedef struct CvEMDState
84 {
85     int ssize, dsize;
86
87     float **cost;
88     CvNode2D *_x;
89     CvNode2D *end_x;
90     CvNode2D *enter_x;
91     char **is_x;
92
93     CvNode2D **rows_x;
94     CvNode2D **cols_x;
95
96     CvNode1D *u;
97     CvNode1D *v;
98
99     int* idx1;
100     int* idx2;
101
102     /* find_loop buffers */
103     CvNode2D **loop;
104     char *is_used;
105
106     /* russel buffers */
107     float *s;
108     float *d;
109     float **delta;
110
111     float weight, max_cost;
112     char *buffer;
113 }
114 CvEMDState;
115
116 /* static function declaration */
117 static CvStatus icvInitEMD( const float *signature1, int size1,
118                             const float *signature2, int size2,
119                             int dims, CvDistanceFunction dist_func, void *user_param,
120                             const float* cost, int cost_step,
121                             CvEMDState * state, float *lower_bound,
122                             char *local_buffer, int local_buffer_size );
123
124 static CvStatus icvFindBasicVariables( float **cost, char **is_x,
125                                        CvNode1D * u, CvNode1D * v, int ssize, int dsize );
126
127 static float icvIsOptimal( float **cost, char **is_x,
128                            CvNode1D * u, CvNode1D * v,
129                            int ssize, int dsize, CvNode2D * enter_x );
130
131 static void icvRussel( CvEMDState * state );
132
133
134 static CvStatus icvNewSolution( CvEMDState * state );
135 static int icvFindLoop( CvEMDState * state );
136
137 static void icvAddBasicVariable( CvEMDState * state,
138                                  int min_i, int min_j,
139                                  CvNode1D * prev_u_min_i,
140                                  CvNode1D * prev_v_min_j,
141                                  CvNode1D * u_head );
142
143 static float icvDistL2( const float *x, const float *y, void *user_param );
144 static float icvDistL1( const float *x, const float *y, void *user_param );
145 static float icvDistC( const float *x, const float *y, void *user_param );
146
147 /* The main function */
148 CV_IMPL float
149 cvCalcEMD2( const CvArr* signature_arr1,
150             const CvArr* signature_arr2,
151             int dist_type,
152             CvDistanceFunction dist_func,
153             const CvArr* cost_matrix,
154             CvArr* flow_matrix,
155             float *lower_bound,
156             void *user_param )
157 {
158     char local_buffer[16384];
159     char *local_buffer_ptr = (char *)cvAlignPtr(local_buffer,16);
160     CvEMDState state;
161     float emd = 0;
162
163     CV_FUNCNAME( "cvCalcEMD2" );
164
165     memset( &state, 0, sizeof(state));
166
167     __BEGIN__;
168
169     double total_cost = 0;
170     CvStatus result = CV_NO_ERR;
171     float eps, min_delta;
172     CvNode2D *xp = 0;
173     CvMat sign_stub1, *signature1 = (CvMat*)signature_arr1;
174     CvMat sign_stub2, *signature2 = (CvMat*)signature_arr2;
175     CvMat cost_stub, *cost = &cost_stub;
176     CvMat flow_stub, *flow = (CvMat*)flow_matrix;
177     int dims, size1, size2;
178
179     CV_CALL( signature1 = cvGetMat( signature1, &sign_stub1 ));
180     CV_CALL( signature2 = cvGetMat( signature2, &sign_stub2 ));
181
182     if( signature1->cols != signature2->cols )
183         CV_ERROR( CV_StsUnmatchedSizes, "The arrays must have equal number of columns (which is number of dimensions but 1)" );
184
185     dims = signature1->cols - 1;
186     size1 = signature1->rows;
187     size2 = signature2->rows;
188
189     if( !CV_ARE_TYPES_EQ( signature1, signature2 ))
190         CV_ERROR( CV_StsUnmatchedFormats, "The array must have equal types" );
191
192     if( CV_MAT_TYPE( signature1->type ) != CV_32FC1 )
193         CV_ERROR( CV_StsUnsupportedFormat, "The signatures must be 32fC1" );
194
195     if( flow )
196     {
197         CV_CALL( flow = cvGetMat( flow, &flow_stub ));
198
199         if( flow->rows != size1 || flow->cols != size2 )
200             CV_ERROR( CV_StsUnmatchedSizes,
201             "The flow matrix size does not match to the signatures' sizes" );
202
203         if( CV_MAT_TYPE( flow->type ) != CV_32FC1 )
204             CV_ERROR( CV_StsUnsupportedFormat, "The flow matrix must be 32fC1" );
205     }
206
207     cost->data.fl = 0;
208     cost->step = 0;
209
210     if( dist_type < 0 )
211     {
212         if( cost_matrix )
213         {
214             if( dist_func )
215                 CV_ERROR( CV_StsBadArg,
216                 "Only one of cost matrix or distance function should be non-NULL in case of user-defined distance" );
217
218             if( lower_bound )
219                 CV_ERROR( CV_StsBadArg,
220                 "The lower boundary can not be calculated if the cost matrix is used" );
221
222             CV_CALL( cost = cvGetMat( cost_matrix, &cost_stub ));
223             if( cost->rows != size1 || cost->cols != size2 )
224                 CV_ERROR( CV_StsUnmatchedSizes,
225                 "The cost matrix size does not match to the signatures' sizes" );
226
227             if( CV_MAT_TYPE( cost->type ) != CV_32FC1 )
228                 CV_ERROR( CV_StsUnsupportedFormat, "The cost matrix must be 32fC1" );
229         }
230         else if( !dist_func )
231             CV_ERROR( CV_StsNullPtr, "In case of user-defined distance Distance function is undefined" );
232     }
233     else
234     {
235         if( dims == 0 )
236             CV_ERROR( CV_StsBadSize,
237             "Number of dimensions can be 0 only if a user-defined metric is used" );
238         user_param = (void *) (size_t)dims;
239         switch (dist_type)
240         {
241         case CV_DIST_L1:
242             dist_func = icvDistL1;
243             break;
244         case CV_DIST_L2:
245             dist_func = icvDistL2;
246             break;
247         case CV_DIST_C:
248             dist_func = icvDistC;
249             break;
250         default:
251             CV_ERROR( CV_StsBadFlag, "Bad or unsupported metric type" );
252         }
253     }
254
255     IPPI_CALL( result = icvInitEMD( signature1->data.fl, size1,
256                                     signature2->data.fl, size2,
257                                     dims, dist_func, user_param,
258                                     cost->data.fl, cost->step,
259                                     &state, lower_bound, local_buffer_ptr,
260                                     sizeof( local_buffer ) - 16 ));
261
262     if( result > 0 && lower_bound )
263     {
264         emd = *lower_bound;
265         EXIT;
266     }
267
268     eps = CV_EMD_EPS * state.max_cost;
269
270     /* if ssize = 1 or dsize = 1 then we are done, else ... */
271     if( state.ssize > 1 && state.dsize > 1 )
272     {
273         int itr;
274
275         for( itr = 1; itr < MAX_ITERATIONS; itr++ )
276         {
277             /* find basic variables */
278             result = icvFindBasicVariables( state.cost, state.is_x,
279                                             state.u, state.v, state.ssize, state.dsize );
280             if( result < 0 )
281                 break;
282
283             /* check for optimality */
284             min_delta = icvIsOptimal( state.cost, state.is_x,
285                                       state.u, state.v,
286                                       state.ssize, state.dsize, state.enter_x );
287
288             if( min_delta == CV_EMD_INF )
289             {
290                 CV_ERROR( CV_StsNoConv, "" );
291             }
292
293             /* if no negative deltamin, we found the optimal solution */
294             if( min_delta >= -eps )
295                 break;
296
297             /* improve solution */
298             IPPI_CALL( icvNewSolution( &state ));
299         }
300     }
301
302     /* compute the total flow */
303     for( xp = state._x; xp < state.end_x; xp++ )
304     {
305         float val = xp->val;
306         int i = xp->i;
307         int j = xp->j;
308         int ci = state.idx1[i];
309         int cj = state.idx2[j];
310
311         if( xp != state.enter_x && ci >= 0 && cj >= 0 )
312         {
313             total_cost += (double)val * state.cost[i][j];
314             if( flow )
315                 ((float*)(flow->data.ptr + flow->step*ci))[cj] = val;
316         }
317     }
318
319     emd = (float) (total_cost / state.weight);
320
321     __END__;
322
323     if( state.buffer && state.buffer != local_buffer_ptr )
324         cvFree( &state.buffer );
325
326     return emd;
327 }
328
329
330 /************************************************************************************\
331 *          initialize structure, allocate buffers and generate initial golution      *
332 \************************************************************************************/
333 static CvStatus
334 icvInitEMD( const float* signature1, int size1,
335             const float* signature2, int size2,
336             int dims, CvDistanceFunction dist_func, void* user_param,
337             const float* cost, int cost_step,
338             CvEMDState* state, float* lower_bound,
339             char* local_buffer, int local_buffer_size )
340 {
341     float s_sum = 0, d_sum = 0, diff;
342     int i, j;
343     int ssize = 0, dsize = 0;
344     int equal_sums = 1;
345     int buffer_size;
346     float max_cost = 0;
347     char *buffer, *buffer_end;
348
349     memset( state, 0, sizeof( *state ));
350     assert( cost_step % sizeof(float) == 0 );
351     cost_step /= sizeof(float);
352
353     /* calculate buffer size */
354     buffer_size = (size1+1) * (size2+1) * (sizeof( float ) +    /* cost */
355                                    sizeof( char ) +     /* is_x */
356                                    sizeof( float )) +   /* delta matrix */
357         (size1 + size2 + 2) * (sizeof( CvNode2D ) + /* _x */
358                            sizeof( CvNode2D * ) +  /* cols_x & rows_x */
359                            sizeof( CvNode1D ) + /* u & v */
360                            sizeof( float ) + /* s & d */
361                            sizeof( int ) + sizeof(CvNode2D*)) +  /* idx1 & idx2 */ 
362         (size1+1) * (sizeof( float * ) + sizeof( char * ) + /* rows pointers for */
363                  sizeof( float * )) + 256;      /*  cost, is_x and delta */
364
365     if( buffer_size < (int) (dims * 2 * sizeof( float )))
366     {
367         buffer_size = dims * 2 * sizeof( float );
368     }
369
370     /* allocate buffers */
371     if( local_buffer != 0 && local_buffer_size >= buffer_size )
372     {
373         buffer = local_buffer;
374     }
375     else
376     {
377         buffer = (char*)cvAlloc( buffer_size );
378         if( !buffer )
379             return CV_OUTOFMEM_ERR;
380     }
381
382     state->buffer = buffer;
383     buffer_end = buffer + buffer_size;
384
385     state->idx1 = (int*) buffer;
386     buffer += (size1 + 1) * sizeof( int );
387
388     state->idx2 = (int*) buffer;
389     buffer += (size2 + 1) * sizeof( int );
390
391     state->s = (float *) buffer;
392     buffer += (size1 + 1) * sizeof( float );
393
394     state->d = (float *) buffer;
395     buffer += (size2 + 1) * sizeof( float );
396
397     /* sum up the supply and demand */
398     for( i = 0; i < size1; i++ )
399     {
400         float weight = signature1[i * (dims + 1)];
401
402         if( weight > 0 )
403         {
404             s_sum += weight;
405             state->s[ssize] = weight;
406             state->idx1[ssize++] = i;
407             
408         }
409         else if( weight < 0 )
410             return CV_BADRANGE_ERR;
411     }
412
413     for( i = 0; i < size2; i++ )
414     {
415         float weight = signature2[i * (dims + 1)];
416
417         if( weight > 0 )
418         {
419             d_sum += weight;
420             state->d[dsize] = weight;
421             state->idx2[dsize++] = i;
422         }
423         else if( weight < 0 )
424             return CV_BADRANGE_ERR;
425     }
426
427     if( ssize == 0 || dsize == 0 )
428         return CV_BADRANGE_ERR;
429
430     /* if supply different than the demand, add a zero-cost dummy cluster */
431     diff = s_sum - d_sum;
432     if( fabs( diff ) >= CV_EMD_EPS * s_sum )
433     {
434         equal_sums = 0;
435         if( diff < 0 )
436         {
437             state->s[ssize] = -diff;
438             state->idx1[ssize++] = -1;
439         }    
440         else
441         {
442             state->d[dsize] = diff;
443             state->idx2[dsize++] = -1;
444         }
445     }
446
447     state->ssize = ssize;
448     state->dsize = dsize;
449     state->weight = s_sum > d_sum ? s_sum : d_sum;
450
451     if( lower_bound && equal_sums )     /* check lower bound */
452     {
453         int sz1 = size1 * (dims + 1), sz2 = size2 * (dims + 1);
454         float lb = 0;
455
456         float* xs = (float *) buffer;
457         float* xd = xs + dims;
458
459         memset( xs, 0, dims*sizeof(xs[0]));
460         memset( xd, 0, dims*sizeof(xd[0]));
461
462         for( j = 0; j < sz1; j += dims + 1 )
463         {
464             float weight = signature1[j];
465             for( i = 0; i < dims; i++ )
466                 xs[i] += signature1[j + i + 1] * weight;
467         }
468
469         for( j = 0; j < sz2; j += dims + 1 )
470         {
471             float weight = signature2[j];
472             for( i = 0; i < dims; i++ )
473                 xd[i] += signature2[j + i + 1] * weight;
474         }
475
476         lb = dist_func( xs, xd, user_param ) / state->weight;
477         i = *lower_bound <= lb;
478         *lower_bound = lb;
479         if( i )
480             return ( CvStatus ) 1;
481     }
482
483     /* assign pointers */
484     state->is_used = (char *) buffer;
485     /* init delta matrix */
486     state->delta = (float **) buffer;
487     buffer += ssize * sizeof( float * );
488
489     for( i = 0; i < ssize; i++ )
490     {
491         state->delta[i] = (float *) buffer;
492         buffer += dsize * sizeof( float );
493     }
494
495     state->loop = (CvNode2D **) buffer;
496     buffer += (ssize + dsize + 1) * sizeof(CvNode2D*);
497
498     state->_x = state->end_x = (CvNode2D *) buffer;
499     buffer += (ssize + dsize) * sizeof( CvNode2D );
500
501     /* init cost matrix */
502     state->cost = (float **) buffer;
503     buffer += ssize * sizeof( float * );
504
505     /* compute the distance matrix */
506     for( i = 0; i < ssize; i++ )
507     {
508         int ci = state->idx1[i];
509
510         state->cost[i] = (float *) buffer;
511         buffer += dsize * sizeof( float );
512
513         if( ci >= 0 )
514         {
515             for( j = 0; j < dsize; j++ )
516             {
517                 int cj = state->idx2[j];
518                 if( cj < 0 )
519                     state->cost[i][j] = 0;
520                 else
521                 {
522                     float val;
523                     if( dist_func )
524                     {
525                         val = dist_func( signature1 + ci * (dims + 1) + 1,
526                                          signature2 + cj * (dims + 1) + 1,
527                                          user_param );
528                     }
529                     else
530                     {
531                         assert( cost );
532                         val = cost[cost_step*ci + cj];
533                     }
534                     state->cost[i][j] = val;
535                     if( max_cost < val )
536                         max_cost = val;
537                 }
538             }
539         }
540         else
541         {
542             for( j = 0; j < dsize; j++ )
543                 state->cost[i][j] = 0;
544         }
545     }
546
547     state->max_cost = max_cost;
548     
549     memset( buffer, 0, buffer_end - buffer );
550
551     state->rows_x = (CvNode2D **) buffer;
552     buffer += ssize * sizeof( CvNode2D * );
553
554     state->cols_x = (CvNode2D **) buffer;
555     buffer += dsize * sizeof( CvNode2D * );
556
557     state->u = (CvNode1D *) buffer;
558     buffer += ssize * sizeof( CvNode1D );
559
560     state->v = (CvNode1D *) buffer;
561     buffer += dsize * sizeof( CvNode1D );
562
563     /* init is_x matrix */
564     state->is_x = (char **) buffer;
565     buffer += ssize * sizeof( char * );
566
567     for( i = 0; i < ssize; i++ )
568     {
569         state->is_x[i] = buffer;
570         buffer += dsize;
571     }
572
573     assert( buffer <= buffer_end );
574
575     icvRussel( state );
576
577     state->enter_x = (state->end_x)++;
578     return CV_NO_ERR;
579 }
580
581
582 /****************************************************************************************\
583 *                              icvFindBasicVariables                                   *
584 \****************************************************************************************/
585 static CvStatus
586 icvFindBasicVariables( float **cost, char **is_x,
587                        CvNode1D * u, CvNode1D * v, int ssize, int dsize )
588 {
589     int i, j, found;
590     int u_cfound, v_cfound;
591     CvNode1D u0_head, u1_head, *cur_u, *prev_u;
592     CvNode1D v0_head, v1_head, *cur_v, *prev_v;
593
594     /* initialize the rows list (u) and the columns list (v) */
595     u0_head.next = u;
596     for( i = 0; i < ssize; i++ )
597     {
598         u[i].next = u + i + 1;
599     }
600     u[ssize - 1].next = 0;
601     u1_head.next = 0;
602
603     v0_head.next = ssize > 1 ? v + 1 : 0;
604     for( i = 1; i < dsize; i++ )
605     {
606         v[i].next = v + i + 1;
607     }
608     v[dsize - 1].next = 0;
609     v1_head.next = 0;
610
611     /* there are ssize+dsize variables but only ssize+dsize-1 independent equations,
612        so set v[0]=0 */
613     v[0].val = 0;
614     v1_head.next = v;
615     v1_head.next->next = 0;
616
617     /* loop until all variables are found */
618     u_cfound = v_cfound = 0;
619     while( u_cfound < ssize || v_cfound < dsize )
620     {
621         found = 0;
622         if( v_cfound < dsize )
623         {
624             /* loop over all marked columns */
625             prev_v = &v1_head;
626
627             for( found |= (cur_v = v1_head.next) != 0; cur_v != 0; cur_v = cur_v->next )
628             {
629                 float cur_v_val = cur_v->val;
630
631                 j = (int)(cur_v - v);
632                 /* find the variables in column j */
633                 prev_u = &u0_head;
634                 for( cur_u = u0_head.next; cur_u != 0; )
635                 {
636                     i = (int)(cur_u - u);
637                     if( is_x[i][j] )
638                     {
639                         /* compute u[i] */
640                         cur_u->val = cost[i][j] - cur_v_val;
641                         /* ...and add it to the marked list */
642                         prev_u->next = cur_u->next;
643                         cur_u->next = u1_head.next;
644                         u1_head.next = cur_u;
645                         cur_u = prev_u->next;
646                     }
647                     else
648                     {
649                         prev_u = cur_u;
650                         cur_u = cur_u->next;
651                     }
652                 }
653                 prev_v->next = cur_v->next;
654                 v_cfound++;
655             }
656         }
657
658         if( u_cfound < ssize )
659         {
660             /* loop over all marked rows */
661             prev_u = &u1_head;
662             for( found |= (cur_u = u1_head.next) != 0; cur_u != 0; cur_u = cur_u->next )
663             {
664                 float cur_u_val = cur_u->val;
665                 float *_cost;
666                 char *_is_x;
667
668                 i = (int)(cur_u - u);
669                 _cost = cost[i];
670                 _is_x = is_x[i];
671                 /* find the variables in rows i */
672                 prev_v = &v0_head;
673                 for( cur_v = v0_head.next; cur_v != 0; )
674                 {
675                     j = (int)(cur_v - v);
676                     if( _is_x[j] )
677                     {
678                         /* compute v[j] */
679                         cur_v->val = _cost[j] - cur_u_val;
680                         /* ...and add it to the marked list */
681                         prev_v->next = cur_v->next;
682                         cur_v->next = v1_head.next;
683                         v1_head.next = cur_v;
684                         cur_v = prev_v->next;
685                     }
686                     else
687                     {
688                         prev_v = cur_v;
689                         cur_v = cur_v->next;
690                     }
691                 }
692                 prev_u->next = cur_u->next;
693                 u_cfound++;
694             }
695         }
696
697         if( !found )
698         {
699             return CV_NOTDEFINED_ERR;
700         }
701     }
702
703     return CV_NO_ERR;
704 }
705
706
707 /****************************************************************************************\
708 *                                   icvIsOptimal                                       *
709 \****************************************************************************************/
710 static float
711 icvIsOptimal( float **cost, char **is_x,
712               CvNode1D * u, CvNode1D * v, int ssize, int dsize, CvNode2D * enter_x )
713 {
714     float delta, min_delta = CV_EMD_INF;
715     int i, j, min_i = 0, min_j = 0;
716
717     /* find the minimal cij-ui-vj over all i,j */
718     for( i = 0; i < ssize; i++ )
719     {
720         float u_val = u[i].val;
721         float *_cost = cost[i];
722         char *_is_x = is_x[i];
723
724         for( j = 0; j < dsize; j++ )
725         {
726             if( !_is_x[j] )
727             {
728                 delta = _cost[j] - u_val - v[j].val;
729                 if( min_delta > delta )
730                 {
731                     min_delta = delta;
732                     min_i = i;
733                     min_j = j;
734                 }
735             }
736         }
737     }
738
739     enter_x->i = min_i;
740     enter_x->j = min_j;
741
742     return min_delta;
743 }
744
745 /****************************************************************************************\
746 *                                   icvNewSolution                                     *
747 \****************************************************************************************/
748 static CvStatus
749 icvNewSolution( CvEMDState * state )
750 {
751     int i, j;
752     float min_val = CV_EMD_INF;
753     int steps;
754     CvNode2D head, *cur_x, *next_x, *leave_x = 0;
755     CvNode2D *enter_x = state->enter_x;
756     CvNode2D **loop = state->loop;
757
758     /* enter the new basic variable */
759     i = enter_x->i;
760     j = enter_x->j;
761     state->is_x[i][j] = 1;
762     enter_x->next[0] = state->rows_x[i];
763     enter_x->next[1] = state->cols_x[j];
764     enter_x->val = 0;
765     state->rows_x[i] = enter_x;
766     state->cols_x[j] = enter_x;
767
768     /* find a chain reaction */
769     steps = icvFindLoop( state );
770
771     if( steps == 0 )
772         return CV_NOTDEFINED_ERR;
773
774     /* find the largest value in the loop */
775     for( i = 1; i < steps; i += 2 )
776     {
777         float temp = loop[i]->val;
778
779         if( min_val > temp )
780         {
781             leave_x = loop[i];
782             min_val = temp;
783         }
784     }
785
786     /* update the loop */
787     for( i = 0; i < steps; i += 2 )
788     {
789         float temp0 = loop[i]->val + min_val;
790         float temp1 = loop[i + 1]->val - min_val;
791
792         loop[i]->val = temp0;
793         loop[i + 1]->val = temp1;
794     }
795
796     /* remove the leaving basic variable */
797     i = leave_x->i;
798     j = leave_x->j;
799     state->is_x[i][j] = 0;
800
801     head.next[0] = state->rows_x[i];
802     cur_x = &head;
803     while( (next_x = cur_x->next[0]) != leave_x )
804     {
805         cur_x = next_x;
806         assert( cur_x );
807     }
808     cur_x->next[0] = next_x->next[0];
809     state->rows_x[i] = head.next[0];
810
811     head.next[1] = state->cols_x[j];
812     cur_x = &head;
813     while( (next_x = cur_x->next[1]) != leave_x )
814     {
815         cur_x = next_x;
816         assert( cur_x );
817     }
818     cur_x->next[1] = next_x->next[1];
819     state->cols_x[j] = head.next[1];
820
821     /* set enter_x to be the new empty slot */
822     state->enter_x = leave_x;
823
824     return CV_NO_ERR;
825 }
826
827
828
829 /****************************************************************************************\
830 *                                    icvFindLoop                                       *
831 \****************************************************************************************/
832 static int
833 icvFindLoop( CvEMDState * state )
834 {
835     int i, steps = 1;
836     CvNode2D *new_x;
837     CvNode2D **loop = state->loop;
838     CvNode2D *enter_x = state->enter_x, *_x = state->_x;
839     char *is_used = state->is_used;
840
841     memset( is_used, 0, state->ssize + state->dsize );
842
843     new_x = loop[0] = enter_x;
844     is_used[enter_x - _x] = 1;
845     steps = 1;
846
847     do
848     {
849         if( (steps & 1) == 1 )
850         {
851             /* find an unused x in the row */
852             new_x = state->rows_x[new_x->i];
853             while( new_x != 0 && is_used[new_x - _x] )
854                 new_x = new_x->next[0];
855         }
856         else
857         {
858             /* find an unused x in the column, or the entering x */
859             new_x = state->cols_x[new_x->j];
860             while( new_x != 0 && is_used[new_x - _x] && new_x != enter_x )
861                 new_x = new_x->next[1];
862             if( new_x == enter_x )
863                 break;
864         }
865
866         if( new_x != 0 )        /* found the next x */
867         {
868             /* add x to the loop */
869             loop[steps++] = new_x;
870             is_used[new_x - _x] = 1;
871         }
872         else                    /* didn't find the next x */
873         {
874             /* backtrack */
875             do
876             {
877                 i = steps & 1;
878                 new_x = loop[steps - 1];
879                 do
880                 {
881                     new_x = new_x->next[i];
882                 }
883                 while( new_x != 0 && is_used[new_x - _x] );
884
885                 if( new_x == 0 )
886                 {
887                     is_used[loop[--steps] - _x] = 0;
888                 }
889             }
890             while( new_x == 0 && steps > 0 );
891
892             is_used[loop[steps - 1] - _x] = 0;
893             loop[steps - 1] = new_x;
894             is_used[new_x - _x] = 1;
895         }
896     }
897     while( steps > 0 );
898
899     return steps;
900 }
901
902
903
904 /****************************************************************************************\
905 *                                        icvRussel                                     *
906 \****************************************************************************************/
907 static void
908 icvRussel( CvEMDState * state )
909 {
910     int i, j, min_i = -1, min_j = -1;
911     float min_delta, diff;
912     CvNode1D u_head, *cur_u, *prev_u;
913     CvNode1D v_head, *cur_v, *prev_v;
914     CvNode1D *prev_u_min_i = 0, *prev_v_min_j = 0, *remember;
915     CvNode1D *u = state->u, *v = state->v;
916     int ssize = state->ssize, dsize = state->dsize;
917     float eps = CV_EMD_EPS * state->max_cost;
918     float **cost = state->cost;
919     float **delta = state->delta;
920
921     /* initialize the rows list (ur), and the columns list (vr) */
922     u_head.next = u;
923     for( i = 0; i < ssize; i++ )
924     {
925         u[i].next = u + i + 1;
926     }
927     u[ssize - 1].next = 0;
928
929     v_head.next = v;
930     for( i = 0; i < dsize; i++ )
931     {
932         v[i].val = -CV_EMD_INF;
933         v[i].next = v + i + 1;
934     }
935     v[dsize - 1].next = 0;
936
937     /* find the maximum row and column values (ur[i] and vr[j]) */
938     for( i = 0; i < ssize; i++ )
939     {
940         float u_val = -CV_EMD_INF;
941         float *cost_row = cost[i];
942
943         for( j = 0; j < dsize; j++ )
944         {
945             float temp = cost_row[j];
946
947             if( u_val < temp )
948                 u_val = temp;
949             if( v[j].val < temp )
950                 v[j].val = temp;
951         }
952         u[i].val = u_val;
953     }
954
955     /* compute the delta matrix */
956     for( i = 0; i < ssize; i++ )
957     {
958         float u_val = u[i].val;
959         float *delta_row = delta[i];
960         float *cost_row = cost[i];
961
962         for( j = 0; j < dsize; j++ )
963         {
964             delta_row[j] = cost_row[j] - u_val - v[j].val;
965         }
966     }
967
968     /* find the basic variables */
969     do
970     {
971         /* find the smallest delta[i][j] */
972         min_i = -1;
973         min_delta = CV_EMD_INF;
974         prev_u = &u_head;
975         for( cur_u = u_head.next; cur_u != 0; cur_u = cur_u->next )
976         {
977             i = (int)(cur_u - u);
978             float *delta_row = delta[i];
979
980             prev_v = &v_head;
981             for( cur_v = v_head.next; cur_v != 0; cur_v = cur_v->next )
982             {
983                 j = (int)(cur_v - v);
984                 if( min_delta > delta_row[j] )
985                 {
986                     min_delta = delta_row[j];
987                     min_i = i;
988                     min_j = j;
989                     prev_u_min_i = prev_u;
990                     prev_v_min_j = prev_v;
991                 }
992                 prev_v = cur_v;
993             }
994             prev_u = cur_u;
995         }
996
997         if( min_i < 0 )
998             break;
999
1000         /* add x[min_i][min_j] to the basis, and adjust supplies and cost */
1001         remember = prev_u_min_i->next;
1002         icvAddBasicVariable( state, min_i, min_j, prev_u_min_i, prev_v_min_j, &u_head );
1003
1004         /* update the necessary delta[][] */
1005         if( remember == prev_u_min_i->next )    /* line min_i was deleted */
1006         {
1007             for( cur_v = v_head.next; cur_v != 0; cur_v = cur_v->next )
1008             {
1009                 j = (int)(cur_v - v);
1010                 if( cur_v->val == cost[min_i][j] )      /* column j needs updating */
1011                 {
1012                     float max_val = -CV_EMD_INF;
1013
1014                     /* find the new maximum value in the column */
1015                     for( cur_u = u_head.next; cur_u != 0; cur_u = cur_u->next )
1016                     {
1017                         float temp = cost[cur_u - u][j];
1018
1019                         if( max_val < temp )
1020                             max_val = temp;
1021                     }
1022
1023                     /* if needed, adjust the relevant delta[*][j] */
1024                     diff = max_val - cur_v->val;
1025                     cur_v->val = max_val;
1026                     if( fabs( diff ) < eps )
1027                     {
1028                         for( cur_u = u_head.next; cur_u != 0; cur_u = cur_u->next )
1029                             delta[cur_u - u][j] += diff;
1030                     }
1031                 }
1032             }
1033         }
1034         else                    /* column min_j was deleted */
1035         {
1036             for( cur_u = u_head.next; cur_u != 0; cur_u = cur_u->next )
1037             {
1038                 i = (int)(cur_u - u);
1039                 if( cur_u->val == cost[i][min_j] )      /* row i needs updating */
1040                 {
1041                     float max_val = -CV_EMD_INF;
1042
1043                     /* find the new maximum value in the row */
1044                     for( cur_v = v_head.next; cur_v != 0; cur_v = cur_v->next )
1045                     {
1046                         float temp = cost[i][cur_v - v];
1047
1048                         if( max_val < temp )
1049                             max_val = temp;
1050                     }
1051
1052                     /* if needed, adjust the relevant delta[i][*] */
1053                     diff = max_val - cur_u->val;
1054                     cur_u->val = max_val;
1055
1056                     if( fabs( diff ) < eps )
1057                     {
1058                         for( cur_v = v_head.next; cur_v != 0; cur_v = cur_v->next )
1059                             delta[i][cur_v - v] += diff;
1060                     }
1061                 }
1062             }
1063         }
1064     }
1065     while( u_head.next != 0 || v_head.next != 0 );
1066 }
1067
1068
1069
1070 /****************************************************************************************\
1071 *                                   icvAddBasicVariable                                *
1072 \****************************************************************************************/
1073 static void
1074 icvAddBasicVariable( CvEMDState * state,
1075                      int min_i, int min_j,
1076                      CvNode1D * prev_u_min_i, CvNode1D * prev_v_min_j, CvNode1D * u_head )
1077 {
1078     float temp;
1079     CvNode2D *end_x = state->end_x;
1080
1081     if( state->s[min_i] < state->d[min_j] + state->weight * CV_EMD_EPS )
1082     {                           /* supply exhausted */
1083         temp = state->s[min_i];
1084         state->s[min_i] = 0;
1085         state->d[min_j] -= temp;
1086     }
1087     else                        /* demand exhausted */
1088     {
1089         temp = state->d[min_j];
1090         state->d[min_j] = 0;
1091         state->s[min_i] -= temp;
1092     }
1093
1094     /* x(min_i,min_j) is a basic variable */
1095     state->is_x[min_i][min_j] = 1;
1096
1097     end_x->val = temp;
1098     end_x->i = min_i;
1099     end_x->j = min_j;
1100     end_x->next[0] = state->rows_x[min_i];
1101     end_x->next[1] = state->cols_x[min_j];
1102     state->rows_x[min_i] = end_x;
1103     state->cols_x[min_j] = end_x;
1104     state->end_x = end_x + 1;
1105
1106     /* delete supply row only if the empty, and if not last row */
1107     if( state->s[min_i] == 0 && u_head->next->next != 0 )
1108         prev_u_min_i->next = prev_u_min_i->next->next;  /* remove row from list */
1109     else
1110         prev_v_min_j->next = prev_v_min_j->next->next;  /* remove column from list */
1111 }
1112
1113
1114 /****************************************************************************************\
1115 *                                  standard  metrics                                     *
1116 \****************************************************************************************/
1117 static float
1118 icvDistL1( const float *x, const float *y, void *user_param )
1119 {
1120     int i, dims = (int)(size_t)user_param;
1121     double s = 0;
1122
1123     for( i = 0; i < dims; i++ )
1124     {
1125         double t = x[i] - y[i];
1126
1127         s += fabs( t );
1128     }
1129     return (float)s;
1130 }
1131
1132 static float
1133 icvDistL2( const float *x, const float *y, void *user_param )
1134 {
1135     int i, dims = (int)(size_t)user_param;
1136     double s = 0;
1137
1138     for( i = 0; i < dims; i++ )
1139     {
1140         double t = x[i] - y[i];
1141
1142         s += t * t;
1143     }
1144     return cvSqrt( (float)s );
1145 }
1146
1147 static float
1148 icvDistC( const float *x, const float *y, void *user_param )
1149 {
1150     int i, dims = (int)(size_t)user_param;
1151     double s = 0;
1152
1153     for( i = 0; i < dims; i++ )
1154     {
1155         double t = fabs( x[i] - y[i] );
1156
1157         if( s < t )
1158             s = t;
1159     }
1160     return (float)s;
1161 }
1162
1163 /* End of file. */
1164