1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math.optimization.direct;
19  
20  import static org.junit.Assert.assertEquals;
21  import static org.junit.Assert.assertNotNull;
22  import static org.junit.Assert.assertNull;
23  import static org.junit.Assert.assertTrue;
24  import static org.junit.Assert.fail;
25  
26  import org.apache.commons.math.ConvergenceException;
27  import org.apache.commons.math.FunctionEvaluationException;
28  import org.apache.commons.math.MathException;
29  import org.apache.commons.math.MaxEvaluationsExceededException;
30  import org.apache.commons.math.MaxIterationsExceededException;
31  import org.apache.commons.math.analysis.MultivariateRealFunction;
32  import org.apache.commons.math.analysis.MultivariateVectorialFunction;
33  import org.apache.commons.math.linear.Array2DRowRealMatrix;
34  import org.apache.commons.math.linear.RealMatrix;
35  import org.apache.commons.math.optimization.GoalType;
36  import org.apache.commons.math.optimization.LeastSquaresConverter;
37  import org.apache.commons.math.optimization.OptimizationException;
38  import org.apache.commons.math.optimization.RealPointValuePair;
39  import org.apache.commons.math.optimization.SimpleRealPointChecker;
40  import org.apache.commons.math.optimization.SimpleScalarValueChecker;
41  import org.junit.Test;
42  
43  public class NelderMeadTest {
44  
45    @Test
46    public void testFunctionEvaluationExceptions() {
47        MultivariateRealFunction wrong =
48            new MultivariateRealFunction() {
49              private static final long serialVersionUID = 4751314470965489371L;
50              public double value(double[] x) throws FunctionEvaluationException {
51                  if (x[0] < 0) {
52                      throw new FunctionEvaluationException(x, "{0}", "oops");
53                  } else if (x[0] > 1) {
54                      throw new FunctionEvaluationException(new RuntimeException("oops"), x);
55                  } else {
56                      return x[0] * (1 - x[0]);
57                  }
58              }
59        };
60        try {
61            NelderMead optimizer = new NelderMead(0.9, 1.9, 0.4, 0.6);
62            optimizer.optimize(wrong, GoalType.MINIMIZE, new double[] { -1.0 });
63            fail("an exception should have been thrown");
64        } catch (FunctionEvaluationException ce) {
65            // expected behavior
66            assertNull(ce.getCause());
67        } catch (Exception e) {
68            fail("wrong exception caught: " + e.getMessage());
69        } 
70        try {
71            NelderMead optimizer = new NelderMead(0.9, 1.9, 0.4, 0.6);
72            optimizer.optimize(wrong, GoalType.MINIMIZE, new double[] { +2.0 });
73            fail("an exception should have been thrown");
74        } catch (FunctionEvaluationException ce) {
75            // expected behavior
76            assertNotNull(ce.getCause());
77        } catch (Exception e) {
78            fail("wrong exception caught: " + e.getMessage());
79        } 
80    }
81  
82    @Test
83    public void testMinimizeMaximize()
84        throws FunctionEvaluationException, ConvergenceException {
85  
86        // the following function has 4 local extrema:
87        final double xM        = -3.841947088256863675365;
88        final double yM        = -1.391745200270734924416;
89        final double xP        =  0.2286682237349059125691;
90        final double yP        = -yM;
91        final double valueXmYm =  0.2373295333134216789769; // local  maximum
92        final double valueXmYp = -valueXmYm;                // local  minimum
93        final double valueXpYm = -0.7290400707055187115322; // global minimum
94        final double valueXpYp = -valueXpYm;                // global maximum
95        MultivariateRealFunction fourExtrema = new MultivariateRealFunction() {
96            private static final long serialVersionUID = -7039124064449091152L;
97            public double value(double[] variables) throws FunctionEvaluationException {
98                final double x = variables[0];
99                final double y = variables[1];
100               return ((x == 0) || (y == 0)) ? 0 : (Math.atan(x) * Math.atan(x + 2) * Math.atan(y) * Math.atan(y) / (x * y));
101           }
102       };
103 
104       NelderMead optimizer = new NelderMead();
105       optimizer.setConvergenceChecker(new SimpleScalarValueChecker(1.0e-10, 1.0e-30));
106       optimizer.setMaxIterations(100);
107       optimizer.setStartConfiguration(new double[] { 0.2, 0.2 });
108       RealPointValuePair optimum;
109 
110       // minimization
111       optimum = optimizer.optimize(fourExtrema, GoalType.MINIMIZE, new double[] { -3.0, 0 });
112       assertEquals(xM,        optimum.getPoint()[0], 2.0e-7);
113       assertEquals(yP,        optimum.getPoint()[1], 2.0e-5);
114       assertEquals(valueXmYp, optimum.getValue(),    6.0e-12);
115       assertTrue(optimizer.getEvaluations() > 60);
116       assertTrue(optimizer.getEvaluations() < 90);
117 
118       optimum = optimizer.optimize(fourExtrema, GoalType.MINIMIZE, new double[] { +1, 0 });
119       assertEquals(xP,        optimum.getPoint()[0], 5.0e-6);
120       assertEquals(yM,        optimum.getPoint()[1], 6.0e-6);
121       assertEquals(valueXpYm, optimum.getValue(),    1.0e-11);              
122       assertTrue(optimizer.getEvaluations() > 60);
123       assertTrue(optimizer.getEvaluations() < 90);
124 
125       // maximization
126       optimum = optimizer.optimize(fourExtrema, GoalType.MAXIMIZE, new double[] { -3.0, 0.0 });
127       assertEquals(xM,        optimum.getPoint()[0], 1.0e-5);
128       assertEquals(yM,        optimum.getPoint()[1], 3.0e-6);
129       assertEquals(valueXmYm, optimum.getValue(),    3.0e-12);
130       assertTrue(optimizer.getEvaluations() > 60);
131       assertTrue(optimizer.getEvaluations() < 90);
132 
133       optimum = optimizer.optimize(fourExtrema, GoalType.MAXIMIZE, new double[] { +1, 0 });
134       assertEquals(xP,        optimum.getPoint()[0], 4.0e-6);
135       assertEquals(yP,        optimum.getPoint()[1], 5.0e-6);
136       assertEquals(valueXpYp, optimum.getValue(),    7.0e-12);
137       assertTrue(optimizer.getEvaluations() > 60);
138       assertTrue(optimizer.getEvaluations() < 90);
139 
140   }
141 
142   @Test
143   public void testRosenbrock()
144     throws FunctionEvaluationException, ConvergenceException {
145 
146     Rosenbrock rosenbrock = new Rosenbrock();
147     NelderMead optimizer = new NelderMead();
148     optimizer.setConvergenceChecker(new SimpleScalarValueChecker(-1, 1.0e-3));
149     optimizer.setMaxIterations(100);
150     optimizer.setStartConfiguration(new double[][] {
151             { -1.2,  1.0 }, { 0.9, 1.2 } , {  3.5, -2.3 }
152     });
153     RealPointValuePair optimum =
154         optimizer.optimize(rosenbrock, GoalType.MINIMIZE, new double[] { -1.2, 1.0 });
155 
156     assertEquals(rosenbrock.getCount(), optimizer.getEvaluations());
157     assertTrue(optimizer.getEvaluations() > 40);
158     assertTrue(optimizer.getEvaluations() < 50);
159     assertTrue(optimum.getValue() < 8.0e-4);
160 
161   }
162 
163   @Test
164   public void testPowell()
165     throws FunctionEvaluationException, ConvergenceException {
166 
167     Powell powell = new Powell();
168     NelderMead optimizer = new NelderMead();
169     optimizer.setConvergenceChecker(new SimpleScalarValueChecker(-1.0, 1.0e-3));
170     optimizer.setMaxIterations(200);
171     RealPointValuePair optimum =
172       optimizer.optimize(powell, GoalType.MINIMIZE, new double[] { 3.0, -1.0, 0.0, 1.0 });
173     assertEquals(powell.getCount(), optimizer.getEvaluations());
174     assertTrue(optimizer.getEvaluations() > 110);
175     assertTrue(optimizer.getEvaluations() < 130);
176     assertTrue(optimum.getValue() < 2.0e-3);
177 
178   }
179 
180   @Test
181   public void testLeastSquares1()
182   throws FunctionEvaluationException, ConvergenceException {
183 
184       final RealMatrix factors =
185           new Array2DRowRealMatrix(new double[][] {
186               { 1.0, 0.0 },
187               { 0.0, 1.0 }
188           }, false);
189       LeastSquaresConverter ls = new LeastSquaresConverter(new MultivariateVectorialFunction() {
190           public double[] value(double[] variables) {
191               return factors.operate(variables);
192           }
193       }, new double[] { 2.0, -3.0 });
194       NelderMead optimizer = new NelderMead();
195       optimizer.setConvergenceChecker(new SimpleScalarValueChecker(-1.0, 1.0e-6));
196       optimizer.setMaxIterations(200);
197       RealPointValuePair optimum =
198           optimizer.optimize(ls, GoalType.MINIMIZE, new double[] { 10.0, 10.0 });
199       assertEquals( 2.0, optimum.getPointRef()[0], 3.0e-5);
200       assertEquals(-3.0, optimum.getPointRef()[1], 4.0e-4);
201       assertTrue(optimizer.getEvaluations() > 60);
202       assertTrue(optimizer.getEvaluations() < 80);
203       assertTrue(optimum.getValue() < 1.0e-6);
204   }
205 
206   @Test
207   public void testLeastSquares2()
208   throws FunctionEvaluationException, ConvergenceException {
209 
210       final RealMatrix factors =
211           new Array2DRowRealMatrix(new double[][] {
212               { 1.0, 0.0 },
213               { 0.0, 1.0 }
214           }, false);
215       LeastSquaresConverter ls = new LeastSquaresConverter(new MultivariateVectorialFunction() {
216           public double[] value(double[] variables) {
217               return factors.operate(variables);
218           }
219       }, new double[] { 2.0, -3.0 }, new double[] { 10.0, 0.1 });
220       NelderMead optimizer = new NelderMead();
221       optimizer.setConvergenceChecker(new SimpleScalarValueChecker(-1.0, 1.0e-6));
222       optimizer.setMaxIterations(200);
223       RealPointValuePair optimum =
224           optimizer.optimize(ls, GoalType.MINIMIZE, new double[] { 10.0, 10.0 });
225       assertEquals( 2.0, optimum.getPointRef()[0], 5.0e-5);
226       assertEquals(-3.0, optimum.getPointRef()[1], 8.0e-4);
227       assertTrue(optimizer.getEvaluations() > 60);
228       assertTrue(optimizer.getEvaluations() < 80);
229       assertTrue(optimum.getValue() < 1.0e-6);
230   }
231 
232   @Test
233   public void testLeastSquares3()
234   throws FunctionEvaluationException, ConvergenceException {
235 
236       final RealMatrix factors =
237           new Array2DRowRealMatrix(new double[][] {
238               { 1.0, 0.0 },
239               { 0.0, 1.0 }
240           }, false);
241       LeastSquaresConverter ls = new LeastSquaresConverter(new MultivariateVectorialFunction() {
242           public double[] value(double[] variables) {
243               return factors.operate(variables);
244           }
245       }, new double[] { 2.0, -3.0 }, new Array2DRowRealMatrix(new double [][] {
246           { 1.0, 1.2 }, { 1.2, 2.0 }
247       }));
248       NelderMead optimizer = new NelderMead();
249       optimizer.setConvergenceChecker(new SimpleScalarValueChecker(-1.0, 1.0e-6));
250       optimizer.setMaxIterations(200);
251       RealPointValuePair optimum =
252           optimizer.optimize(ls, GoalType.MINIMIZE, new double[] { 10.0, 10.0 });
253       assertEquals( 2.0, optimum.getPointRef()[0], 2.0e-3);
254       assertEquals(-3.0, optimum.getPointRef()[1], 8.0e-4);
255       assertTrue(optimizer.getEvaluations() > 60);
256       assertTrue(optimizer.getEvaluations() < 80);
257       assertTrue(optimum.getValue() < 1.0e-6);
258   }
259 
260   @Test(expected = MaxIterationsExceededException.class)
261   public void testMaxIterations() throws MathException {
262       try {
263           Powell powell = new Powell();
264           NelderMead optimizer = new NelderMead();
265           optimizer.setConvergenceChecker(new SimpleScalarValueChecker(-1.0, 1.0e-3));
266           optimizer.setMaxIterations(20);
267           optimizer.optimize(powell, GoalType.MINIMIZE, new double[] { 3.0, -1.0, 0.0, 1.0 });
268       } catch (OptimizationException oe) {
269           if (oe.getCause() instanceof ConvergenceException) {
270               throw (ConvergenceException) oe.getCause();
271           }
272           throw oe;
273       }
274   }
275 
276   @Test(expected = MaxEvaluationsExceededException.class)
277   public void testMaxEvaluations() throws MathException {
278       try {
279           Powell powell = new Powell();
280           NelderMead optimizer = new NelderMead();
281           optimizer.setConvergenceChecker(new SimpleRealPointChecker(-1.0, 1.0e-3));
282           optimizer.setMaxEvaluations(20);
283           optimizer.optimize(powell, GoalType.MINIMIZE, new double[] { 3.0, -1.0, 0.0, 1.0 });
284       } catch (FunctionEvaluationException fee) {
285           if (fee.getCause() instanceof ConvergenceException) {
286               throw (ConvergenceException) fee.getCause();
287           }
288           throw fee;
289       }
290   }
291 
292   private static class Rosenbrock implements MultivariateRealFunction {
293 
294       private int count;
295 
296       public Rosenbrock() {
297           count = 0;
298       }
299 
300       public double value(double[] x) throws FunctionEvaluationException {
301           ++count;
302           double a = x[1] - x[0] * x[0];
303           double b = 1.0 - x[0];
304           return 100 * a * a + b * b;
305       }
306 
307       public int getCount() {
308           return count;
309       }
310 
311   }
312 
313   private static class Powell implements MultivariateRealFunction {
314 
315       private int count;
316 
317       public Powell() {
318           count = 0;
319       }
320 
321       public double value(double[] x) throws FunctionEvaluationException {
322           ++count;
323           double a = x[0] + 10 * x[1];
324           double b = x[2] - x[3];
325           double c = x[1] - 2 * x[2];
326           double d = x[0] - x[3];
327           return a * a + 5 * b * b + c * c * c * c + 10 * d * d * d * d;
328       }
329 
330       public int getCount() {
331           return count;
332       }
333 
334   }
335 
336 }