1 /***********************************************************************
2 * Software License Agreement (BSD License)
4 * Copyright 2008-2009 Marius Muja (mariusm@cs.ubc.ca). All rights reserved.
5 * Copyright 2008-2009 David G. Lowe (lowe@cs.ubc.ca). All rights reserved.
9 * Redistribution and use in source and binary forms, with or without
10 * modification, are permitted provided that the following conditions
13 * 1. Redistributions of source code must retain the above copyright
14 * notice, this list of conditions and the following disclaimer.
15 * 2. Redistributions in binary form must reproduce the above copyright
16 * notice, this list of conditions and the following disclaimer in the
17 * documentation and/or other materials provided with the distribution.
19 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
20 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
21 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
22 * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
23 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
24 * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
25 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
26 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
28 * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 *************************************************************************/
31 #include "index_testing.h"
32 #include "result_set.h"
46 const float SEARCH_EPS = 0.001f;
48 int countCorrectMatches(int* neighbors, int* groundTruth, int n)
51 for (int i=0;i<n;++i) {
52 for (int k=0;k<n;++k) {
53 if (neighbors[i]==groundTruth[k]) {
63 float computeDistanceRaport(const Matrix<float>& inputData, float* target, int* neighbors, int* groundTruth, int veclen, int n)
65 float* target_end = target + veclen;
67 for (int i=0;i<n;++i) {
68 float den = (float)flann_dist(target,target_end, inputData[groundTruth[i]]);
69 float num = (float)flann_dist(target,target_end, inputData[neighbors[i]]);
71 // printf("den=%g,num=%g\n",den,num);
73 if (den==0 && num==0) {
83 float search_with_ground_truth(NNIndex& index, const Matrix<float>& inputData, const Matrix<float>& testData, const Matrix<int>& matches, int nn, int checks, float& time, float& dist, int skipMatches)
85 if (matches.cols<nn) {
86 logger.info("matches.cols=%d, nn=%d\n",matches.cols,nn);
88 throw FLANNException("Ground truth is not computed for as many neighbors as requested");
91 KNNResultSet resultSet(nn+skipMatches);
92 SearchParams searchParams(checks);
103 for (int i = 0; i < testData.rows; i++) {
104 float* target = testData[i];
105 resultSet.init(target, testData.cols);
106 index.findNeighbors(resultSet,target, searchParams);
107 int* neighbors = resultSet.getNeighbors();
108 neighbors = neighbors+skipMatches;
110 correct += countCorrectMatches(neighbors,matches[i], nn);
111 distR += computeDistanceRaport(inputData, target,neighbors,matches[i], testData.cols, nn);
115 time = (float)(t.value/repeats);
118 float precicion = (float)correct/(nn*testData.rows);
120 dist = distR/(testData.rows*nn);
122 logger.info("%8d %10.4g %10.5g %10.5g %10.5g\n",
123 checks, precicion, time, 1000.0 * time / testData.rows, dist);
128 void search_for_neighbors(NNIndex& index, const Matrix<float>& testset, Matrix<int>& result, Matrix<float>& dists, const SearchParams& searchParams, int skip)
130 assert(testset.rows == result.rows);
132 int nn = result.cols;
133 KNNResultSet resultSet(nn+skip);
136 for (int i = 0; i < testset.rows; i++) {
137 float* target = testset[i];
138 resultSet.init(target, testset.cols);
140 index.findNeighbors(resultSet,target, searchParams);
142 int* neighbors = resultSet.getNeighbors();
143 float* distances = resultSet.getDistances();
144 memcpy(result[i], neighbors+skip, nn*sizeof(int));
145 memcpy(dists[i], distances+skip, nn*sizeof(float));
150 float test_index_checks(NNIndex& index, const Matrix<float>& inputData, const Matrix<float>& testData, const Matrix<int>& matches, int checks, float& precision, int nn, int skipMatches)
152 logger.info(" Nodes Precision(%) Time(s) Time/vec(ms) Mean dist\n");
153 logger.info("---------------------------------------------------------\n");
157 precision = search_with_ground_truth(index, inputData, testData, matches, nn, checks, time, dist, skipMatches);
163 float test_index_precision(NNIndex& index, const Matrix<float>& inputData, const Matrix<float>& testData, const Matrix<int>& matches,
164 float precision, int& checks, int nn, int skipMatches)
166 logger.info(" Nodes Precision(%) Time(s) Time/vec(ms) Mean dist\n");
167 logger.info("---------------------------------------------------------\n");
176 p2 = search_with_ground_truth(index, inputData, testData, matches, nn, c2, time, dist, skipMatches);
179 logger.info("Got as close as I can\n");
184 while (p2<precision) {
188 p2 = search_with_ground_truth(index, inputData, testData, matches, nn, c2, time, dist, skipMatches);
193 if (fabs(p2-precision)>SEARCH_EPS) {
194 logger.info("Start linear estimation\n");
195 // after we got to values in the vecinity of the desired precision
196 // use linear approximation get a better estimation
199 realPrecision = search_with_ground_truth(index, inputData, testData, matches, nn, cx, time, dist, skipMatches);
200 while (fabs(realPrecision-precision)>SEARCH_EPS) {
202 if (realPrecision<precision) {
210 logger.info("Got as close as I can\n");
213 realPrecision = search_with_ground_truth(index, inputData, testData, matches, nn, cx, time, dist, skipMatches);
220 logger.info("No need for linear estimation\n");
230 float test_index_precisions(NNIndex& index, const Matrix<float>& inputData, const Matrix<float>& testData, const Matrix<int>& matches,
231 float* precisions, int precisions_length, int nn, int skipMatches, float maxTime)
233 // make sure precisions array is sorted
234 sort(precisions, precisions+precisions_length);
237 float precision = precisions[pindex];
239 logger.info(" Nodes Precision(%) Time(s) Time/vec(ms) Mean dist");
240 logger.info("---------------------------------------------------------");
251 p2 = search_with_ground_truth(index, inputData, testData, matches, nn, c2, time, dist, skipMatches);
253 // if precision for 1 run down the tree is already
254 // better then some of the requested precisions, then
256 while (precisions[pindex]<p2 && pindex<precisions_length) {
260 if (pindex==precisions_length) {
261 logger.info("Got as close as I can\n");
265 for (int i=pindex;i<precisions_length;++i) {
267 precision = precisions[i];
268 while (p2<precision) {
272 p2 = search_with_ground_truth(index, inputData, testData, matches, nn, c2, time, dist, skipMatches);
273 if (maxTime> 0 && time > maxTime && p2<precision) return time;
278 if (fabs(p2-precision)>SEARCH_EPS) {
279 logger.info("Start linear estimation\n");
280 // after we got to values in the vecinity of the desired precision
281 // use linear approximation get a better estimation
284 realPrecision = search_with_ground_truth(index, inputData, testData, matches, nn, cx, time, dist, skipMatches);
285 while (fabs(realPrecision-precision)>SEARCH_EPS) {
287 if (realPrecision<precision) {
295 logger.info("Got as close as I can\n");
298 realPrecision = search_with_ground_truth(index, inputData, testData, matches, nn, cx, time, dist, skipMatches);
305 logger.info("No need for linear estimation\n");