1 /***********************************************************************
2 * Software License Agreement (BSD License)
4 * Copyright 2008-2009 Marius Muja (mariusm@cs.ubc.ca). All rights reserved.
5 * Copyright 2008-2009 David G. Lowe (lowe@cs.ubc.ca). All rights reserved.
9 * Redistribution and use in source and binary forms, with or without
10 * modification, are permitted provided that the following conditions
13 * 1. Redistributions of source code must retain the above copyright
14 * notice, this list of conditions and the following disclaimer.
15 * 2. Redistributions in binary form must reproduce the above copyright
16 * notice, this list of conditions and the following disclaimer in the
17 * documentation and/or other materials provided with the distribution.
19 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
20 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
21 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
22 * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
23 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
24 * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
25 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
26 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
28 * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 *************************************************************************/
46 /* This record represents a branch point when finding neighbors in
47 the tree. It contains a record of the minimum distance to the query
48 point, as well as the node at which the search resumes.
53 T node; /* Tree node at which search resumes */
54 float mindistsq; /* Minimum distance to query for all nodes below. */
56 bool operator<(const BranchStruct<T>& rhs)
58 return mindistsq<rhs.mindistsq;
61 static BranchStruct<T> make_branch(T aNode, float dist)
63 BranchStruct<T> branch;
65 branch.mindistsq = dist;
79 const float* target_end;
84 ResultSet(float* target_ = NULL, int veclen_ = 0) :
85 target(target_), veclen(veclen_) { target_end = target + veclen;}
87 virtual ~ResultSet() {}
89 virtual void init(const float* target_, int veclen_) = 0;
91 virtual int* getNeighbors() = 0;
93 virtual float* getDistances() = 0;
95 virtual int size() const = 0;
97 virtual bool full() const = 0;
99 virtual bool addPoint(float* point, int index) = 0;
101 virtual float worstDist() const = 0;
106 class KNNResultSet : public ResultSet
115 KNNResultSet(int capacity_, float* target_ = NULL, int veclen_ = 0 ) :
116 ResultSet(target_, veclen_), capacity(capacity_), count(0)
118 indices = new int[capacity_];
119 dists = new float[capacity_];
128 void init(const float* target_, int veclen_)
132 target_end = target + veclen;
142 float* getDistances()
154 return count == capacity;
158 bool addPoint(float* point, int index)
160 for (int i=0;i<count;++i) {
161 if (indices[i]==index) return false;
163 float dist = (float)flann_dist(target, target_end, point);
165 if (count<capacity) {
166 indices[count] = index;
170 else if (dist < dists[count-1] || (dist == dists[count-1] && index < indices[count-1])) {
171 // else if (dist < dists[count-1]) {
172 indices[count-1] = index;
173 dists[count-1] = dist;
181 while (i>=1 && (dists[i]<dists[i-1] || (dists[i]==dists[i-1] && indices[i]<indices[i-1]) ) ) {
182 // while (i>=1 && (dists[i]<dists[i-1]) ) {
183 swap(indices[i],indices[i-1]);
184 swap(dists[i],dists[i-1]);
191 float worstDist() const
193 return (count<capacity) ? numeric_limits<float>::max() : dists[count-1];
199 * A result-set class used when performing a radius based search.
201 class RadiusResultSet : public ResultSet
207 bool operator<(Item rhs) {
208 return dist<rhs.dist;
223 if (items.size()>count) {
224 if (indices!=NULL) delete[] indices;
225 if (dists!=NULL) delete[] dists;
226 count = items.size();
227 indices = new int[count];
228 dists = new float[count];
233 RadiusResultSet(float radius_) :
234 radius(radius_), indices(NULL), dists(NULL)
243 if (indices!=NULL) delete[] indices;
244 if (dists!=NULL) delete[] dists;
247 void init(const float* target_, int veclen_)
251 target_end = target + veclen;
260 sort_heap(items.begin(), items.end());
263 for (size_t i=0;i<items.size();++i) {
264 indices[i] = items[i].index;
269 float* getDistances()
273 sort_heap(items.begin(), items.end());
276 for (size_t i=0;i<items.size();++i) {
277 dists[i] = items[i].dist;
292 bool addPoint(float* point, int index)
296 it.dist = (float)flann_dist(target, target_end, point);
297 if (it.dist<=radius) {
299 push_heap(items.begin(), items.end());
305 float worstDist() const