2afda210167a6014b57c09a016c4cbf502798351
[opencv] / src / cv / _cvkdtree.hpp
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                        Intel License Agreement
11 //                For Open Source Computer Vision Library
12 //
13 // Copyright (C) 2008, Xavier Delacour, all rights reserved.
14 // Third party copyrights are property of their respective owners.
15 //
16 // Redistribution and use in source and binary forms, with or without modification,
17 // are permitted provided that the following conditions are met:
18 //
19 //   * Redistribution's of source code must retain the above copyright notice,
20 //     this list of conditions and the following disclaimer.
21 //
22 //   * Redistribution's in binary form must reproduce the above copyright notice,
23 //     this list of conditions and the following disclaimer in the documentation
24 //     and/or other materials provided with the distribution.
25 //
26 //   * The name of Intel Corporation may not be used to endorse or promote products
27 //     derived from this software without specific prior written permission.
28 //
29 // This software is provided by the copyright holders and contributors "as is" and
30 // any express or implied warranties, including, but not limited to, the implied
31 // warranties of merchantability and fitness for a particular purpose are disclaimed.
32 // In no event shall the Intel Corporation or contributors be liable for any direct,
33 // indirect, incidental, special, exemplary, or consequential damages
34 // (including, but not limited to, procurement of substitute goods or services;
35 // loss of use, data, or profits; or business interruption) however caused
36 // and on any theory of liability, whether in contract, strict liability,
37 // or tort (including negligence or otherwise) arising in any way out of
38 // the use of this software, even if advised of the possibility of such damage.
39 //
40 //M*/
41
42 // 2008-05-13, Xavier Delacour <xavier.delacour@gmail.com>
43
44 #ifndef __cv_kdtree_h__
45 #define __cv_kdtree_h__
46
47 #include "_cv.h"
48
49 #include <vector>
50 #include <algorithm>
51 #include <limits>
52 #include <iostream>
53 #include "assert.h"
54 #include "math.h"
55
56 #if _MSC_VER >= 1400
57 #pragma warning(disable: 4512) // suppress "assignment operator could not be generated"
58 #endif
59
60 // J.S. Beis and D.G. Lowe. Shape indexing using approximate nearest-neighbor search 
61 // in highdimensional spaces. In Proc. IEEE Conf. Comp. Vision Patt. Recog., 
62 // pages 1000--1006, 1997. http://citeseer.ist.psu.edu/beis97shape.html 
63 #undef __deref
64 #undef __valuetype
65
66 template < class __valuetype, class __deref >
67 class CvKDTree {
68 public:
69   typedef __deref deref_type;
70   typedef typename __deref::scalar_type scalar_type;
71   typedef typename __deref::accum_type accum_type;
72
73 private:
74   struct node {
75     int dim;                    // split dimension; >=0 for nodes, -1 for leaves
76     __valuetype value;          // if leaf, value of leaf
77     int left, right;            // node indices of left and right branches
78     scalar_type boundary;       // left if deref(value,dim)<=boundary, otherwise right
79   };
80   typedef std::vector < node > node_array;
81
82   __deref deref;                // requires operator() (__valuetype lhs,int dim)
83
84   node_array nodes;             // node storage
85   int point_dim;                // dimension of points (the k in kd-tree)
86   int root_node;                // index of root node, -1 if empty tree
87
88   // for given set of point indices, compute dimension of highest variance
89   template < class __instype, class __valuector >
90   int dimension_of_highest_variance(__instype * first, __instype * last,
91                                     __valuector ctor) {
92     assert(last - first > 0);
93
94     accum_type maxvar = -std::numeric_limits < accum_type >::max();
95     int maxj = -1;
96     for (int j = 0; j < point_dim; ++j) {
97       accum_type mean = 0;
98       for (__instype * k = first; k < last; ++k)
99         mean += deref(ctor(*k), j);
100       mean /= last - first;
101       accum_type var = 0;
102       for (__instype * k = first; k < last; ++k) {
103         accum_type diff = accum_type(deref(ctor(*k), j)) - mean;
104         var += diff * diff;
105       }
106       var /= last - first;
107
108       assert(maxj != -1 || var >= maxvar);
109
110       if (var >= maxvar) {
111         maxvar = var;
112         maxj = j;
113       }
114     }
115
116     return maxj;
117   }
118
119   // given point indices and dimension, find index of median; (almost) modifies [first,last) 
120   // such that points_in[first,median]<=point[median], points_in(median,last)>point[median].
121   // implemented as partial quicksort; expected linear perf.
122   template < class __instype, class __valuector >
123   __instype * median_partition(__instype * first, __instype * last,
124                                int dim, __valuector ctor) {
125     assert(last - first > 0);
126     __instype *k = first + (last - first) / 2;
127     median_partition(first, last, k, dim, ctor);
128     return k;
129   }
130
131   template < class __instype, class __valuector >
132   struct median_pr {
133     const __instype & pivot;
134     int dim;
135     __deref deref;
136     __valuector ctor;
137     median_pr(const __instype & _pivot, int _dim, __deref _deref, __valuector _ctor)
138       : pivot(_pivot), dim(_dim), deref(_deref), ctor(_ctor) {
139     }
140     bool operator() (const __instype & lhs) const {
141       return deref(ctor(lhs), dim) <= deref(ctor(pivot), dim);
142     }
143   };
144
145   template < class __instype, class __valuector >
146   void median_partition(__instype * first, __instype * last, 
147                         __instype * k, int dim, __valuector ctor) {
148     int pivot = (int)((last - first) / 2);
149
150     std::swap(first[pivot], last[-1]);
151     __instype *middle = std::partition(first, last - 1,
152                                        median_pr < __instype, __valuector > 
153                                        (last[-1], dim, deref, ctor));
154     std::swap(*middle, last[-1]);
155
156     if (middle < k)
157       median_partition(middle + 1, last, k, dim, ctor);
158     else if (middle > k)
159       median_partition(first, middle, k, dim, ctor);
160   }
161
162   // insert given points into the tree; return created node
163   template < class __instype, class __valuector >
164   int insert(__instype * first, __instype * last, __valuector ctor) {
165     if (first == last)
166       return -1;
167     else {
168
169       int dim = dimension_of_highest_variance(first, last, ctor);
170       __instype *median = median_partition(first, last, dim, ctor);
171
172       __instype *split = median;
173       for (; split != last && deref(ctor(*split), dim) == 
174              deref(ctor(*median), dim); ++split);
175
176       if (split == last) { // leaf
177         int nexti = -1;
178         for (--split; split >= first; --split) {
179           int i = (int)nodes.size();
180           node & n = *nodes.insert(nodes.end(), node());
181           n.dim = -1;
182           n.value = ctor(*split);
183           n.left = -1;
184           n.right = nexti;
185           nexti = i;
186         }
187
188         return nexti;
189       } else { // node
190         int i = (int)nodes.size();
191         // note that recursive insert may invalidate this ref
192         node & n = *nodes.insert(nodes.end(), node());
193
194         n.dim = dim;
195         n.boundary = deref(ctor(*median), dim);
196
197         int left = insert(first, split, ctor);
198         nodes[i].left = left;
199         int right = insert(split, last, ctor);
200         nodes[i].right = right;
201
202         return i;
203       }
204     }
205   }
206
207   // run to leaf; linear search for p;
208   // if found, remove paths to empty leaves on unwind
209   bool remove(int *i, const __valuetype & p) {
210     if (*i == -1)
211       return false;
212     node & n = nodes[*i];
213     bool r;
214
215     if (n.dim >= 0) { // node
216       if (deref(p, n.dim) <= n.boundary) // left
217         r = remove(&n.left, p);
218       else // right
219         r = remove(&n.right, p);
220
221       // if terminal, remove this node
222       if (n.left == -1 && n.right == -1)
223         *i = -1;
224
225       return r;
226     } else { // leaf
227       if (n.value == p) {
228         *i = n.right;
229         return true;
230       } else
231         return remove(&n.right, p);
232     }
233   }
234
235 public:
236   struct identity_ctor {
237     const __valuetype & operator() (const __valuetype & rhs) const {
238       return rhs;
239     }
240   };
241
242   // initialize an empty tree
243   CvKDTree(__deref _deref = __deref())
244     : deref(_deref), root_node(-1) {
245   }
246   // given points, initialize a balanced tree
247   CvKDTree(__valuetype * first, __valuetype * last, int _point_dim,
248            __deref _deref = __deref())
249     : deref(_deref) {
250     set_data(first, last, _point_dim, identity_ctor());
251   }
252   // given points, initialize a balanced tree
253   template < class __instype, class __valuector >
254   CvKDTree(__instype * first, __instype * last, int _point_dim,
255            __valuector ctor, __deref _deref = __deref())
256     : deref(_deref) {
257     set_data(first, last, _point_dim, ctor);
258   }
259
260   void set_deref(__deref _deref) {
261     deref = _deref;
262   }
263
264   void set_data(__valuetype * first, __valuetype * last, int _point_dim) {
265     set_data(first, last, _point_dim, identity_ctor());
266   }
267   template < class __instype, class __valuector >
268   void set_data(__instype * first, __instype * last, int _point_dim,
269                 __valuector ctor) {
270     point_dim = _point_dim;
271     nodes.clear();
272     nodes.reserve(last - first);
273     root_node = insert(first, last, ctor);
274   }
275
276   int dims() const {
277     return point_dim;
278   }
279
280   // remove the given point
281   bool remove(const __valuetype & p) {
282     return remove(&root_node, p);
283   }
284
285   void print() const {
286     print(root_node);
287   }
288   void print(int i, int indent = 0) const {
289     if (i == -1)
290       return;
291     for (int j = 0; j < indent; ++j)
292       std::cout << " ";
293     const node & n = nodes[i];
294     if (n.dim >= 0) {
295       std::cout << "node " << i << ", left " << nodes[i].left << ", right " << 
296         nodes[i].right << ", dim " << nodes[i].dim << ", boundary " << 
297         nodes[i].boundary << std::endl;
298       print(n.left, indent + 3);
299       print(n.right, indent + 3);
300     } else
301       std::cout << "leaf " << i << ", value = " << nodes[i].value << std::endl;
302   }
303
304   ////////////////////////////////////////////////////////////////////////////////////////
305   // bbf search
306 public:
307   struct bbf_nn {               // info on found neighbors (approx k nearest)
308     const __valuetype *p;       // nearest neighbor
309     accum_type dist;            // distance from d to query point
310     bbf_nn(const __valuetype & _p, accum_type _dist)
311       : p(&_p), dist(_dist) {
312     }
313     bool operator<(const bbf_nn & rhs) const {
314       return dist < rhs.dist;
315     }
316   };
317   typedef std::vector < bbf_nn > bbf_nn_pqueue;
318 private:
319   struct bbf_node {             // info on branches not taken
320     int node;                   // corresponding node
321     accum_type dist;            // minimum distance from bounds to query point
322     bbf_node(int _node, accum_type _dist)
323       : node(_node), dist(_dist) {
324     }
325     bool operator<(const bbf_node & rhs) const {
326       return dist > rhs.dist;
327     }
328   };
329   typedef std::vector < bbf_node > bbf_pqueue;
330   mutable bbf_pqueue tmp_pq;
331
332   // called for branches not taken, as bbf walks to leaf;
333   // construct bbf_node given minimum distance to bounds of alternate branch
334   void pq_alternate(int alt_n, bbf_pqueue & pq, scalar_type dist) const {
335     if (alt_n == -1)
336       return;
337
338     // add bbf_node for alternate branch in priority queue
339     pq.push_back(bbf_node(alt_n, dist));
340     push_heap(pq.begin(), pq.end());
341   }
342
343   // called by bbf to walk to leaf;
344   // takes one step down the tree towards query point d
345   template < class __desctype >
346   int bbf_branch(int i, const __desctype * d, bbf_pqueue & pq) const {
347     const node & n = nodes[i];
348     // push bbf_node with bounds of alternate branch, then branch
349     if (d[n.dim] <= n.boundary) {       // left
350       pq_alternate(n.right, pq, n.boundary - d[n.dim]);
351       return n.left;
352     } else {                    // right
353       pq_alternate(n.left, pq, d[n.dim] - n.boundary);
354       return n.right;
355     }
356   }
357
358   // compute euclidean distance between two points
359   template < class __desctype >
360   accum_type distance(const __desctype * d, const __valuetype & p) const {
361     accum_type dist = 0;
362     for (int j = 0; j < point_dim; ++j) {
363       accum_type diff = accum_type(d[j]) - accum_type(deref(p, j));
364       dist += diff * diff;
365     } return (accum_type) sqrt(dist);
366   }
367
368   // called per candidate nearest neighbor; constructs new bbf_nn for
369   // candidate and adds it to priority queue of all candidates; if 
370   // queue len exceeds k, drops the point furthest from query point d.
371   template < class __desctype >
372   void bbf_new_nn(bbf_nn_pqueue & nn_pq, int k, 
373                   const __desctype * d, const __valuetype & p) const {
374     bbf_nn nn(p, distance(d, p));
375     if ((int) nn_pq.size() < k) {
376       nn_pq.push_back(nn);
377       push_heap(nn_pq.begin(), nn_pq.end());
378     } else if (nn_pq[0].dist > nn.dist) {
379       pop_heap(nn_pq.begin(), nn_pq.end());
380       nn_pq.end()[-1] = nn;
381       push_heap(nn_pq.begin(), nn_pq.end());
382     }
383     assert(nn_pq.size() < 2 || nn_pq[0].dist >= nn_pq[1].dist);
384   }
385
386 public:
387   // finds (with high probability) the k nearest neighbors of d, 
388   // searching at most emax leaves/bins.
389   // ret_nn_pq is an array containing the (at most) k nearest neighbors 
390   // (see bbf_nn structure def above).
391   template < class __desctype >
392   int find_nn_bbf(const __desctype * d, 
393                   int k, int emax, 
394                   bbf_nn_pqueue & ret_nn_pq) const {
395     assert(k > 0);
396     ret_nn_pq.clear();
397
398     if (root_node == -1)
399       return 0;
400
401     // add root_node to bbf_node priority queue;
402     // iterate while queue non-empty and emax>0
403     tmp_pq.clear();
404     tmp_pq.push_back(bbf_node(root_node, 0));
405     while (tmp_pq.size() && emax > 0) {
406
407       // from node nearest query point d, run to leaf
408       pop_heap(tmp_pq.begin(), tmp_pq.end());
409       bbf_node bbf(tmp_pq.end()[-1]);
410       tmp_pq.erase(tmp_pq.end() - 1);
411
412       int i;
413       for (i = bbf.node;
414            i != -1 && nodes[i].dim >= 0; 
415            i = bbf_branch(i, d, tmp_pq));
416
417       if (i != -1) {
418
419         // add points in leaf/bin to ret_nn_pq
420         do {
421           bbf_new_nn(ret_nn_pq, k, d, nodes[i].value);
422         } while (-1 != (i = nodes[i].right));
423
424         --emax;
425       }
426     }
427
428     tmp_pq.clear();
429     return (int)ret_nn_pq.size();
430   }
431
432   ////////////////////////////////////////////////////////////////////////////////////////
433   // orthogonal range search
434 private:
435   void find_ortho_range(int i, scalar_type * bounds_min,
436                         scalar_type * bounds_max,
437                         std::vector < __valuetype > &inbounds) const {
438     if (i == -1)
439       return;
440     const node & n = nodes[i];
441     if (n.dim >= 0) { // node
442       if (bounds_min[n.dim] <= n.boundary)
443         find_ortho_range(n.left, bounds_min, bounds_max, inbounds);
444       if (bounds_max[n.dim] > n.boundary)
445         find_ortho_range(n.right, bounds_min, bounds_max, inbounds);
446     } else { // leaf
447       do {
448         inbounds.push_back(nodes[i].value);
449       } while (-1 != (i = nodes[i].right));
450     }
451   }
452 public:
453   // return all points that lie within the given bounds; inbounds is cleared
454   int find_ortho_range(scalar_type * bounds_min,
455                        scalar_type * bounds_max,
456                        std::vector < __valuetype > &inbounds) const {
457     inbounds.clear();
458     find_ortho_range(root_node, bounds_min, bounds_max, inbounds);
459     return (int)inbounds.size();
460   }
461 };
462
463 #endif // __cv_kdtree_h__
464
465 // Local Variables:
466 // mode:C++
467 // End: