Update to 2.0.0 tree from current Fremantle build
[opencv] / src / cv / cvspilltree.cpp
diff --git a/src/cv/cvspilltree.cpp b/src/cv/cvspilltree.cpp
new file mode 100644 (file)
index 0000000..de61b89
--- /dev/null
@@ -0,0 +1,484 @@
+/* Original code has been submitted by Liu Liu.
+   ----------------------------------------------------------------------------------
+   * Spill-Tree for Approximate KNN Search
+   * Author: Liu Liu
+   * mailto: liuliu.1987+opencv@gmail.com
+   * Refer to Paper:
+   * An Investigation of Practical Approximate Nearest Neighbor Algorithms
+   * cvMergeSpillTree TBD
+   *
+   * Redistribution and use in source and binary forms, with or
+   * without modification, are permitted provided that the following
+   * conditions are met:
+   *   Redistributions of source code must retain the above
+   *   copyright notice, this list of conditions and the following
+   *   disclaimer.
+   *   Redistributions in binary form must reproduce the above
+   *   copyright notice, this list of conditions and the following
+   *   disclaimer in the documentation and/or other materials
+   *   provided with the distribution.
+   *   The name of Contributor may not be used to endorse or
+   *   promote products derived from this software without
+   *   specific prior written permission.
+   *
+   * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
+   * CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
+   * INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
+   * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+   * DISCLAIMED. IN NO EVENT SHALL THE CONTRIBUTORS BE LIABLE FOR ANY
+   * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+   * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+   * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,
+   * OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+   * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR
+   * TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT
+   * OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY
+   * OF SUCH DAMAGE.
+   */
+
+#include "_cv.h"
+#include "_cvfeaturetree.h"
+
+struct CvSpillTreeNode
+{
+  bool leaf; // is leaf or not (leaf is the point that have no more child)
+  bool spill; // is not a non-overlapping point (defeatist search)
+  CvSpillTreeNode* lc; // left child (<)
+  CvSpillTreeNode* rc; // right child (>)
+  int cc; // child count
+  CvMat* u; // projection vector
+  CvMat* center; // center
+  int i; // original index
+  double r; // radius of remaining feature point
+  double ub; // upper bound
+  double lb; // lower bound
+  double mp; // mean point
+  double p; // projection value
+};
+
+struct CvSpillTree
+{
+  CvSpillTreeNode* root;
+  CvMat** refmat; // leaf ref matrix
+  bool* cache; // visited or not
+  int total; // total leaves
+  int naive; // under this value, we perform naive search
+  int type; // mat type
+  double rho; // under this value, it is a spill tree
+  double tau; // the overlapping buffer ratio
+};
+
+// find the farthest node in the "list" from "node"
+static inline CvSpillTreeNode*
+icvFarthestNode( CvSpillTreeNode* node,
+                CvSpillTreeNode* list,
+                int total )
+{
+  double farthest = -1.;
+  CvSpillTreeNode* result = NULL;
+  for ( int i = 0; i < total; i++ )
+    {
+      double norm = cvNorm( node->center, list->center );
+      if ( norm > farthest )
+       {
+         farthest = norm;
+         result = list;
+       }
+      list = list->rc;
+    }
+  return result;
+}
+
+// clone a new tree node
+static inline CvSpillTreeNode*
+icvCloneSpillTreeNode( CvSpillTreeNode* node )
+{
+  CvSpillTreeNode* result = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) );
+  memcpy( result, node, sizeof(CvSpillTreeNode) );
+  return result;
+}
+
+// append the link-list of a tree node
+static inline void
+icvAppendSpillTreeNode( CvSpillTreeNode* node,
+                       CvSpillTreeNode* append )
+{
+  if ( node->lc == NULL )
+    {
+      node->lc = node->rc = append;
+      node->lc->lc = node->rc->rc = NULL;
+    } else {
+      append->lc = node->rc;
+      append->rc = NULL;
+      node->rc->rc = append;
+      node->rc = append;
+    }
+  node->cc++;
+}
+
+#define _dispatch_mat_ptr(x, step) (CV_MAT_DEPTH((x)->type) == CV_32F ? (void*)((x)->data.fl+(step)) : (CV_MAT_DEPTH((x)->type) == CV_64F ? (void*)((x)->data.db+(step)) : (void*)(0)))
+
+static void
+icvDFSInitSpillTreeNode( const CvSpillTree* tr,
+                        const int d,
+                        CvSpillTreeNode* node )
+{
+  if ( node->cc <= tr->naive )
+    {
+      // already get to a leaf, terminate the recursion.
+      node->leaf = true;
+      node->spill = false;
+      return;
+    }
+
+  // random select a node, then find a farthest node from this one, then find a farthest from that one...
+  // to approximate the farthest node-pair
+  static CvRNG rng_state = cvRNG(0xdeadbeef);
+  int rn = cvRandInt( &rng_state ) % node->cc;
+  CvSpillTreeNode* lnode = NULL;
+  CvSpillTreeNode* rnode = node->lc;
+  for ( int i = 0; i < rn; i++ )
+    rnode = rnode->rc;
+  lnode = icvFarthestNode( rnode, node->lc, node->cc );
+  rnode = icvFarthestNode( lnode, node->lc, node->cc );
+
+  // u is the projection vector
+  node->u = cvCreateMat( 1, d, tr->type );
+  cvSub( lnode->center, rnode->center, node->u );
+  cvNormalize( node->u, node->u );
+       
+  // find the center of node in hyperspace
+  node->center = cvCreateMat( 1, d, tr->type );
+  cvZero( node->center );
+  CvSpillTreeNode* it = node->lc;
+  for ( int i = 0; i < node->cc; i++ )
+    {
+      cvAdd( it->center, node->center, node->center );
+      it = it->rc;
+    }
+  cvConvertScale( node->center, node->center, 1./node->cc );
+
+  // project every node to "u", and find the mean point "mp"
+  it = node->lc;
+  node->r = -1.;
+  node->mp = 0;
+  for ( int i = 0; i < node->cc; i++ )
+    {
+      node->mp += ( it->p = cvDotProduct( it->center, node->u ) );
+      double norm = cvNorm( node->center, it->center );
+      if ( norm > node->r )
+       node->r = norm;
+      it = it->rc;
+    }
+  node->mp = node->mp / node->cc;
+
+  // overlapping buffer and upper bound, lower bound
+  double ob = (lnode->p-rnode->p)*tr->tau*.5;
+  node->ub = node->mp+ob;
+  node->lb = node->mp-ob;
+  int sl = 0, l = 0;
+  int sr = 0, r = 0;
+  it = node->lc;
+  for ( int i = 0; i < node->cc; i++ )
+    {
+      if ( it->p <= node->ub )
+       sl++;
+      if ( it->p >= node->lb )
+       sr++;
+      if ( it->p < node->mp )
+       l++;
+      else
+       r++;
+      it = it->rc;
+    }
+  // precision problem, return the node as it is.
+  if (( l == 0 )||( r == 0 ))
+    {
+      cvReleaseMat( &(node->u) );
+      cvReleaseMat( &(node->center) );
+      node->leaf = true;
+      node->spill = false;
+      return;
+    }
+  CvSpillTreeNode* lc = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) );
+  memset(lc, 0, sizeof(CvSpillTreeNode));
+  CvSpillTreeNode* rc = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) );
+  memset(rc, 0, sizeof(CvSpillTreeNode));
+  lc->lc = lc->rc = rc->lc = rc->rc = NULL;
+  lc->cc = rc->cc = 0;
+  int undo = cvRound(node->cc*tr->rho);
+  if (( sl >= undo )||( sr >= undo ))
+    {
+      // it is not a spill point (defeatist search disabled)
+      it = node->lc;
+      for ( int i = 0; i < node->cc; i++ )
+       {
+         CvSpillTreeNode* next = it->rc;
+         if ( it->p < node->mp )
+           icvAppendSpillTreeNode( lc, it );
+         else
+           icvAppendSpillTreeNode( rc, it );
+         it = next;
+       }
+      node->spill = false;
+    } else {
+      // a spill point
+      it = node->lc;
+      for ( int i = 0; i < node->cc; i++ )
+       {
+         CvSpillTreeNode* next = it->rc;
+         if ( it->p < node->lb )
+           icvAppendSpillTreeNode( lc, it );
+         else if ( it->p > node->ub )
+           icvAppendSpillTreeNode( rc, it );
+         else {
+           CvSpillTreeNode* cit = icvCloneSpillTreeNode( it );
+           icvAppendSpillTreeNode( lc, it );
+           icvAppendSpillTreeNode( rc, cit );
+         }
+         it = next;
+       }
+      node->spill = true;
+    }
+  node->lc = lc;
+  node->rc = rc;
+
+  // recursion process
+  icvDFSInitSpillTreeNode( tr, d, node->lc );
+  icvDFSInitSpillTreeNode( tr, d, node->rc );
+}
+
+static CvSpillTree*
+icvCreateSpillTree( const CvMat* raw_data,
+                   const int naive,
+                   const double rho,
+                   const double tau )
+{
+  int n = raw_data->rows;
+  int d = raw_data->cols;
+
+  CvSpillTree* tr = (CvSpillTree*)cvAlloc( sizeof(CvSpillTree) );
+  tr->root = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) );
+  memset(tr->root, 0, sizeof(CvSpillTreeNode));
+  tr->refmat = (CvMat**)cvAlloc( sizeof(CvMat*)*n );
+  tr->cache = (bool*)cvAlloc( sizeof(bool)*n );
+  tr->total = n;
+  tr->naive = naive;
+  tr->rho = rho;
+  tr->tau = tau;
+  tr->type = raw_data->type;
+
+  // tie a link-list to the root node
+  tr->root->lc = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) );
+  memset(tr->root->lc, 0, sizeof(CvSpillTreeNode));
+  tr->root->lc->center = cvCreateMatHeader( 1, d, tr->type );
+  cvSetData( tr->root->lc->center, _dispatch_mat_ptr(raw_data, 0), raw_data->step );
+  tr->refmat[0] = tr->root->lc->center;
+  tr->root->lc->lc = NULL;
+  tr->root->lc->leaf = true;
+  tr->root->lc->i = 0;
+  CvSpillTreeNode* node = tr->root->lc;
+  for ( int i = 1; i < n; i++ )
+    {
+      CvSpillTreeNode* newnode = (CvSpillTreeNode*)cvAlloc( sizeof(CvSpillTreeNode) );
+      memset(newnode, 0, sizeof(CvSpillTreeNode));
+      newnode->center = cvCreateMatHeader( 1, d, tr->type );
+      cvSetData( newnode->center, _dispatch_mat_ptr(raw_data, i*d), raw_data->step );
+      tr->refmat[i] = newnode->center;
+      newnode->lc = node;
+      newnode->i = i;
+      newnode->leaf = true;
+      newnode->rc = NULL;
+      node->rc = newnode;
+      node = newnode;
+    }
+  tr->root->rc = node;
+  tr->root->cc = n;
+  icvDFSInitSpillTreeNode( tr, d, tr->root );
+  return tr;
+}
+
+static void
+icvSpillTreeNodeHeapify( CvSpillTreeNode** heap,
+                        int i,
+                        const int k )
+{
+  if ( heap[i] == NULL )
+    return;
+  int l, r, largest = i;
+  CvSpillTreeNode* inp;
+  do {
+    i = largest;
+    r = (i+1)<<1;
+    l = r-1;
+    if (( l < k )&&( heap[l] == NULL ))
+      largest = l;
+    else if (( r < k )&&( heap[r] == NULL ))
+      largest = r;
+       else {
+      if (( l < k )&&( heap[l]->mp > heap[i]->mp ))
+        largest = l;
+      if (( r < k )&&( heap[r]->mp > heap[largest]->mp ))
+        largest = r;
+    }
+    if ( largest != i )
+      CV_SWAP( heap[largest], heap[i], inp );
+  } while ( largest != i );
+}
+
+static void
+icvSpillTreeDFSearch( CvSpillTree* tr,
+                     CvSpillTreeNode* node,
+                     CvSpillTreeNode** heap,
+                     int* es,
+                     const CvMat* desc,
+                     const int k,
+                     const int emax )
+{
+  if ((emax > 0)&&( *es >= emax ))
+    return;
+  double dist, p=0;
+  while ( node->spill )
+    {
+      // defeatist search
+      if ( !node->leaf )
+       p = cvDotProduct( node->u, desc );
+      if ( p < node->lb && node->lc->cc >= k ) // check the number of children larger than k otherwise you'll skip over better neighbor
+       node = node->lc;
+      else if ( p > node->ub && node->rc->cc >= k )
+       node = node->rc;
+      else
+       break;
+      if ( NULL == node )
+       return;
+    }
+  if ( node->leaf )
+    {
+      // a leaf, naive search
+      CvSpillTreeNode* it = node->lc;
+      for ( int i = 0; i < node->cc; i++ )
+        {
+          if ( !tr->cache[it->i] )
+          {
+           it->mp = cvNorm( it->center, desc );
+            tr->cache[it->i] = true;
+           if (( heap[0] == NULL)||( it->mp < heap[0]->mp ))
+             {
+                heap[0] = it;
+               icvSpillTreeNodeHeapify( heap, 0, k );
+               (*es)++;
+             }
+          }
+          it = it->rc;
+       }
+      return;
+    }
+  dist = cvNorm( node->center, desc );
+  // impossible case, skip
+  if (( heap[0] != NULL )&&( dist-node->r > heap[0]->mp ))
+    return;
+  p = cvDotProduct( node->u, desc );
+  // guided dfs
+  if ( p < node->mp )
+    {
+      icvSpillTreeDFSearch( tr, node->lc, heap, es, desc, k, emax );
+      icvSpillTreeDFSearch( tr, node->rc, heap, es, desc, k, emax );
+    } else {
+      icvSpillTreeDFSearch( tr, node->rc, heap, es, desc, k, emax );
+      icvSpillTreeDFSearch( tr, node->lc, heap, es, desc, k, emax );
+    }
+}
+
+static void
+icvFindSpillTreeFeatures( CvSpillTree* tr,
+                         const CvMat* desc,
+                         CvMat* results,
+                         CvMat* dist,
+                         const int k,
+                         const int emax )
+{
+  assert( desc->type == tr->type );
+  CvSpillTreeNode** heap = (CvSpillTreeNode**)cvAlloc( k*sizeof(heap[0]) );
+  for ( int j = 0; j < desc->rows; j++ )
+    {
+      CvMat _desc = cvMat( 1, desc->cols, desc->type, _dispatch_mat_ptr(desc, j*desc->cols) );
+      for ( int i = 0; i < k; i++ )
+       heap[i] = NULL;
+      memset( tr->cache, 0, sizeof(bool)*tr->total );
+      int es = 0;
+      icvSpillTreeDFSearch( tr, tr->root, heap, &es, &_desc, k, emax );
+      CvSpillTreeNode* inp;
+      for ( int i = k-1; i > 0; i-- )
+       {
+         CV_SWAP( heap[i], heap[0], inp );
+         icvSpillTreeNodeHeapify( heap, 0, i );
+       }
+      int* rs = results->data.i+j*results->cols;
+      double* dt = dist->data.db+j*dist->cols;
+      for ( int i = 0; i < k; i++, rs++, dt++ )
+       if ( heap[i] != NULL )
+         {
+           *rs = heap[i]->i;
+           *dt = heap[i]->mp;
+         } else
+           *rs = -1;
+    }
+  cvFree( &heap );
+}
+
+static void
+icvDFSReleaseSpillTreeNode( CvSpillTreeNode* node )
+{
+  if ( node->leaf )
+    {
+      CvSpillTreeNode* it = node->lc;
+      for ( int i = 0; i < node->cc; i++ )
+        {
+          CvSpillTreeNode* s = it;
+          it = it->rc;
+          cvFree( &s );
+        }
+    } else {
+      cvReleaseMat( &node->u );
+      cvReleaseMat( &node->center );
+      icvDFSReleaseSpillTreeNode( node->lc );
+      icvDFSReleaseSpillTreeNode( node->rc );
+    }
+  cvFree( &node );
+}
+
+static void
+icvReleaseSpillTree( CvSpillTree** tr )
+{
+  for ( int i = 0; i < (*tr)->total; i++ )
+    cvReleaseMat( &((*tr)->refmat[i]) );
+  cvFree( &((*tr)->refmat) );
+  cvFree( &((*tr)->cache) );
+  icvDFSReleaseSpillTreeNode( (*tr)->root );
+  cvFree( tr );
+}
+
+class CvSpillTreeWrap : public CvFeatureTree {
+  CvSpillTree* tr;
+public:
+  CvSpillTreeWrap(const CvMat* raw_data,
+                 const int naive,
+                 const double rho,
+                 const double tau) {
+    tr = icvCreateSpillTree(raw_data, naive, rho, tau);
+  }
+  ~CvSpillTreeWrap() {
+    icvReleaseSpillTree(&tr);
+  }
+
+  void FindFeatures(const CvMat* desc, int k, int emax, CvMat* results, CvMat* dist) {
+    icvFindSpillTreeFeatures(tr, desc, results, dist, k, emax);
+  }
+};
+
+CvFeatureTree* cvCreateSpillTree( const CvMat* raw_data,
+                                 const int naive,
+                                 const double rho,
+                                 const double tau ) {
+  return new CvSpillTreeWrap(raw_data, naive, rho, tau);
+}