Update to 2.0.0 tree from current Fremantle build
[opencv] / 3rdparty / flann / util / result_set.h
1 /***********************************************************************
2  * Software License Agreement (BSD License)
3  *
4  * Copyright 2008-2009  Marius Muja (mariusm@cs.ubc.ca). All rights reserved.
5  * Copyright 2008-2009  David G. Lowe (lowe@cs.ubc.ca). All rights reserved.
6  *
7  * THE BSD LICENSE
8  *
9  * Redistribution and use in source and binary forms, with or without
10  * modification, are permitted provided that the following conditions
11  * are met:
12  *
13  * 1. Redistributions of source code must retain the above copyright
14  *    notice, this list of conditions and the following disclaimer.
15  * 2. Redistributions in binary form must reproduce the above copyright
16  *    notice, this list of conditions and the following disclaimer in the
17  *    documentation and/or other materials provided with the distribution.
18  *
19  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
20  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
21  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
22  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
23  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
24  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
25  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
26  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
28  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29  *************************************************************************/
30
31 #ifndef RESULTSET_H
32 #define RESULTSET_H
33
34
35 #include <algorithm>
36 #include <limits>
37 #include <vector>
38 #include "dist.h"
39
40 using namespace std;
41
42
43 namespace flann
44 {
45
46 /* This record represents a branch point when finding neighbors in
47         the tree.  It contains a record of the minimum distance to the query
48         point, as well as the node at which the search resumes.
49 */
50
51 template <typename T>
52 struct BranchStruct {
53         T node;           /* Tree node at which search resumes */
54         float mindistsq;     /* Minimum distance to query for all nodes below. */
55
56         bool operator<(const BranchStruct<T>& rhs)
57         {
58         return mindistsq<rhs.mindistsq;
59         }
60
61     static BranchStruct<T> make_branch(T aNode, float dist)
62     {
63         BranchStruct<T> branch;
64         branch.node = aNode;
65         branch.mindistsq = dist;
66         return branch;
67     }
68 };
69
70
71
72
73
74
75 class ResultSet
76 {
77 protected:
78         const float* target;
79         const float* target_end;
80     int veclen;
81
82 public:
83
84         ResultSet(float* target_ = NULL, int veclen_ = 0) :
85                 target(target_), veclen(veclen_) { target_end = target + veclen;}
86
87         virtual ~ResultSet() {}
88
89         virtual void init(const float* target_, int veclen_) = 0;
90
91         virtual int* getNeighbors() = 0;
92
93         virtual float* getDistances() = 0;
94
95         virtual int size() const = 0;
96
97         virtual bool full() const = 0;
98
99         virtual bool addPoint(float* point, int index) = 0;
100
101         virtual float worstDist() const = 0;
102
103 };
104
105
106 class KNNResultSet : public ResultSet
107 {
108         int* indices;
109         float* dists;
110     int capacity;
111
112         int count;
113
114 public:
115         KNNResultSet(int capacity_, float* target_ = NULL, int veclen_ = 0 ) :
116         ResultSet(target_, veclen_), capacity(capacity_), count(0)
117         {
118         indices = new int[capacity_];
119         dists = new float[capacity_];
120         }
121
122         ~KNNResultSet()
123         {
124                 delete[] indices;
125                 delete[] dists;
126         }
127
128         void init(const float* target_, int veclen_)
129         {
130         target = target_;
131         veclen = veclen_;
132         target_end = target + veclen;
133         count = 0;
134         }
135
136
137         int* getNeighbors()
138         {
139                 return indices;
140         }
141
142     float* getDistances()
143     {
144         return dists;
145     }
146
147     int size() const
148     {
149         return count;
150     }
151
152         bool full() const
153         {
154                 return count == capacity;
155         }
156
157
158         bool addPoint(float* point, int index)
159         {
160                 for (int i=0;i<count;++i) {
161                         if (indices[i]==index) return false;
162                 }
163                 float dist = (float)flann_dist(target, target_end, point);
164
165                 if (count<capacity) {
166                         indices[count] = index;
167                         dists[count] = dist;
168                         ++count;
169                 }
170                 else if (dist < dists[count-1] || (dist == dists[count-1] && index < indices[count-1])) {
171 //         else if (dist < dists[count-1]) {
172                         indices[count-1] = index;
173                         dists[count-1] = dist;
174                 }
175                 else {
176                         return false;
177                 }
178
179                 int i = count-1;
180                 // bubble up
181                 while (i>=1 && (dists[i]<dists[i-1] || (dists[i]==dists[i-1] && indices[i]<indices[i-1]) ) ) {
182 //         while (i>=1 && (dists[i]<dists[i-1]) ) {
183                         swap(indices[i],indices[i-1]);
184                         swap(dists[i],dists[i-1]);
185                         i--;
186                 }
187
188                 return true;
189         }
190
191         float worstDist() const
192         {
193                 return (count<capacity) ? numeric_limits<float>::max() : dists[count-1];
194         }
195 };
196
197
198 /**
199  * A result-set class used when performing a radius based search.
200  */
201 class RadiusResultSet : public ResultSet
202 {
203         struct Item {
204                 int index;
205                 float dist;
206
207                 bool operator<(Item rhs) {
208                         return dist<rhs.dist;
209                 }
210         };
211
212         vector<Item> items;
213         float radius;
214
215         bool sorted;
216         int* indices;
217         float* dists;
218         size_t count;
219
220 private:
221         void resize_vecs()
222         {
223                 if (items.size()>count) {
224                         if (indices!=NULL) delete[] indices;
225                         if (dists!=NULL) delete[] dists;
226                         count = items.size();
227                         indices = new int[count];
228                         dists = new float[count];
229                 }
230         }
231
232 public:
233         RadiusResultSet(float radius_) :
234                 radius(radius_), indices(NULL), dists(NULL)
235         {
236                 sorted = false;
237                 items.reserve(16);
238                 count = 0;
239         }
240
241         ~RadiusResultSet()
242         {
243                 if (indices!=NULL) delete[] indices;
244                 if (dists!=NULL) delete[] dists;
245         }
246
247         void init(const float* target_, int veclen_)
248         {
249         target = target_;
250         veclen = veclen_;
251         target_end = target + veclen;
252         items.clear();
253         sorted = false;
254         }
255
256         int* getNeighbors()
257         {
258                 if (!sorted) {
259                         sorted = true;
260                         sort_heap(items.begin(), items.end());
261                 }
262                 resize_vecs();
263                 for (size_t i=0;i<items.size();++i) {
264                         indices[i] = items[i].index;
265                 }
266                 return indices;
267         }
268
269     float* getDistances()
270     {
271                 if (!sorted) {
272                         sorted = true;
273                         sort_heap(items.begin(), items.end());
274                 }
275                 resize_vecs();
276                 for (size_t i=0;i<items.size();++i) {
277                         dists[i] = items[i].dist;
278                 }
279         return dists;
280     }
281
282     int size() const
283     {
284         return items.size();
285     }
286
287         bool full() const
288         {
289                 return true;
290         }
291
292         bool addPoint(float* point, int index)
293         {
294                 Item it;
295                 it.index = index;
296                 it.dist = (float)flann_dist(target, target_end, point);
297                 if (it.dist<=radius) {
298                         items.push_back(it);
299                         push_heap(items.begin(), items.end());
300             return true;
301                 }
302         return false;
303         }
304
305         float worstDist() const
306         {
307                 return radius;
308         }
309
310 };
311
312 }
313
314 #endif //RESULTSET_H