1 /***********************************************************************
2 * Software License Agreement (BSD License)
4 * Copyright 2008-2009 Marius Muja (mariusm@cs.ubc.ca). All rights reserved.
5 * Copyright 2008-2009 David G. Lowe (lowe@cs.ubc.ca). All rights reserved.
9 * Redistribution and use in source and binary forms, with or without
10 * modification, are permitted provided that the following conditions
13 * 1. Redistributions of source code must retain the above copyright
14 * notice, this list of conditions and the following disclaimer.
15 * 2. Redistributions in binary form must reproduce the above copyright
16 * notice, this list of conditions and the following disclaimer in the
17 * documentation and/or other materials provided with the distribution.
19 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
20 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
21 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
22 * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
23 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
24 * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
25 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
26 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
28 * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 *************************************************************************/
41 #include "constants.h"
44 #include "allocator.h"
46 #include "result_set.h"
57 * Chooses the initial centers in the k-means clustering in a random manner.
60 * k = number of centers
61 * vecs = the dataset of points
62 * indices = indices in the dataset
63 * indices_length = length of indices vector
66 void chooseCentersRandom(int k, const Matrix<float>& vecs, int* indices, int indices_length, float** centers, int& centers_length)
68 UniqueRandom r(indices_length);
71 for (index=0;index<k;++index) {
72 bool duplicate = true;
78 centers_length = index;
82 centers[index] = vecs[indices[rnd]];
84 for (int j=0;j<index;++j) {
85 float sq = flann_dist(centers[index],centers[index]+vecs.cols,centers[j]);
93 centers_length = index;
98 * Chooses the initial centers in the k-means using Gonzales' algorithm
99 * so that the centers are spaced apart from each other.
102 * k = number of centers
103 * vecs = the dataset of points
104 * indices = indices in the dataset
107 void chooseCentersGonzales(int k, const Matrix<float>& vecs, int* indices, int indices_length, float** centers, int& centers_length)
109 int n = indices_length;
112 int rnd = rand_int(n);
113 assert(rnd >=0 && rnd < n);
115 centers[0] = vecs[indices[rnd]];
118 for (index=1; index<k; ++index) {
122 for (int j=0;j<n;++j) {
123 float dist = flann_dist(centers[0],centers[0]+vecs.cols,vecs[indices[j]]);
124 for (int i=1;i<index;++i) {
125 float tmp_dist = flann_dist(centers[i],centers[i]+vecs.cols,vecs[indices[j]]);
135 if (best_index!=-1) {
136 centers[index] = vecs[indices[best_index]];
142 centers_length = index;
147 * Chooses the initial centers in the k-means using the algorithm
148 * proposed in the KMeans++ paper:
149 * Arthur, David; Vassilvitskii, Sergei - k-means++: The Advantages of Careful Seeding
151 * Implementation of this function was converted from the one provided in Arthur's code.
154 * k = number of centers
155 * vecs = the dataset of points
156 * indices = indices in the dataset
159 void chooseCentersKMeanspp(int k, const Matrix<float>& vecs, int* indices, int indices_length, float** centers, int& centers_length)
161 int n = indices_length;
163 double currentPot = 0;
164 double* closestDistSq = new double[n];
166 // Choose one random center and set the closestDistSq values
167 int index = rand_int(n);
168 assert(index >=0 && index < n);
169 centers[0] = vecs[indices[index]];
171 for (int i = 0; i < n; i++) {
172 closestDistSq[i] = flann_dist(vecs[indices[i]], vecs[indices[i]] + vecs.cols, vecs[indices[index]]);
173 currentPot += closestDistSq[i];
177 const int numLocalTries = 1;
179 // Choose each center
181 for (centerCount = 1; centerCount < k; centerCount++) {
183 // Repeat several trials
184 double bestNewPot = -1;
185 int bestNewIndex = 0;
186 for (int localTrial = 0; localTrial < numLocalTries; localTrial++) {
188 // Choose our center - have to be slightly careful to return a valid answer even accounting
189 // for possible rounding errors
190 double randVal = rand_double(currentPot);
191 for (index = 0; index < n-1; index++) {
192 if (randVal <= closestDistSq[index])
195 randVal -= closestDistSq[index];
198 // Compute the new potential
200 for (int i = 0; i < n; i++)
201 newPot += min( (double)flann_dist(vecs[indices[i]], vecs[indices[i]] + vecs.cols, vecs[indices[index]]), closestDistSq[i] );
203 // Store the best result
204 if (bestNewPot < 0 || newPot < bestNewPot) {
206 bestNewIndex = index;
210 // Add the appropriate center
211 centers[centerCount] = vecs[indices[bestNewIndex]];
212 currentPot = bestNewPot;
213 for (int i = 0; i < n; i++)
214 closestDistSq[i] = min( (double)flann_dist(vecs[indices[i]], vecs[indices[i]]+vecs.cols, vecs[indices[bestNewIndex]]), closestDistSq[i] );
217 centers_length = centerCount;
219 delete[] closestDistSq;
227 typedef void (*centersAlgFunction)(int, const Matrix<float>&, int*, int, float**, int&);
229 * Associative array with functions to use for choosing the cluster centers.
231 map<flann_centers_init_t,centersAlgFunction> centerAlgs;
233 * Static initializer. Performs initialization befor the program starts.
238 centerAlgs[CENTERS_RANDOM] = &chooseCentersRandom;
239 centerAlgs[CENTERS_GONZALES] = &chooseCentersGonzales;
240 centerAlgs[CENTERS_KMEANSPP] = &chooseCentersKMeanspp;
244 Init() { centers_init(); }
254 * Hierarchical kmeans index
256 * Contains a tree constructed through a hierarchical kmeans clustering
257 * and other information for indexing a set of points for nearest-neighbor matching.
259 class KMeansIndex : public NNIndex
263 * The branching factor used in the hierarchical k-means clustering
268 * Maximum number of iterations to use when performing k-means
274 * Cluster border index. This is used in the tree search phase when determining
275 * the closest cluster to explore next. A zero value takes into account only
276 * the cluster centers, a value greater then zero also take into account the size
282 * The dataset used by this index
284 const Matrix<float> dataset;
287 * Number of features in the dataset.
292 * Length of each feature.
298 * Struture representing a node in the hierarchical k-means tree.
300 struct KMeansNodeSt {
302 * The cluster center.
306 * The cluster radius.
310 * The cluster mean radius.
314 * The cluster variance.
318 * The cluster size (number of points in the cluster)
322 * Child nodes (only for non-terminal nodes)
324 KMeansNodeSt** childs;
326 * Node points (only for terminal nodes)
334 typedef KMeansNodeSt* KMeansNode;
339 * Alias definition for a nicer syntax.
341 typedef BranchStruct<KMeansNode> BranchSt;
344 * Priority queue storing intermediate branches in the best-bin-first search
346 Heap<BranchSt>* heap;
351 * The root node in the tree.
356 * Array of indices to vectors in the dataset.
362 * Pooled memory allocator.
364 * Using a pooled memory allocator is more efficient
365 * than allocating memory directly when there is a large
366 * number small of memory allocations.
368 PooledAllocator pool;
371 * Memory occupied by the index.
377 * The function used for choosing the cluster centers.
379 centersAlgFunction chooseCenters;
386 flann_algorithm_t getType() const
395 * inputData = dataset with the input features
396 * params = parameters passed to the hierarchical k-means algorithm
398 KMeansIndex(const Matrix<float>& inputData, const KMeansIndexParams& params = KMeansIndexParams() )
399 : dataset(inputData), root(NULL), indices(NULL)
403 size_ = dataset.rows;
404 veclen_ = dataset.cols;
406 branching = params.branching;
407 max_iter = params.iterations;
409 max_iter = numeric_limits<int>::max();
411 flann_centers_init_t centersInit = params.centers_init;
413 if ( centerAlgs.find(centersInit) != centerAlgs.end() ) {
414 chooseCenters = centerAlgs[centersInit];
417 throw FLANNException("Unknown algorithm for choosing initial centers.");
421 heap = new Heap<BranchSt>(size_);
428 * Release the memory used by the index.
430 virtual ~KMeansIndex()
442 * Returns size of index.
450 * Returns the length of an index feature.
458 void set_cb_index( float index)
465 * Computes the inde memory usage
466 * Returns: memory used by the index
468 int usedMemory() const
470 return pool.usedMemory+pool.wastedMemory+memoryCounter;
479 throw FLANNException("Branching factor must be at least 2");
482 indices = new int[size_];
483 for (int i=0;i<size_;++i) {
487 root = pool.allocate<KMeansNodeSt>();
488 computeNodeStatistics(root, indices, size_);
489 computeClustering(root, indices, size_, branching,0);
493 void saveIndex(FILE* stream)
495 save_header(stream, *this);
496 save_value(stream, branching);
497 save_value(stream, max_iter);
498 save_value(stream, memoryCounter);
499 save_value(stream, cb_index);
500 save_value(stream, *indices, size_);
502 save_tree(stream, root);
507 void loadIndex(FILE* stream)
509 IndexHeader header = load_header(stream);
511 if (header.rows!=size() || header.cols!=veclen()) {
512 throw FLANNException("The index saved belongs to a different dataset");
514 load_value(stream, branching);
515 load_value(stream, max_iter);
516 load_value(stream, memoryCounter);
517 load_value(stream, cb_index);
521 indices = new int[size_];
522 load_value(stream, *indices, size_);
527 load_tree(stream, root);
532 * Find set of nearest neighbors to vec. Their indices are stored inside
536 * result = the result object in which the indices of the nearest-neighbors are stored
537 * vec = the vector for which to search the nearest neighbors
538 * searchParams = parameters that influence the search algorithm (checks, cb_index)
540 void findNeighbors(ResultSet& result, const float* vec, const SearchParams& searchParams)
542 int maxChecks = searchParams.checks;
545 findExactNN(root, result, vec);
551 findNN(root, result, vec, checks, maxChecks);
554 while (heap->popMin(branch) && (checks<maxChecks || !result.full())) {
555 KMeansNode node = branch.node;
556 findNN(node, result, vec, checks, maxChecks);
558 assert(result.full());
565 * Clustering function that takes a cut in the hierarchical k-means
566 * tree and return the clusters centers of that clustering.
568 * numClusters = number of clusters to have in the clustering computed
569 * Returns: number of cluster centers
571 int getClusterCenters(Matrix<float>& centers)
573 int numClusters = centers.rows;
575 throw FLANNException("Number of clusters must be at least 1");
579 KMeansNode* clusters = new KMeansNode[numClusters];
581 int clusterCount = getMinVarianceClusters(root, clusters, numClusters, variance);
583 // logger.info("Clusters requested: %d, returning %d\n",numClusters, clusterCount);
586 for (int i=0;i<clusterCount;++i) {
587 float* center = clusters[i]->pivot;
588 for (int j=0;j<veclen_;++j) {
589 centers[i][j] = center[j];
597 // Params estimateSearchParams(float precision, Dataset<float>* testset = NULL)
609 void save_tree(FILE* stream, KMeansNode node)
611 save_value(stream, *node);
612 save_value(stream, *(node->pivot), veclen_);
613 if (node->childs==NULL) {
614 int indices_offset = node->indices - indices;
615 save_value(stream, indices_offset);
618 for(int i=0; i<branching; ++i) {
619 save_tree(stream, node->childs[i]);
625 void load_tree(FILE* stream, KMeansNode& node)
627 node = pool.allocate<KMeansNodeSt>();
628 load_value(stream, *node);
629 node->pivot = new float[veclen_];
630 load_value(stream, *(node->pivot), veclen_);
631 if (node->childs==NULL) {
633 load_value(stream, indices_offset);
634 node->indices = indices + indices_offset;
637 node->childs = pool.allocate<KMeansNode>(branching);
638 for(int i=0; i<branching; ++i) {
639 load_tree(stream, node->childs[i]);
648 void free_centers(KMeansNode node)
650 delete[] node->pivot;
651 if (node->childs!=NULL) {
652 for (int k=0;k<branching;++k) {
653 free_centers(node->childs[k]);
659 * Computes the statistics of a node (mean, radius, variance).
662 * node = the node to use
663 * indices = the indices of the points belonging to the node
665 void computeNodeStatistics(KMeansNode node, int* indices, int indices_length) {
669 float* mean = new float[veclen_];
670 memoryCounter += veclen_*sizeof(float);
672 memset(mean,0,veclen_*sizeof(float));
674 for (int i=0;i<size_;++i) {
675 float* vec = dataset[indices[i]];
676 for (int j=0;j<veclen_;++j) {
679 variance += flann_dist(vec,vec+veclen_,zero);
681 for (int j=0;j<veclen_;++j) {
685 variance -= flann_dist(mean,mean+veclen_,zero);
688 for (int i=0;i<indices_length;++i) {
689 tmp = flann_dist(mean, mean + veclen_, dataset[indices[i]]);
695 node->variance = variance;
696 node->radius = radius;
702 * The method responsible with actually doing the recursive hierarchical
706 * node = the node to cluster
707 * indices = indices of the points belonging to the current node
708 * branching = the branching factor to use in the clustering
710 * TODO: for 1-sized clusters don't store a cluster center (it's the same as the single cluster point)
712 void computeClustering(KMeansNode node, int* indices, int indices_length, int branching, int level)
714 node->size = indices_length;
717 if (indices_length < branching) {
718 node->indices = indices;
719 sort(node->indices,node->indices+indices_length);
724 float** initial_centers = new float*[branching];
726 chooseCenters(branching, dataset, indices, indices_length, initial_centers, centers_length);
728 if (centers_length<branching) {
729 node->indices = indices;
730 sort(node->indices,node->indices+indices_length);
736 Matrix<double> dcenters(branching,veclen_);
737 for (int i=0; i<centers_length; ++i) {
738 for (int k=0; k<veclen_; ++k) {
739 dcenters[i][k] = double(initial_centers[i][k]);
742 delete[] initial_centers;
744 float* radiuses = new float[branching];
745 int* count = new int[branching];
746 for (int i=0;i<branching;++i) {
751 // assign points to clusters
752 int* belongs_to = new int[indices_length];
753 for (int i=0;i<indices_length;++i) {
755 float sq_dist = flann_dist(dataset[indices[i]], dataset[indices[i]] + veclen_ ,dcenters[0]);
757 for (int j=1;j<branching;++j) {
758 float new_sq_dist = flann_dist(dataset[indices[i]], dataset[indices[i]]+veclen_, dcenters[j]);
759 if (sq_dist>new_sq_dist) {
761 sq_dist = new_sq_dist;
764 if (sq_dist>radiuses[belongs_to[i]]) {
765 radiuses[belongs_to[i]] = sq_dist;
767 count[belongs_to[i]]++;
770 bool converged = false;
772 while (!converged && iteration<max_iter) {
776 // compute the new cluster centers
777 for (int i=0;i<branching;++i) {
778 memset(dcenters[i],0,sizeof(double)*veclen_);
781 for (int i=0;i<indices_length;++i) {
782 float* vec = dataset[indices[i]];
783 double* center = dcenters[belongs_to[i]];
784 for (int k=0;k<veclen_;++k) {
788 for (int i=0;i<branching;++i) {
790 for (int k=0;k<veclen_;++k) {
791 dcenters[i][k] /= cnt;
795 // reassign points to clusters
796 for (int i=0;i<indices_length;++i) {
797 float sq_dist = flann_dist(dataset[indices[i]], dataset[indices[i]]+veclen_ ,dcenters[0]);
798 int new_centroid = 0;
799 for (int j=1;j<branching;++j) {
800 float new_sq_dist = flann_dist(dataset[indices[i]], dataset[indices[i]]+veclen_,dcenters[j]);
801 if (sq_dist>new_sq_dist) {
803 sq_dist = new_sq_dist;
806 if (sq_dist>radiuses[new_centroid]) {
807 radiuses[new_centroid] = sq_dist;
809 if (new_centroid != belongs_to[i]) {
810 count[belongs_to[i]]--;
811 count[new_centroid]++;
812 belongs_to[i] = new_centroid;
818 for (int i=0;i<branching;++i) {
819 // if one cluster converges to an empty cluster,
820 // move an element into that cluster
822 int j = (i+1)%branching;
823 while (count[j]<=1) {
827 for (int k=0;k<indices_length;++k) {
828 if (belongs_to[k]==j) {
841 float** centers = new float*[branching];
843 for (int i=0; i<branching; ++i) {
844 centers[i] = new float[veclen_];
845 memoryCounter += veclen_*sizeof(float);
846 for (int k=0; k<veclen_; ++k) {
847 centers[i][k] = (float)dcenters[i][k];
852 // compute kmeans clustering for each of the resulting clusters
853 node->childs = pool.allocate<KMeansNode>(branching);
856 for (int c=0;c<branching;++c) {
860 float mean_radius =0;
861 for (int i=0;i<indices_length;++i) {
862 if (belongs_to[i]==c) {
863 float d = flann_dist(dataset[indices[i]],dataset[indices[i]]+veclen_,zero);
865 mean_radius += sqrt(d);
866 swap(indices[i],indices[end]);
867 swap(belongs_to[i],belongs_to[end]);
873 variance -= flann_dist(centers[c],centers[c]+veclen_,zero);
875 node->childs[c] = pool.allocate<KMeansNodeSt>();
876 node->childs[c]->radius = radiuses[c];
877 node->childs[c]->pivot = centers[c];
878 node->childs[c]->variance = variance;
879 node->childs[c]->mean_radius = mean_radius;
880 node->childs[c]->indices = NULL;
881 computeClustering(node->childs[c],indices+start, end-start, branching, level+1);
894 * Performs one descent in the hierarchical k-means tree. The branches not
895 * visited are stored in a priority queue.
898 * node = node to explore
899 * result = container for the k-nearest neighbors found
901 * checks = how many points in the dataset have been checked so far
902 * maxChecks = maximum dataset points to checks
906 void findNN(KMeansNode node, ResultSet& result, const float* vec, int& checks, int maxChecks)
908 // Ignore those clusters that are too far away
910 float bsq = flann_dist(vec, vec+veclen_, node->pivot);
911 float rsq = node->radius;
912 float wsq = result.worstDist();
914 float val = bsq-rsq-wsq;
915 float val2 = val*val-4*rsq*wsq;
918 if (val>0 && val2>0) {
923 if (node->childs==NULL) {
924 if (checks>=maxChecks) {
925 if (result.full()) return;
927 checks += node->size;
928 for (int i=0;i<node->size;++i) {
929 result.addPoint(dataset[node->indices[i]], node->indices[i]);
933 float* domain_distances = new float[branching];
934 int closest_center = exploreNodeBranches(node, vec, domain_distances);
935 delete[] domain_distances;
936 findNN(node->childs[closest_center],result,vec, checks, maxChecks);
941 * Helper function that computes the nearest childs of a node to a given query point.
944 * q = the query point
945 * distances = array with the distances to each child node.
948 int exploreNodeBranches(KMeansNode node, const float* q, float* domain_distances)
952 domain_distances[best_index] = flann_dist(q,q+veclen_,node->childs[best_index]->pivot);
953 for (int i=1;i<branching;++i) {
954 domain_distances[i] = flann_dist(q,q+veclen_,node->childs[i]->pivot);
955 if (domain_distances[i]<domain_distances[best_index]) {
960 // float* best_center = node->childs[best_index]->pivot;
961 for (int i=0;i<branching;++i) {
962 if (i != best_index) {
963 domain_distances[i] -= cb_index*node->childs[i]->variance;
965 // float dist_to_border = getDistanceToBorder(node.childs[i].pivot,best_center,q);
966 // if (domain_distances[i]<dist_to_border) {
967 // domain_distances[i] = dist_to_border;
969 heap->insert(BranchSt::make_branch(node->childs[i],domain_distances[i]));
978 * Function the performs exact nearest neighbor search by traversing the entire tree.
980 void findExactNN(KMeansNode node, ResultSet& result, const float* vec)
982 // Ignore those clusters that are too far away
984 float bsq = flann_dist(vec, vec+veclen_, node->pivot);
985 float rsq = node->radius;
986 float wsq = result.worstDist();
988 float val = bsq-rsq-wsq;
989 float val2 = val*val-4*rsq*wsq;
992 if (val>0 && val2>0) {
998 if (node->childs==NULL) {
999 for (int i=0;i<node->size;++i) {
1000 result.addPoint(dataset[node->indices[i]], node->indices[i]);
1004 int* sort_indices = new int[branching];
1006 getCenterOrdering(node, vec, sort_indices);
1008 for (int i=0; i<branching; ++i) {
1009 findExactNN(node->childs[sort_indices[i]],result,vec);
1012 delete[] sort_indices;
1020 * I computes the order in which to traverse the child nodes of a particular node.
1022 void getCenterOrdering(KMeansNode node, const float* q, int* sort_indices)
1024 float* domain_distances = new float[branching];
1025 for (int i=0;i<branching;++i) {
1026 float dist = flann_dist(q, q+veclen_, node->childs[i]->pivot);
1029 while (domain_distances[j]<dist && j<i) j++;
1030 for (int k=i;k>j;--k) {
1031 domain_distances[k] = domain_distances[k-1];
1032 sort_indices[k] = sort_indices[k-1];
1034 domain_distances[j] = dist;
1035 sort_indices[j] = i;
1037 delete[] domain_distances;
1041 * Method that computes the squared distance from the query point q
1042 * from inside region with center c to the border between this
1043 * region and the region with center p
1045 float getDistanceToBorder(float* p, float* c, float* q)
1050 for (int i=0;i<veclen_; ++i) {
1051 float t = c[i]-p[i];
1052 sum += t*(q[i]-(c[i]+p[i])/2);
1056 return sum*sum/sum2;
1061 * Helper function the descends in the hierarchical k-means tree by spliting those clusters that minimize
1062 * the overall variance of the clustering.
1065 * clusters = array with clusters centers (return value)
1066 * varianceValue = variance of the clustering (return value)
1069 int getMinVarianceClusters(KMeansNode root, KMeansNode* clusters, int clusters_length, float& varianceValue)
1071 int clusterCount = 1;
1074 float meanVariance = root->variance*root->size;
1076 while (clusterCount<clusters_length) {
1077 float minVariance = numeric_limits<float>::max();
1078 int splitIndex = -1;
1080 for (int i=0;i<clusterCount;++i) {
1081 if (clusters[i]->childs != NULL) {
1083 float variance = meanVariance - clusters[i]->variance*clusters[i]->size;
1085 for (int j=0;j<branching;++j) {
1086 variance += clusters[i]->childs[j]->variance*clusters[i]->childs[j]->size;
1088 if (variance<minVariance) {
1089 minVariance = variance;
1095 if (splitIndex==-1) break;
1096 if ( (branching+clusterCount-1) > clusters_length) break;
1098 meanVariance = minVariance;
1101 KMeansNode toSplit = clusters[splitIndex];
1102 clusters[splitIndex] = toSplit->childs[0];
1103 for (int i=1;i<branching;++i) {
1104 clusters[clusterCount++] = toSplit->childs[i];
1108 varianceValue = meanVariance/root->size;
1109 return clusterCount;
1115 //register_index(KMEANS,KMeansTree)
1119 #endif //KMEANSTREE_H