Update to 2.0.0 tree from current Fremantle build
[opencv] / 3rdparty / flann / flann.cpp
1 /***********************************************************************
2  * Software License Agreement (BSD License)
3  *
4  * Copyright 2008-2009  Marius Muja (mariusm@cs.ubc.ca). All rights reserved.
5  * Copyright 2008-2009  David G. Lowe (lowe@cs.ubc.ca). All rights reserved.
6  *
7  * THE BSD LICENSE
8  *
9  * Redistribution and use in source and binary forms, with or without
10  * modification, are permitted provided that the following conditions
11  * are met:
12  *
13  * 1. Redistributions of source code must retain the above copyright
14  *    notice, this list of conditions and the following disclaimer.
15  * 2. Redistributions in binary form must reproduce the above copyright
16  *    notice, this list of conditions and the following disclaimer in the
17  *    documentation and/or other materials provided with the distribution.
18  *
19  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
20  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
21  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
22  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
23  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
24  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
25  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
26  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
28  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29  *************************************************************************/
30
31 #include <stdexcept>
32 #include <vector>
33 #include "flann.h"
34 #include "timer.h"
35 #include "common.h"
36 #include "logger.h"
37 #include "index_testing.h"
38 #include "saving.h"
39 #include "object_factory.h"
40 // index types
41 #include "kdtree_index.h"
42 #include "kmeans_index.h"
43 #include "composite_index.h"
44 #include "linear_index.h"
45 #include "autotuned_index.h"
46
47 #include <typeinfo>
48 using namespace std;
49
50
51
52 #include "flann.h"
53
54 #ifdef WIN32
55 #define EXPORTED extern "C" __declspec(dllexport)
56 #else
57 #define EXPORTED extern "C"
58 #endif
59
60
61 namespace flann
62 {
63
64 typedef ObjectFactory<IndexParams, flann_algorithm_t> ParamsFactory;
65
66
67 IndexParams* IndexParams::createFromParameters(const FLANNParameters& p)
68 {
69         IndexParams* params = ParamsFactory::instance().create(p.algorithm);
70         params->fromParameters(p);
71
72         return params;
73 }
74
75 NNIndex* LinearIndexParams::createIndex(const Matrix<float>& dataset) const
76 {
77         return new LinearIndex(dataset, *this);
78 }
79
80
81 NNIndex* KDTreeIndexParams::createIndex(const Matrix<float>& dataset) const
82 {
83         return new KDTreeIndex(dataset, *this);
84 }
85
86 NNIndex* KMeansIndexParams::createIndex(const Matrix<float>& dataset) const
87 {
88         return new KMeansIndex(dataset, *this);
89 }
90
91
92 NNIndex* CompositeIndexParams::createIndex(const Matrix<float>& dataset) const
93 {
94         return new CompositeIndex(dataset, *this);
95 }
96
97
98 NNIndex* AutotunedIndexParams::createIndex(const Matrix<float>& dataset) const
99 {
100         return new AutotunedIndex(dataset, *this);
101 }
102
103
104 NNIndex* SavedIndexParams::createIndex(const Matrix<float>& dataset) const
105 {
106
107         FILE* fin = fopen(filename.c_str(), "rb");
108         if (fin==NULL) {
109                 return NULL;
110         }
111         IndexHeader header = load_header(fin);
112         rewind(fin);
113         IndexParams* params = ParamsFactory::instance().create(header.index_type);
114         NNIndex* nnIndex =  params->createIndex(dataset);
115         nnIndex->loadIndex(fin);
116         fclose(fin);
117         delete params; //?
118
119         return nnIndex;
120
121 }
122
123 class StaticInit
124 {
125 public:
126         StaticInit()
127         {
128                 ParamsFactory::instance().register_<LinearIndexParams>(LINEAR);
129                 ParamsFactory::instance().register_<KDTreeIndexParams>(KDTREE);
130                 ParamsFactory::instance().register_<KMeansIndexParams>(KMEANS);
131                 ParamsFactory::instance().register_<CompositeIndexParams>(COMPOSITE);
132                 ParamsFactory::instance().register_<AutotunedIndexParams>(AUTOTUNED);
133                 ParamsFactory::instance().register_<SavedIndexParams>(SAVED);
134         }
135 };
136 StaticInit __init;
137
138
139
140 Index::Index(const Matrix<float>& dataset, const IndexParams& params)
141 {
142         nnIndex = params.createIndex(dataset);
143         nnIndex->buildIndex();
144 }
145
146 Index::~Index()
147 {
148         delete nnIndex;
149 }
150
151
152 void Index::knnSearch(const Matrix<float>& queries, Matrix<int>& indices, Matrix<float>& dists, int knn, const SearchParams& searchParams)
153 {
154         assert(queries.cols==nnIndex->veclen());
155         assert(indices.rows>=queries.rows);
156         assert(dists.rows>=queries.rows);
157         assert(indices.cols>=knn);
158         assert(dists.cols>=knn);
159
160
161     search_for_neighbors(*nnIndex, queries, indices, dists, searchParams);
162 }
163
164 int Index::radiusSearch(const Matrix<float>& query, Matrix<int> indices, Matrix<float> dists, float radius, const SearchParams& searchParams)
165 {
166         if (query.rows!=1) {
167                 printf("I can only search one feature at a time for range search\n");
168                 return -1;
169         }
170         assert(query.cols==nnIndex->veclen());
171
172         RadiusResultSet resultSet(radius);
173         resultSet.init(query.data, query.cols);
174         nnIndex->findNeighbors(resultSet,query.data,searchParams);
175
176         // TODO: optimize here
177         int* neighbors = resultSet.getNeighbors();
178         float* distances = resultSet.getDistances();
179         int count_nn = min((long)resultSet.size(), indices.cols);
180
181         assert (dists.cols>=count_nn);
182
183         for (int i=0;i<count_nn;++i) {
184                 indices[0][i] = neighbors[i];
185                 dists[0][i] = distances[i];
186         }
187
188         return count_nn;
189 }
190
191
192 void Index::save(string filename)
193 {
194         FILE* fout = fopen(filename.c_str(), "wb");
195         if (fout==NULL) {
196                 logger.error("Cannot open file: %s", filename.c_str());
197                 throw FLANNException("Cannot open file");
198         }
199         nnIndex->saveIndex(fout);
200         fclose(fout);
201 }
202
203 int Index::size() const
204 {
205         return nnIndex->size();
206 }
207
208 int Index::veclen() const
209 {
210         return nnIndex->veclen();
211 }
212
213
214 int hierarchicalClustering(const Matrix<float>& features, Matrix<float>& centers, const KMeansIndexParams& params)
215 {
216     KMeansIndex kmeans(features, params);
217         kmeans.buildIndex();
218
219     int clusterNum = kmeans.getClusterCenters(centers);
220         return clusterNum;
221 }
222
223 } // namespace FLANN
224
225
226
227 using namespace flann;
228
229 typedef NNIndex* NNIndexPtr;
230 typedef Matrix<float>* MatrixPtr;
231
232
233
234 void init_flann_parameters(FLANNParameters* p)
235 {
236         if (p != NULL) {
237                 flann_log_verbosity(p->log_level);
238         if (p->random_seed>0) {
239                   seed_random(p->random_seed);
240         }
241         }
242 }
243
244
245 EXPORTED void flann_log_verbosity(int level)
246 {
247     if (level>=0) {
248         logger.setLevel(level);
249     }
250 }
251
252 EXPORTED void flann_set_distance_type(flann_distance_t distance_type, int order)
253 {
254         flann_distance_type = distance_type;
255         flann_minkowski_order = order;
256 }
257
258
259 EXPORTED flann_index_t flann_build_index(float* dataset, int rows, int cols, float* /*speedup*/, FLANNParameters* flann_params)
260 {
261         try {
262                 init_flann_parameters(flann_params);
263                 if (flann_params == NULL) {
264                         throw FLANNException("The flann_params argument must be non-null");
265                 }
266                 IndexParams* params = IndexParams::createFromParameters(*flann_params);
267                 Index* index = new Index(Matrix<float>(rows,cols,dataset), *params);
268
269                 return index;
270         }
271         catch (runtime_error& e) {
272                 logger.error("Caught exception: %s\n",e.what());
273                 return NULL;
274         }
275 }
276
277
278
279 EXPORTED int flann_save_index(flann_index_t index_ptr, char* filename)
280 {
281         try {
282                 if (index_ptr==NULL) {
283                         throw FLANNException("Invalid index");
284                 }
285
286                 Index* index = (Index*)index_ptr;
287                 index->save(filename);
288
289                 return 0;
290         }
291         catch(runtime_error& e) {
292                 logger.error("Caught exception: %s\n",e.what());
293                 return -1;
294         }
295 }
296
297
298 EXPORTED FLANN_INDEX flann_load_index(char* filename, float* dataset, int rows, int cols)
299 {
300         try {
301                 Index* index = new Index(Matrix<float>(rows,cols,dataset), SavedIndexParams(filename));
302                 return index;
303         }
304         catch(runtime_error& e) {
305                 logger.error("Caught exception: %s\n",e.what());
306                 return NULL;
307         }
308 }
309
310
311
312 EXPORTED int flann_find_nearest_neighbors(float* dataset,  int rows, int cols, float* testset, int tcount, int* result, float* dists, int nn, FLANNParameters* flann_params)
313 {
314     int _result = 0;
315         try {
316                 init_flann_parameters(flann_params);
317
318                 IndexParams* params = IndexParams::createFromParameters(*flann_params);
319                 Index* index = new Index(Matrix<float>(rows,cols,dataset), *params);
320                 Matrix<int> m_indices(tcount, nn, result);
321                 Matrix<float> m_dists(tcount, nn, dists);
322                 index->knnSearch(Matrix<float>(tcount, index->veclen(), testset),
323                                                 m_indices,
324                                                 m_dists, nn, SearchParams(flann_params->checks) );
325         }
326         catch(runtime_error& e) {
327                 logger.error("Caught exception: %s\n",e.what());
328         _result = -1;
329         }
330
331         return _result;
332 }
333
334
335 EXPORTED int flann_find_nearest_neighbors_index(flann_index_t index_ptr, float* testset, int tcount, int* result, float* dists, int nn, int checks, FLANNParameters* flann_params)
336 {
337         try {
338                 init_flann_parameters(flann_params);
339                 if (index_ptr==NULL) {
340                         throw FLANNException("Invalid index");
341                 }
342                 Index* index = (Index*) index_ptr;
343
344                 Matrix<int> m_indices(tcount, nn, result);
345                 Matrix<float> m_dists(tcount, nn, dists);
346                 index->knnSearch(Matrix<float>(tcount, index->veclen(), testset),
347                                                 m_indices,
348                                                 m_dists, nn, SearchParams(checks) );
349
350         }
351         catch(runtime_error& e) {
352                 logger.error("Caught exception: %s\n",e.what());
353                 return -1;
354         }
355
356         return -1;
357 }
358
359
360 EXPORTED int flann_radius_search(FLANN_INDEX index_ptr,
361                                                                                 float* query,
362                                                                                 int* indices,
363                                                                                 float* dists,
364                                                                                 int max_nn,
365                                                                                 float radius,
366                                                                                 int checks,
367                                                                                 FLANNParameters* flann_params)
368 {
369         try {
370                 init_flann_parameters(flann_params);
371                 if (index_ptr==NULL) {
372                         throw FLANNException("Invalid index");
373                 }
374                 Index* index = (Index*) index_ptr;
375
376                 Matrix<int> m_indices(1, max_nn, indices);
377                 Matrix<float> m_dists(1, max_nn, dists);
378                 int count = index->radiusSearch(Matrix<float>(1, index->veclen(), query),
379                                                 m_indices,
380                                                 m_dists, radius, SearchParams(checks) );
381
382
383                 return count;
384         }
385         catch(runtime_error& e) {
386                 logger.error("Caught exception: %s\n",e.what());
387                 return -1;
388         }
389
390 }
391
392
393 EXPORTED int flann_free_index(FLANN_INDEX index_ptr, FLANNParameters* flann_params)
394 {
395         try {
396                 init_flann_parameters(flann_params);
397         if (index_ptr==NULL) {
398             throw FLANNException("Invalid index");
399         }
400         Index* index = (Index*) index_ptr;
401         delete index;
402
403         return 0;
404         }
405         catch(runtime_error& e) {
406                 logger.error("Caught exception: %s\n",e.what());
407         return -1;
408         }
409 }
410
411
412 EXPORTED int flann_compute_cluster_centers(float* dataset, int rows, int cols, int clusters, float* result, FLANNParameters* flann_params)
413 {
414         try {
415                 init_flann_parameters(flann_params);
416
417         MatrixPtr inputData = new Matrix<float>(rows,cols,dataset);
418         KMeansIndexParams params(flann_params->branching, flann_params->iterations, flann_params->centers_init, flann_params->cb_index);
419                 Matrix<float> centers(clusters, cols, result);
420         int clusterNum = hierarchicalClustering(*inputData,centers, params);
421
422                 return clusterNum;
423         } catch (runtime_error& e) {
424                 logger.error("Caught exception: %s\n",e.what());
425                 return -1;
426         }
427 }
428
429
430 EXPORTED void compute_ground_truth_float(float* dataset, int dshape[], float* testset, int tshape[], int* match, int mshape[], int skip)
431 {
432     assert(dshape[1]==tshape[1]);
433     assert(tshape[0]==mshape[0]);
434
435     Matrix<int> _match(mshape[0], mshape[1], match);
436     compute_ground_truth(Matrix<float>(dshape[0], dshape[1], dataset), Matrix<float>(tshape[0], tshape[1], testset), _match, skip);
437 }
438
439
440 EXPORTED float test_with_precision(FLANN_INDEX index_ptr, float* dataset, int dshape[], float* testset, int tshape[], int* matches, int mshape[],
441              int nn, float precision, int* checks, int skip = 0)
442 {
443     assert(dshape[1]==tshape[1]);
444     assert(tshape[0]==mshape[0]);
445
446     try {
447         if (index_ptr==NULL) {
448             throw FLANNException("Invalid index");
449         }
450         NNIndexPtr index = (NNIndexPtr)index_ptr;
451         return test_index_precision(*index, Matrix<float>(dshape[0], dshape[1],dataset), Matrix<float>(tshape[0], tshape[1], testset),
452                 Matrix<int>(mshape[0],mshape[1],matches), precision, *checks, nn, skip);
453     } catch (runtime_error& e) {
454         logger.error("Caught exception: %s\n",e.what());
455         return -1;
456     }
457 }
458
459 EXPORTED float test_with_checks(FLANN_INDEX index_ptr, float* dataset, int dshape[], float* testset, int tshape[], int* matches, int mshape[],
460              int nn, int checks, float* precision, int skip = 0)
461 {
462     assert(dshape[1]==tshape[1]);
463     assert(tshape[0]==mshape[0]);
464
465     try {
466         if (index_ptr==NULL) {
467             throw FLANNException("Invalid index");
468         }
469         NNIndexPtr index = (NNIndexPtr)index_ptr;
470         return test_index_checks(*index, Matrix<float>(dshape[0], dshape[1],dataset), Matrix<float>(tshape[0], tshape[1], testset),
471                 Matrix<int>(mshape[0],mshape[1],matches), checks, *precision, nn, skip);
472     } catch (runtime_error& e) {
473         logger.error("Caught exception: %s\n",e.what());
474         return -1;
475     }
476 }