TooN 2.1
optimization/brent.h
00001 #ifndef TOON_BRENT_H
00002 #define TOON_BRENT_H
00003 #include <TooN/TooN.h>
00004 #include <TooN/helpers.h>
00005 #include <limits>
00006 #include <cmath>
00007 #include <cstdlib>
00008 #include <iomanip>
00009 
00010 
00011 namespace TooN
00012 {
00013     using std::numeric_limits;
00014 
00015     /// brent_line_search performs Brent's golden section/quadratic interpolation search
00016     /// on the functor provided. The inputs a, x, b must bracket the minimum, and
00017     /// must be in order, so  that \f$ a < x < b \f$ and \f$ f(a) > f(x) < f(b) \f$.
00018     /// @param a The most negative point along the line.
00019     /// @param x The central point.
00020     /// @param fx The value of the function at the central point (\f$b\f$).
00021     /// @param b The most positive point along the line.
00022     /// @param func The functor to minimize
00023     /// @param maxiterations  Maximum number of iterations
00024     /// @param tolerance Tolerance at which the search should be stopped (defults to sqrt machine precision)
00025     /// @param epsilon Minimum bracket width (defaults to machine precision)
00026     /// @return The minima position is returned as the first element of the vector,
00027     ///         and the minimal value as the second element.
00028     /// @ingroup gOptimize
00029     template<class Functor, class Precision> Vector<2, Precision> brent_line_search(Precision a, Precision x, Precision b, Precision fx, const Functor& func, int maxiterations, Precision tolerance = sqrt(numeric_limits<Precision>::epsilon()), Precision epsilon = numeric_limits<Precision>::epsilon())
00030     {
00031         using std::min;
00032         using std::max;
00033 
00034         using std::abs;
00035         using std::sqrt;
00036 
00037         //The golden ratio:
00038         const Precision g = (3.0 - sqrt(5))/2;
00039         
00040         //The following points are tracked by the algorithm:
00041         //a, b bracket the interval
00042         // x   is the best value so far
00043         // w   second best point so far
00044         // v   third best point so far
00045         // These may not be unique.
00046         
00047         //The following points are used during iteration
00048         // u   the point currently being evaluated
00049         // xm   (a+b)/2
00050         
00051         //The updates are tracked as:
00052         //e is the distance moved last step, or current if golden section is used
00053         //d is the point moved in the current step
00054         
00055         Precision w=x, v=x, fw=fx, fv=fx;
00056         
00057         Precision d=0, e=0;
00058         int i=0;
00059 
00060         while(abs(b-a) > (abs(a) + abs(b)) * tolerance + epsilon && i < maxiterations)
00061         {
00062             i++;
00063             //The midpoint of the bracket
00064             const Precision xm = (a+b)/2;
00065 
00066             //Per-iteration tolerance 
00067             const Precision tol1 = abs(x)*tolerance + epsilon;
00068 
00069             //If we recently had an unhelpful step, then do
00070             //not attempt a parabolic fit. This prevents bad parabolic
00071             //fits spoiling the convergence. Also, do not attempt to fit if
00072             //there is not yet enough unique information in x, w, v.
00073             if(abs(e) > tol1 && w != v)
00074             {
00075                 //Attempt a parabolic through the best 3 points. The pdata is shifted
00076                 //so that x = 0 and f(x) = 0. The remaining parameters are:
00077                 //
00078                 // xw  = w'    = w-x
00079                 // fxw = f'(w) = f(w) - f(x)
00080                 //
00081                 // etc:
00082                 const Precision fxw = fw - fx;
00083                 const Precision fxv = fv - fx;
00084                 const Precision xw = w-x;
00085                 const Precision xv = v-x;
00086 
00087                 //The parabolic fit has only second and first order coefficients:
00088                 //const Precision c1 = (fxv*xw - fxw*xv) / (xw*xv*(xv-xw));
00089                 //const Precision c2 = (fxw*xv*xv - fxv*xw*xw) / (xw*xv*(xv-xw));
00090                 
00091                 //The minimum is at -.5*c2 / c1;
00092                 //
00093                 //This can be written as p/q where:
00094                 const Precision p = fxv*xw*xw - fxw*xv*xv;
00095                 const Precision q = 2*(fxv*xw - fxw*xv);
00096 
00097                 //The minimum is at p/q. The minimum must lie within the bracket for it
00098                 //to be accepted. 
00099                 // Also, we want the step to be smaller than half the old one. So:
00100 
00101                 if(q == 0 || x + p/q < a || x+p/q > b || abs(p/q) > abs(e/2))
00102                 {
00103                     //Parabolic fit no good. Take a golden section step instead
00104                     //and reset d and e.
00105                     if(x > xm)
00106                         e = a-x;
00107                     else
00108                         e = b-x;
00109 
00110                     d = g*e;
00111                 }
00112                 else
00113                 {
00114                     //Parabolic fit was good. Shift d and e
00115                     e = d;
00116                     d = p/q;
00117                 }
00118             }
00119             else
00120             {
00121                 //Don't attempt a parabolic fit. Take a golden section step
00122                 //instead and reset d and e.
00123                 if(x > xm)
00124                     e = a-x;
00125                 else
00126                     e = b-x;
00127 
00128                 d = g*e;
00129             }
00130 
00131             const Precision u = x+d;
00132             //Our one function evaluation per iteration
00133             const Precision fu = func(u);
00134 
00135             if(fu < fx)
00136             {
00137                 //U is the best known point.
00138 
00139                 //Update the bracket
00140                 if(u > x)
00141                     a = x;
00142                 else
00143                     b = x;
00144 
00145                 //Shift v, w, x
00146                 v=w; fv = fw;
00147                 w=x; fw = fx;
00148                 x=u; fx = fu;
00149             }
00150             else
00151             {
00152                 //u is not the best known point. However, it is within the
00153                 //bracket.
00154                 if(u < x)
00155                     a = u;
00156                 else
00157                     b = u;
00158 
00159                 if(fu <= fw || w == x)
00160                 {
00161                     //Here, u is the new second-best point
00162                     v = w; fv = fw;
00163                     w = u; fw = fu;
00164                 }
00165                 else if(fu <= fv || v==x || v == w)
00166                 {
00167                     //Here, u is the new third-best point.
00168                     v = u; fv = fu;
00169                 }
00170             }
00171         }
00172 
00173         return makeVector(x, fx);
00174     }
00175 }
00176 #endif