1 /***********************************************************************
2 * Software License Agreement (BSD License)
4 * Copyright 2008-2009 Marius Muja (mariusm@cs.ubc.ca). All rights reserved.
5 * Copyright 2008-2009 David G. Lowe (lowe@cs.ubc.ca). All rights reserved.
9 * Redistribution and use in source and binary forms, with or without
10 * modification, are permitted provided that the following conditions
13 * 1. Redistributions of source code must retain the above copyright
14 * notice, this list of conditions and the following disclaimer.
15 * 2. Redistributions in binary form must reproduce the above copyright
16 * notice, this list of conditions and the following disclaimer in the
17 * documentation and/or other materials provided with the distribution.
19 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
20 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
21 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
22 * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
23 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
24 * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
25 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
26 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
28 * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 *************************************************************************/
40 #include "constants.h"
41 #include "allocator.h"
43 #include "result_set.h"
56 * Randomized kd-tree index
58 * Contains the k-d trees and other information for indexing a set of points
59 * for nearest-neighbor matching.
61 class KDTreeIndex : public NNIndex
66 * To improve efficiency, only SAMPLE_MEAN random values are used to
67 * compute the mean and variance at each level when building a tree.
68 * A value of 100 seems to perform as well as using all values.
72 * Top random dimensions to consider
74 * When creating random trees, the dimension on which to subdivide is
75 * selected at random from among the top RAND_DIM dimensions with the
76 * highest variance. A value of 5 works well.
83 * Number of randomized trees that are used
88 * Array of indices to vectors in the dataset. When doing lookup,
89 * this is used instead to mark checkID.
94 * An unique ID for each lookup.
99 * The dataset used by this index
101 const Matrix<float> dataset;
111 /*--------------------- Internal Data Structures --------------------------*/
114 * A node of the binary k-d tree.
116 * This is All nodes that have vec[divfeat] < divval are placed in the
117 * child1 subtree, else child2., A leaf node is indicated if both children are NULL.
121 * Index of the vector feature used for subdivision.
122 * If this is a leaf node (both children are NULL) then
123 * this holds vector index for this leaf.
127 * The value used for subdivision.
133 TreeSt *child1, *child2;
135 typedef TreeSt* Tree;
138 * Array of k-d trees used to find neighbors.
141 typedef BranchStruct<Tree> BranchSt;
142 typedef BranchSt* Branch;
144 * Priority queue storing intermediate branches in the best-bin-first search
146 Heap<BranchSt>* heap;
150 * Pooled memory allocator.
152 * Using a pooled memory allocator is more efficient
153 * than allocating memory directly when there is a large
154 * number small of memory allocations.
156 PooledAllocator pool;
162 flann_algorithm_t getType() const
171 * inputData = dataset with the input features
172 * params = parameters passed to the kdtree algorithm
174 KDTreeIndex(const Matrix<float>& inputData, const KDTreeIndexParams& params = KDTreeIndexParams() ) : dataset(inputData)
176 size_ = dataset.rows;
177 veclen_ = dataset.cols;
179 numTrees = params.trees;
180 trees = new Tree[numTrees];
182 // get the parameters
183 // if (params.find("trees") != params.end()) {
184 // numTrees = (int)params["trees"];
185 // trees = new Tree[numTrees];
191 heap = new Heap<BranchSt>(size_);
194 // Create a permutable array of indices to the input vectors.
195 vind = new int[size_];
196 for (int i = 0; i < size_; i++) {
200 mean = new float[veclen_];
201 var = new float[veclen_];
205 * Standard destructor
224 /* Construct the randomized trees. */
225 for (int i = 0; i < numTrees; i++) {
226 /* Randomize the order of vectors to allow for unbiased sampling. */
227 for (int j = size_; j > 0; --j) {
228 // int rand = cast(int) (drand48() * size);
229 int rnd = rand_int(j);
230 assert(rnd >=0 && rnd < size_);
231 swap(vind[j-1], vind[rnd]);
234 divideTree(&trees[i], 0, size_ - 1);
240 void saveIndex(FILE* stream)
242 save_header(stream, *this);
243 save_value(stream, numTrees);
244 for (int i=0;i<numTrees;++i) {
245 save_tree(stream, trees[i]);
251 void loadIndex(FILE* stream)
253 IndexHeader header = load_header(stream);
255 if (header.rows!=size() || header.cols!=veclen()) {
256 throw FLANNException("The index saved belongs to a different dataset");
258 load_value(stream, numTrees);
263 trees = new Tree[numTrees];
264 for (int i=0;i<numTrees;++i) {
265 load_tree(stream,trees[i]);
272 * Returns size of index.
280 * Returns the length of an index feature.
289 * Computes the inde memory usage
290 * Returns: memory used by the index
292 int usedMemory() const
294 return pool.usedMemory+pool.wastedMemory+dataset.rows*sizeof(int); // pool memory and vind array memory
299 * Find set of nearest neighbors to vec. Their indices are stored inside
303 * result = the result object in which the indices of the nearest-neighbors are stored
304 * vec = the vector for which to search the nearest neighbors
305 * maxCheck = the maximum number of restarts (in a best-bin-first manner)
307 void findNeighbors(ResultSet& result, const float* vec, const SearchParams& searchParams)
309 int maxChecks = searchParams.checks;
312 getExactNeighbors(result, vec);
314 getNeighbors(result, vec, maxChecks);
319 void continueSearch(ResultSet& result, float* vec, int maxCheck)
325 /* Keep searching other branches from heap until finished. */
326 while ( heap->popMin(branch) && (checkCount < maxCheck || !result.full() )) {
327 searchLevel(result, vec, branch.node,branch.mindistsq, checkCount, maxCheck);
330 assert(result.full());
334 // Params estimateSearchParams(float precision, Dataset<float>* testset = NULL)
345 void save_tree(FILE* stream, Tree tree)
347 save_value(stream, *tree);
348 if (tree->child1!=NULL) {
349 save_tree(stream, tree->child1);
351 if (tree->child2!=NULL) {
352 save_tree(stream, tree->child2);
357 void load_tree(FILE* stream, Tree& tree)
359 tree = pool.allocate<TreeSt>();
360 load_value(stream, *tree);
361 if (tree->child1!=NULL) {
362 load_tree(stream, tree->child1);
364 if (tree->child2!=NULL) {
365 load_tree(stream, tree->child2);
371 * Create a tree node that subdivides the list of vecs from vind[first]
372 * to vind[last]. The routine is called recursively on each sublist.
373 * Place a pointer to this new tree node in the location pTree.
375 * Params: pTree = the new node to create
376 * first = index of the first vector
377 * last = index of the last vector
379 void divideTree(Tree* pTree, int first, int last)
383 node = pool.allocate<TreeSt>(); // allocate memory
386 /* If only one exemplar remains, then make this a leaf node. */
388 node->child1 = node->child2 = NULL; /* Mark as leaf node. */
389 node->divfeat = vind[first]; /* Store index of this vec. */
391 chooseDivision(node, first, last);
392 subdivide(node, first, last);
400 * Choose which feature to use in order to subdivide this set of vectors.
401 * Make a random choice among those with the highest variance, and use
402 * its variance as the threshold value.
404 void chooseDivision(Tree node, int first, int last)
406 memset(mean,0,veclen_*sizeof(float));
407 memset(var,0,veclen_*sizeof(float));
409 /* Compute mean values. Only the first SAMPLE_MEAN values need to be
410 sampled to get a good estimate.
412 int end = min(first + SAMPLE_MEAN, last);
413 int count = end - first + 1;
414 for (int j = first; j <= end; ++j) {
415 float* v = dataset[vind[j]];
416 for (int k=0; k<veclen_; ++k) {
420 for (int k=0; k<veclen_; ++k) {
424 /* Compute variances (no need to divide by count). */
425 for (int j = first; j <= end; ++j) {
426 float* v = dataset[vind[j]];
427 for (int k=0; k<veclen_; ++k) {
428 float dist = v[k] - mean[k];
429 var[k] += dist * dist;
432 /* Select one of the highest variance indices at random. */
433 node->divfeat = selectDivision(var);
434 node->divval = mean[node->divfeat];
440 * Select the top RAND_DIM largest values from v and return the index of
441 * one of these selected at random.
443 int selectDivision(float* v)
446 int topind[RAND_DIM];
448 /* Create a list of the indices of the top RAND_DIM values. */
449 for (int i = 0; i < veclen_; ++i) {
450 if (num < RAND_DIM || v[i] > v[topind[num-1]]) {
451 /* Put this element at end of topind. */
452 if (num < RAND_DIM) {
453 topind[num++] = i; /* Add to list. */
456 topind[num-1] = i; /* Replace last element. */
458 /* Bubble end value down to right location by repeated swapping. */
460 while (j > 0 && v[topind[j]] > v[topind[j-1]]) {
461 swap(topind[j], topind[j-1]);
466 /* Select a random integer in range [0,num-1], and return that index. */
467 // int rand = cast(int) (drand48() * num);
468 int rnd = rand_int(num);
469 assert(rnd >=0 && rnd < num);
475 * Subdivide the list of exemplars using the feature and division
476 * value given in this node. Call divideTree recursively on each list.
478 void subdivide(Tree node, int first, int last)
480 /* Move vector indices for left subtree to front of list. */
485 float val = dataset[ind][node->divfeat];
486 if (val < node->divval) {
489 /* Move to end of list by swapping vind i and j. */
490 swap(vind[i], vind[j]);
494 /* If either list is empty, it means we have hit the unlikely case
495 in which all remaining features are identical. Split in the middle
496 to maintain a balanced tree.
498 if ( (i == first) || (i == last+1)) {
499 i = (first+last+1)/2;
502 divideTree(& node->child1, first, i - 1);
503 divideTree(& node->child2, i, last);
509 * Performs an exact nearest neighbor search. The exact search performs a full
510 * traversal of the tree.
512 void getExactNeighbors(ResultSet& result, const float* vec)
514 checkID -= 1; /* Set a different unique ID for each search. */
517 fprintf(stderr,"It doesn't make any sense to use more than one tree for exact search");
520 searchLevelExact(result, vec, trees[0], 0.0);
522 assert(result.full());
526 * Performs the approximate nearest-neighbor search. The search is approximate
527 * because the tree traversal is abandoned after a given number of descends in
530 void getNeighbors(ResultSet& result, const float* vec, int maxCheck)
537 checkID -= 1; /* Set a different unique ID for each search. */
539 /* Search once through each tree down to root. */
540 for (i = 0; i < numTrees; ++i) {
541 searchLevel(result, vec, trees[i], 0.0, checkCount, maxCheck);
544 /* Keep searching other branches from heap until finished. */
545 while ( heap->popMin(branch) && (checkCount < maxCheck || !result.full() )) {
546 searchLevel(result, vec, branch.node,branch.mindistsq, checkCount, maxCheck);
549 assert(result.full());
554 * Search starting from a given node of the tree. Based on any mismatches at
555 * higher levels, all exemplars below this level must have a distance of
556 * at least "mindistsq".
558 void searchLevel(ResultSet& result, const float* vec, Tree node, float mindistsq, int& checkCount, int maxCheck)
560 if (result.worstDist()<mindistsq) {
561 // printf("Ignoring branch, too far\n");
566 Tree bestChild, otherChild;
568 /* If this is a leaf node, then do check and return. */
569 if (node->child1 == NULL && node->child2 == NULL) {
571 /* Do not check same node more than once when searching multiple trees.
572 Once a vector is checked, we set its location in vind to the
575 if (vind[node->divfeat] == checkID || checkCount>=maxCheck) {
576 if (result.full()) return;
579 vind[node->divfeat] = checkID;
581 result.addPoint(dataset[node->divfeat],node->divfeat);
585 /* Which child branch should be taken first? */
586 val = vec[node->divfeat];
587 diff = val - node->divval;
588 bestChild = (diff < 0) ? node->child1 : node->child2;
589 otherChild = (diff < 0) ? node->child2 : node->child1;
591 /* Create a branch record for the branch not taken. Add distance
592 of this feature boundary (we don't attempt to correct for any
593 use of this feature in a parent node, which is unlikely to
594 happen and would have only a small effect). Don't bother
595 adding more branches to heap after halfway point, as cost of
596 adding exceeds their value.
599 float new_distsq = flann_dist(&val, &val+1, &node->divval, mindistsq);
600 // if (2 * checkCount < maxCheck || !result.full()) {
601 if (new_distsq < result.worstDist() || !result.full()) {
602 heap->insert( BranchSt::make_branch(otherChild, new_distsq) );
605 /* Call recursively to search next level down. */
606 searchLevel(result, vec, bestChild, mindistsq, checkCount, maxCheck);
610 * Performs an exact search in the tree starting from a node.
612 void searchLevelExact(ResultSet& result, const float* vec, Tree node, float mindistsq)
614 if (mindistsq>result.worstDist()) {
619 Tree bestChild, otherChild;
621 /* If this is a leaf node, then do check and return. */
622 if (node->child1 == NULL && node->child2 == NULL) {
624 /* Do not check same node more than once when searching multiple trees.
625 Once a vector is checked, we set its location in vind to the
628 if (vind[node->divfeat] == checkID)
630 vind[node->divfeat] = checkID;
632 result.addPoint(dataset[node->divfeat],node->divfeat);
636 /* Which child branch should be taken first? */
637 val = vec[node->divfeat];
638 diff = val - node->divval;
639 bestChild = (diff < 0) ? node->child1 : node->child2;
640 otherChild = (diff < 0) ? node->child2 : node->child1;
643 /* Call recursively to search next level down. */
644 searchLevelExact(result, vec, bestChild, mindistsq);
645 float new_distsq = flann_dist(&val, &val+1, &node->divval, mindistsq);
646 searchLevelExact(result, vec, otherChild, new_distsq);