1 /***********************************************************************
2 * Software License Agreement (BSD License)
4 * Copyright 2008-2009 Marius Muja (mariusm@cs.ubc.ca). All rights reserved.
5 * Copyright 2008-2009 David G. Lowe (lowe@cs.ubc.ca). All rights reserved.
9 * Redistribution and use in source and binary forms, with or without
10 * modification, are permitted provided that the following conditions
13 * 1. Redistributions of source code must retain the above copyright
14 * notice, this list of conditions and the following disclaimer.
15 * 2. Redistributions in binary form must reproduce the above copyright
16 * notice, this list of conditions and the following disclaimer in the
17 * documentation and/or other materials provided with the distribution.
19 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
20 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
21 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
22 * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
23 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
24 * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
25 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
26 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
28 * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 *************************************************************************/
31 #ifndef SIMPLEX_DOWNHILL_H
32 #define SIMPLEX_DOWNHILL_H
38 Adds val to array vals (and point to array points) and keeping the arrays sorted by vals.
41 void addValue(int pos, float val, float* vals, T* point, T* points, int n)
44 for (int i=0;i<n;++i) {
45 points[pos*n+i] = point[i];
50 while (j>0 && vals[j]<vals[j-1]) {
51 swap(vals[j],vals[j-1]);
52 for (int i=0;i<n;++i) {
53 swap(points[j*n+i],points[(j-1)*n+i]);
61 Simplex downhill optimization function.
62 Preconditions: points is a 2D mattrix of size (n+1) x n
63 func is the cost function taking n an array of n params and returning float
64 vals is the cost function in the n+1 simplex points, if NULL it will be computed
66 Postcondition: returns optimum value and points[0..n] are the optimum parameters
68 template <typename T, typename F>
69 float optimizeSimplexDownhill(T* points, int n, F func, float* vals = NULL )
71 const int MAX_ITERATIONS = 10;
86 vals = new float[n+1];
87 for (int i=0;i<n+1;++i) {
88 float val = func(points+i*n);
89 addValue(i, val, vals, points+i*n, points, n);
96 if (iterations++ > MAX_ITERATIONS) break;
98 // compute average of simplex points (except the highest point)
99 for (int j=0;j<n;++j) {
101 for (int i=0;i<n;++i) {
102 p_o[i] += points[j*n+i];
105 for (int i=0;i<n;++i) {
109 bool converged = true;
110 for (int i=0;i<n;++i) {
111 if (p_o[i] != points[nn+i]) {
115 if (converged) break;
117 // trying a reflection
118 for (int i=0;i<n;++i) {
119 p_r[i] = p_o[i] + alpha*(p_o[i]-points[nn+i]);
121 float val_r = func(p_r);
123 if (val_r>=vals[0] && val_r<vals[n]) {
124 // reflection between second highest and lowest
125 // add it to the simplex
126 logger.info("Choosing reflection\n");
127 addValue(n, val_r,vals, p_r, points, n);
132 // value is smaller than smalest in simplex
134 // expand some more to see if it drops further
135 for (int i=0;i<n;++i) {
136 p_e[i] = 2*p_r[i]-p_o[i];
138 float val_e = func(p_e);
141 logger.info("Choosing reflection and expansion\n");
142 addValue(n, val_e,vals,p_e,points,n);
145 logger.info("Choosing reflection\n");
146 addValue(n, val_r,vals,p_r,points,n);
150 if (val_r>=vals[n]) {
151 for (int i=0;i<n;++i) {
152 p_e[i] = (p_o[i]+points[nn+i])/2;
154 float val_e = func(p_e);
157 logger.info("Choosing contraction\n");
158 addValue(n,val_e,vals,p_e,points,n);
163 logger.info("Full contraction\n");
164 for (int j=1;j<=n;++j) {
165 for (int i=0;i<n;++i) {
166 points[j*n+i] = (points[j*n+i]+points[i])/2;
168 float val = func(points+j*n);
169 addValue(j,val,vals,points+j*n,points,n);
174 float bestVal = vals[0];
179 if (ownVals) delete[] vals;
186 #endif //SIMPLEX_DOWNHILL_H