View Javadoc

1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math.optimization.direct;
19  
20  import java.util.Comparator;
21  
22  import org.apache.commons.math.FunctionEvaluationException;
23  import org.apache.commons.math.optimization.OptimizationException;
24  import org.apache.commons.math.optimization.RealPointValuePair;
25  
26  /** 
27   * This class implements the Nelder-Mead direct search method.
28   *
29   * @version $Revision: 799857 $ $Date: 2009-08-01 09:07:12 -0400 (Sat, 01 Aug 2009) $
30   * @see MultiDirectional
31   * @since 1.2
32   */
33  public class NelderMead extends DirectSearchOptimizer {
34  
35      /** Reflection coefficient. */
36      private final double rho;
37  
38      /** Expansion coefficient. */
39      private final double khi;
40  
41      /** Contraction coefficient. */
42      private final double gamma;
43  
44      /** Shrinkage coefficient. */
45      private final double sigma;
46  
47      /** Build a Nelder-Mead optimizer with default coefficients.
48       * <p>The default coefficients are 1.0 for rho, 2.0 for khi and 0.5
49       * for both gamma and sigma.</p>
50       */
51      public NelderMead() {
52          this.rho   = 1.0;
53          this.khi   = 2.0;
54          this.gamma = 0.5;
55          this.sigma = 0.5;
56      }
57  
58      /** Build a Nelder-Mead optimizer with specified coefficients.
59       * @param rho reflection coefficient
60       * @param khi expansion coefficient
61       * @param gamma contraction coefficient
62       * @param sigma shrinkage coefficient
63       */
64      public NelderMead(final double rho, final double khi,
65                        final double gamma, final double sigma) {
66          this.rho   = rho;
67          this.khi   = khi;
68          this.gamma = gamma;
69          this.sigma = sigma;
70      }
71  
72      /** {@inheritDoc} */
73      @Override
74      protected void iterateSimplex(final Comparator<RealPointValuePair> comparator)
75          throws FunctionEvaluationException, OptimizationException {
76  
77          incrementIterationsCounter();
78  
79          // the simplex has n+1 point if dimension is n
80          final int n = simplex.length - 1;
81  
82          // interesting values
83          final RealPointValuePair best       = simplex[0];
84          final RealPointValuePair secondBest = simplex[n-1];
85          final RealPointValuePair worst      = simplex[n];
86          final double[] xWorst = worst.getPointRef();
87  
88          // compute the centroid of the best vertices
89          // (dismissing the worst point at index n)
90          final double[] centroid = new double[n];
91          for (int i = 0; i < n; ++i) {
92              final double[] x = simplex[i].getPointRef();
93              for (int j = 0; j < n; ++j) {
94                  centroid[j] += x[j];
95              }
96          }
97          final double scaling = 1.0 / n;
98          for (int j = 0; j < n; ++j) {
99              centroid[j] *= scaling;
100         }
101 
102         // compute the reflection point
103         final double[] xR = new double[n];
104         for (int j = 0; j < n; ++j) {
105             xR[j] = centroid[j] + rho * (centroid[j] - xWorst[j]);
106         }
107         final RealPointValuePair reflected = new RealPointValuePair(xR, evaluate(xR), false);
108 
109         if ((comparator.compare(best, reflected) <= 0) &&
110             (comparator.compare(reflected, secondBest) < 0)) {
111 
112             // accept the reflected point
113             replaceWorstPoint(reflected, comparator);
114 
115         } else if (comparator.compare(reflected, best) < 0) {
116 
117             // compute the expansion point
118             final double[] xE = new double[n];
119             for (int j = 0; j < n; ++j) {
120                 xE[j] = centroid[j] + khi * (xR[j] - centroid[j]);
121             }
122             final RealPointValuePair expanded = new RealPointValuePair(xE, evaluate(xE), false);
123 
124             if (comparator.compare(expanded, reflected) < 0) {
125                 // accept the expansion point
126                 replaceWorstPoint(expanded, comparator);
127             } else {
128                 // accept the reflected point
129                 replaceWorstPoint(reflected, comparator);
130             }
131 
132         } else {
133 
134             if (comparator.compare(reflected, worst) < 0) {
135 
136                 // perform an outside contraction
137                 final double[] xC = new double[n];
138                 for (int j = 0; j < n; ++j) {
139                     xC[j] = centroid[j] + gamma * (xR[j] - centroid[j]);
140                 }
141                 final RealPointValuePair outContracted = new RealPointValuePair(xC, evaluate(xC), false);
142 
143                 if (comparator.compare(outContracted, reflected) <= 0) {
144                     // accept the contraction point
145                     replaceWorstPoint(outContracted, comparator);
146                     return;
147                 }
148 
149             } else {
150 
151                 // perform an inside contraction
152                 final double[] xC = new double[n];
153                 for (int j = 0; j < n; ++j) {
154                     xC[j] = centroid[j] - gamma * (centroid[j] - xWorst[j]);
155                 }
156                 final RealPointValuePair inContracted = new RealPointValuePair(xC, evaluate(xC), false);
157 
158                 if (comparator.compare(inContracted, worst) < 0) {
159                     // accept the contraction point
160                     replaceWorstPoint(inContracted, comparator);
161                     return;
162                 }
163 
164             }
165 
166             // perform a shrink
167             final double[] xSmallest = simplex[0].getPointRef();
168             for (int i = 1; i < simplex.length; ++i) {
169                 final double[] x = simplex[i].getPoint();
170                 for (int j = 0; j < n; ++j) {
171                     x[j] = xSmallest[j] + sigma * (x[j] - xSmallest[j]);
172                 }
173                 simplex[i] = new RealPointValuePair(x, Double.NaN, false);
174             }
175             evaluateSimplex(comparator);
176 
177         }
178 
179     }
180 
181 }