001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    
018    package org.apache.commons.math.optimization;
019    
020    import static org.junit.Assert.assertEquals;
021    import static org.junit.Assert.assertTrue;
022    
023    import java.awt.geom.Point2D;
024    import java.util.ArrayList;
025    
026    import org.apache.commons.math.FunctionEvaluationException;
027    import org.apache.commons.math.analysis.DifferentiableMultivariateRealFunction;
028    import org.apache.commons.math.analysis.MultivariateRealFunction;
029    import org.apache.commons.math.analysis.MultivariateVectorialFunction;
030    import org.apache.commons.math.analysis.solvers.BrentSolver;
031    import org.apache.commons.math.optimization.general.ConjugateGradientFormula;
032    import org.apache.commons.math.optimization.general.NonLinearConjugateGradientOptimizer;
033    import org.apache.commons.math.random.GaussianRandomGenerator;
034    import org.apache.commons.math.random.JDKRandomGenerator;
035    import org.apache.commons.math.random.RandomVectorGenerator;
036    import org.apache.commons.math.random.UncorrelatedRandomVectorGenerator;
037    import org.junit.Test;
038    
039    public class MultiStartDifferentiableMultivariateRealOptimizerTest {
040    
041        @Test
042        public void testCircleFitting() throws FunctionEvaluationException, OptimizationException {
043            Circle circle = new Circle();
044            circle.addPoint( 30.0,  68.0);
045            circle.addPoint( 50.0,  -6.0);
046            circle.addPoint(110.0, -20.0);
047            circle.addPoint( 35.0,  15.0);
048            circle.addPoint( 45.0,  97.0);
049            NonLinearConjugateGradientOptimizer underlying =
050                new NonLinearConjugateGradientOptimizer(ConjugateGradientFormula.POLAK_RIBIERE);
051            JDKRandomGenerator g = new JDKRandomGenerator();
052            g.setSeed(753289573253l);
053            RandomVectorGenerator generator =
054                new UncorrelatedRandomVectorGenerator(new double[] { 50.0, 50.0 }, new double[] { 10.0, 10.0 },
055                                                      new GaussianRandomGenerator(g));
056            MultiStartDifferentiableMultivariateRealOptimizer optimizer =
057                new MultiStartDifferentiableMultivariateRealOptimizer(underlying, 10, generator);
058            optimizer.setMaxIterations(100);
059            assertEquals(100, optimizer.getMaxIterations());
060            optimizer.setMaxEvaluations(100);
061            assertEquals(100, optimizer.getMaxEvaluations());
062            optimizer.setConvergenceChecker(new SimpleScalarValueChecker(1.0e-10, 1.0e-10));
063            BrentSolver solver = new BrentSolver();
064            solver.setAbsoluteAccuracy(1.0e-13);
065            solver.setRelativeAccuracy(1.0e-15);
066            RealPointValuePair optimum =
067                optimizer.optimize(circle, GoalType.MINIMIZE, new double[] { 98.680, 47.345 });
068            RealPointValuePair[] optima = optimizer.getOptima();
069            for (RealPointValuePair o : optima) {
070                Point2D.Double center = new Point2D.Double(o.getPointRef()[0], o.getPointRef()[1]);
071                assertEquals(69.960161753, circle.getRadius(center), 1.0e-8);
072                assertEquals(96.075902096, center.x, 1.0e-8);
073                assertEquals(48.135167894, center.y, 1.0e-8);
074            }
075            assertTrue(optimizer.getGradientEvaluations() > 650);
076            assertTrue(optimizer.getGradientEvaluations() < 700);
077            assertTrue(optimizer.getEvaluations() > 70);
078            assertTrue(optimizer.getEvaluations() < 90);
079            assertTrue(optimizer.getIterations() > 70);
080            assertTrue(optimizer.getIterations() < 90);
081            assertEquals(3.1267527, optimum.getValue(), 1.0e-8);
082        }
083    
084        private static class Circle implements DifferentiableMultivariateRealFunction {
085    
086            private ArrayList<Point2D.Double> points;
087    
088            public Circle() {
089                points  = new ArrayList<Point2D.Double>();
090            }
091    
092            public void addPoint(double px, double py) {
093                points.add(new Point2D.Double(px, py));
094            }
095    
096            public double getRadius(Point2D.Double center) {
097                double r = 0;
098                for (Point2D.Double point : points) {
099                    r += point.distance(center);
100                }
101                return r / points.size();
102            }
103    
104            private double[] gradient(double[] point) {
105    
106                // optimal radius
107                Point2D.Double center = new Point2D.Double(point[0], point[1]);
108                double radius = getRadius(center);
109    
110                // gradient of the sum of squared residuals
111                double dJdX = 0;
112                double dJdY = 0;
113                for (Point2D.Double pk : points) {
114                    double dk = pk.distance(center);
115                    dJdX += (center.x - pk.x) * (dk - radius) / dk;
116                    dJdY += (center.y - pk.y) * (dk - radius) / dk;
117                }
118                dJdX *= 2;
119                dJdY *= 2;
120    
121                return new double[] { dJdX, dJdY };
122    
123            }
124    
125            public double value(double[] variables)
126            throws IllegalArgumentException, FunctionEvaluationException {
127    
128                Point2D.Double center = new Point2D.Double(variables[0], variables[1]);
129                double radius = getRadius(center);
130    
131                double sum = 0;
132                for (Point2D.Double point : points) {
133                    double di = point.distance(center) - radius;
134                    sum += di * di;
135                }
136    
137                return sum;
138    
139            }
140    
141            public MultivariateVectorialFunction gradient() {
142                return new MultivariateVectorialFunction() {
143                    public double[] value(double[] point) {
144                        return gradient(point);
145                    }
146                };
147            }
148    
149            public MultivariateRealFunction partialDerivative(final int k) {
150                return new MultivariateRealFunction() {
151                    public double value(double[] point) {
152                        return gradient(point)[k];
153                    }
154                };
155            }
156    
157        }
158    
159    }