1 /***********************************************************************
2 * Software License Agreement (BSD License)
4 * Copyright 2008-2009 Marius Muja (mariusm@cs.ubc.ca). All rights reserved.
5 * Copyright 2008-2009 David G. Lowe (lowe@cs.ubc.ca). All rights reserved.
9 * Redistribution and use in source and binary forms, with or without
10 * modification, are permitted provided that the following conditions
13 * 1. Redistributions of source code must retain the above copyright
14 * notice, this list of conditions and the following disclaimer.
15 * 2. Redistributions in binary form must reproduce the above copyright
16 * notice, this list of conditions and the following disclaimer in the
17 * documentation and/or other materials provided with the distribution.
19 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
20 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
21 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
22 * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
23 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
24 * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
25 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
26 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
28 * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 *************************************************************************/
37 #include "index_testing.h"
39 #include "object_factory.h"
41 #include "kdtree_index.h"
42 #include "kmeans_index.h"
43 #include "composite_index.h"
44 #include "linear_index.h"
45 #include "autotuned_index.h"
55 #define EXPORTED extern "C" __declspec(dllexport)
57 #define EXPORTED extern "C"
64 typedef ObjectFactory<IndexParams, flann_algorithm_t> ParamsFactory;
67 IndexParams* IndexParams::createFromParameters(const FLANNParameters& p)
69 IndexParams* params = ParamsFactory::instance().create(p.algorithm);
70 params->fromParameters(p);
75 NNIndex* LinearIndexParams::createIndex(const Matrix<float>& dataset) const
77 return new LinearIndex(dataset, *this);
81 NNIndex* KDTreeIndexParams::createIndex(const Matrix<float>& dataset) const
83 return new KDTreeIndex(dataset, *this);
86 NNIndex* KMeansIndexParams::createIndex(const Matrix<float>& dataset) const
88 return new KMeansIndex(dataset, *this);
92 NNIndex* CompositeIndexParams::createIndex(const Matrix<float>& dataset) const
94 return new CompositeIndex(dataset, *this);
98 NNIndex* AutotunedIndexParams::createIndex(const Matrix<float>& dataset) const
100 return new AutotunedIndex(dataset, *this);
104 NNIndex* SavedIndexParams::createIndex(const Matrix<float>& dataset) const
107 FILE* fin = fopen(filename.c_str(), "rb");
111 IndexHeader header = load_header(fin);
113 IndexParams* params = ParamsFactory::instance().create(header.index_type);
114 NNIndex* nnIndex = params->createIndex(dataset);
115 nnIndex->loadIndex(fin);
128 ParamsFactory::instance().register_<LinearIndexParams>(LINEAR);
129 ParamsFactory::instance().register_<KDTreeIndexParams>(KDTREE);
130 ParamsFactory::instance().register_<KMeansIndexParams>(KMEANS);
131 ParamsFactory::instance().register_<CompositeIndexParams>(COMPOSITE);
132 ParamsFactory::instance().register_<AutotunedIndexParams>(AUTOTUNED);
133 ParamsFactory::instance().register_<SavedIndexParams>(SAVED);
140 Index::Index(const Matrix<float>& dataset, const IndexParams& params)
142 nnIndex = params.createIndex(dataset);
143 nnIndex->buildIndex();
152 void Index::knnSearch(const Matrix<float>& queries, Matrix<int>& indices, Matrix<float>& dists, int knn, const SearchParams& searchParams)
154 assert(queries.cols==nnIndex->veclen());
155 assert(indices.rows>=queries.rows);
156 assert(dists.rows>=queries.rows);
157 assert(indices.cols>=knn);
158 assert(dists.cols>=knn);
161 search_for_neighbors(*nnIndex, queries, indices, dists, searchParams);
164 int Index::radiusSearch(const Matrix<float>& query, Matrix<int> indices, Matrix<float> dists, float radius, const SearchParams& searchParams)
167 printf("I can only search one feature at a time for range search\n");
170 assert(query.cols==nnIndex->veclen());
172 RadiusResultSet resultSet(radius);
173 resultSet.init(query.data, query.cols);
174 nnIndex->findNeighbors(resultSet,query.data,searchParams);
176 // TODO: optimize here
177 int* neighbors = resultSet.getNeighbors();
178 float* distances = resultSet.getDistances();
179 int count_nn = min((long)resultSet.size(), indices.cols);
181 assert (dists.cols>=count_nn);
183 for (int i=0;i<count_nn;++i) {
184 indices[0][i] = neighbors[i];
185 dists[0][i] = distances[i];
192 void Index::save(string filename)
194 FILE* fout = fopen(filename.c_str(), "wb");
196 logger.error("Cannot open file: %s", filename.c_str());
197 throw FLANNException("Cannot open file");
199 nnIndex->saveIndex(fout);
203 int Index::size() const
205 return nnIndex->size();
208 int Index::veclen() const
210 return nnIndex->veclen();
214 int hierarchicalClustering(const Matrix<float>& features, Matrix<float>& centers, const KMeansIndexParams& params)
216 KMeansIndex kmeans(features, params);
219 int clusterNum = kmeans.getClusterCenters(centers);
227 using namespace flann;
229 typedef NNIndex* NNIndexPtr;
230 typedef Matrix<float>* MatrixPtr;
234 void init_flann_parameters(FLANNParameters* p)
237 flann_log_verbosity(p->log_level);
238 if (p->random_seed>0) {
239 seed_random(p->random_seed);
245 EXPORTED void flann_log_verbosity(int level)
248 logger.setLevel(level);
252 EXPORTED void flann_set_distance_type(flann_distance_t distance_type, int order)
254 flann_distance_type = distance_type;
255 flann_minkowski_order = order;
259 EXPORTED flann_index_t flann_build_index(float* dataset, int rows, int cols, float* /*speedup*/, FLANNParameters* flann_params)
262 init_flann_parameters(flann_params);
263 if (flann_params == NULL) {
264 throw FLANNException("The flann_params argument must be non-null");
266 IndexParams* params = IndexParams::createFromParameters(*flann_params);
267 Index* index = new Index(Matrix<float>(rows,cols,dataset), *params);
271 catch (runtime_error& e) {
272 logger.error("Caught exception: %s\n",e.what());
279 EXPORTED int flann_save_index(flann_index_t index_ptr, char* filename)
282 if (index_ptr==NULL) {
283 throw FLANNException("Invalid index");
286 Index* index = (Index*)index_ptr;
287 index->save(filename);
291 catch(runtime_error& e) {
292 logger.error("Caught exception: %s\n",e.what());
298 EXPORTED FLANN_INDEX flann_load_index(char* filename, float* dataset, int rows, int cols)
301 Index* index = new Index(Matrix<float>(rows,cols,dataset), SavedIndexParams(filename));
304 catch(runtime_error& e) {
305 logger.error("Caught exception: %s\n",e.what());
312 EXPORTED int flann_find_nearest_neighbors(float* dataset, int rows, int cols, float* testset, int tcount, int* result, float* dists, int nn, FLANNParameters* flann_params)
316 init_flann_parameters(flann_params);
318 IndexParams* params = IndexParams::createFromParameters(*flann_params);
319 Index* index = new Index(Matrix<float>(rows,cols,dataset), *params);
320 Matrix<int> m_indices(tcount, nn, result);
321 Matrix<float> m_dists(tcount, nn, dists);
322 index->knnSearch(Matrix<float>(tcount, index->veclen(), testset),
324 m_dists, nn, SearchParams(flann_params->checks) );
326 catch(runtime_error& e) {
327 logger.error("Caught exception: %s\n",e.what());
335 EXPORTED int flann_find_nearest_neighbors_index(flann_index_t index_ptr, float* testset, int tcount, int* result, float* dists, int nn, int checks, FLANNParameters* flann_params)
338 init_flann_parameters(flann_params);
339 if (index_ptr==NULL) {
340 throw FLANNException("Invalid index");
342 Index* index = (Index*) index_ptr;
344 Matrix<int> m_indices(tcount, nn, result);
345 Matrix<float> m_dists(tcount, nn, dists);
346 index->knnSearch(Matrix<float>(tcount, index->veclen(), testset),
348 m_dists, nn, SearchParams(checks) );
351 catch(runtime_error& e) {
352 logger.error("Caught exception: %s\n",e.what());
360 EXPORTED int flann_radius_search(FLANN_INDEX index_ptr,
367 FLANNParameters* flann_params)
370 init_flann_parameters(flann_params);
371 if (index_ptr==NULL) {
372 throw FLANNException("Invalid index");
374 Index* index = (Index*) index_ptr;
376 Matrix<int> m_indices(1, max_nn, indices);
377 Matrix<float> m_dists(1, max_nn, dists);
378 int count = index->radiusSearch(Matrix<float>(1, index->veclen(), query),
380 m_dists, radius, SearchParams(checks) );
385 catch(runtime_error& e) {
386 logger.error("Caught exception: %s\n",e.what());
393 EXPORTED int flann_free_index(FLANN_INDEX index_ptr, FLANNParameters* flann_params)
396 init_flann_parameters(flann_params);
397 if (index_ptr==NULL) {
398 throw FLANNException("Invalid index");
400 Index* index = (Index*) index_ptr;
405 catch(runtime_error& e) {
406 logger.error("Caught exception: %s\n",e.what());
412 EXPORTED int flann_compute_cluster_centers(float* dataset, int rows, int cols, int clusters, float* result, FLANNParameters* flann_params)
415 init_flann_parameters(flann_params);
417 MatrixPtr inputData = new Matrix<float>(rows,cols,dataset);
418 KMeansIndexParams params(flann_params->branching, flann_params->iterations, flann_params->centers_init, flann_params->cb_index);
419 Matrix<float> centers(clusters, cols, result);
420 int clusterNum = hierarchicalClustering(*inputData,centers, params);
423 } catch (runtime_error& e) {
424 logger.error("Caught exception: %s\n",e.what());
430 EXPORTED void compute_ground_truth_float(float* dataset, int dshape[], float* testset, int tshape[], int* match, int mshape[], int skip)
432 assert(dshape[1]==tshape[1]);
433 assert(tshape[0]==mshape[0]);
435 Matrix<int> _match(mshape[0], mshape[1], match);
436 compute_ground_truth(Matrix<float>(dshape[0], dshape[1], dataset), Matrix<float>(tshape[0], tshape[1], testset), _match, skip);
440 EXPORTED float test_with_precision(FLANN_INDEX index_ptr, float* dataset, int dshape[], float* testset, int tshape[], int* matches, int mshape[],
441 int nn, float precision, int* checks, int skip = 0)
443 assert(dshape[1]==tshape[1]);
444 assert(tshape[0]==mshape[0]);
447 if (index_ptr==NULL) {
448 throw FLANNException("Invalid index");
450 NNIndexPtr index = (NNIndexPtr)index_ptr;
451 return test_index_precision(*index, Matrix<float>(dshape[0], dshape[1],dataset), Matrix<float>(tshape[0], tshape[1], testset),
452 Matrix<int>(mshape[0],mshape[1],matches), precision, *checks, nn, skip);
453 } catch (runtime_error& e) {
454 logger.error("Caught exception: %s\n",e.what());
459 EXPORTED float test_with_checks(FLANN_INDEX index_ptr, float* dataset, int dshape[], float* testset, int tshape[], int* matches, int mshape[],
460 int nn, int checks, float* precision, int skip = 0)
462 assert(dshape[1]==tshape[1]);
463 assert(tshape[0]==mshape[0]);
466 if (index_ptr==NULL) {
467 throw FLANNException("Invalid index");
469 NNIndexPtr index = (NNIndexPtr)index_ptr;
470 return test_index_checks(*index, Matrix<float>(dshape[0], dshape[1],dataset), Matrix<float>(tshape[0], tshape[1], testset),
471 Matrix<int>(mshape[0],mshape[1],matches), checks, *precision, nn, skip);
472 } catch (runtime_error& e) {
473 logger.error("Caught exception: %s\n",e.what());