View Javadoc

1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  
18  package org.apache.commons.math.optimization.fitting;
19  
20  import java.util.ArrayList;
21  import java.util.List;
22  
23  import org.apache.commons.math.FunctionEvaluationException;
24  import org.apache.commons.math.analysis.DifferentiableMultivariateVectorialFunction;
25  import org.apache.commons.math.analysis.MultivariateMatrixFunction;
26  import org.apache.commons.math.optimization.DifferentiableMultivariateVectorialOptimizer;
27  import org.apache.commons.math.optimization.OptimizationException;
28  import org.apache.commons.math.optimization.VectorialPointValuePair;
29  
30  /** Fitter for parametric univariate real functions y = f(x).
31   * <p>When a univariate real function y = f(x) does depend on some
32   * unknown parameters p<sub>0</sub>, p<sub>1</sub> ... p<sub>n-1</sub>,
33   * this class can be used to find these parameters. It does this
34   * by <em>fitting</em> the curve so it remains very close to a set of
35   * observed points (x<sub>0</sub>, y<sub>0</sub>), (x<sub>1</sub>,
36   * y<sub>1</sub>) ... (x<sub>k-1</sub>, y<sub>k-1</sub>). This fitting
37   * is done by finding the parameters values that minimizes the objective
38   * function &sum;(y<sub>i</sub>-f(x<sub>i</sub>))<sup>2</sup>. This is
39   * really a least squares problem.</p>
40   * @version $Revision: 790380 $ $Date: 2009-07-01 17:03:38 -0400 (Wed, 01 Jul 2009) $
41   * @since 2.0
42   */
43  public class CurveFitter {
44  
45      /** Optimizer to use for the fitting. */
46      private final DifferentiableMultivariateVectorialOptimizer optimizer;
47  
48      /** Observed points. */
49      private final List<WeightedObservedPoint> observations;
50  
51      /** Simple constructor.
52       * @param optimizer optimizer to use for the fitting
53       */
54      public CurveFitter(final DifferentiableMultivariateVectorialOptimizer optimizer) {
55          this.optimizer = optimizer;
56          observations = new ArrayList<WeightedObservedPoint>();
57      }
58  
59      /** Add an observed (x,y) point to the sample with unit weight.
60       * <p>Calling this method is equivalent to call
61       * <code>addObservedPoint(1.0, x, y)</code>.</p>
62       * @param x abscissa of the point
63       * @param y observed value of the point at x, after fitting we should
64       * have f(x) as close as possible to this value
65       * @see #addObservedPoint(double, double, double)
66       * @see #addObservedPoint(WeightedObservedPoint)
67       * @see #getObservations()
68       */
69      public void addObservedPoint(double x, double y) {
70          addObservedPoint(1.0, x, y);
71      }
72  
73      /** Add an observed weighted (x,y) point to the sample.
74       * @param weight weight of the observed point in the fit
75       * @param x abscissa of the point
76       * @param y observed value of the point at x, after fitting we should
77       * have f(x) as close as possible to this value
78       * @see #addObservedPoint(double, double)
79       * @see #addObservedPoint(WeightedObservedPoint)
80       * @see #getObservations()
81       */
82      public void addObservedPoint(double weight, double x, double y) {
83          observations.add(new WeightedObservedPoint(weight, x, y));
84      }
85  
86      /** Add an observed weighted (x,y) point to the sample.
87       * @param observed observed point to add
88       * @see #addObservedPoint(double, double)
89       * @see #addObservedPoint(double, double, double)
90       * @see #getObservations()
91       */
92      public void addObservedPoint(WeightedObservedPoint observed) {
93          observations.add(observed);
94      }
95  
96      /** Get the observed points.
97       * @return observed points
98       * @see #addObservedPoint(double, double)
99       * @see #addObservedPoint(double, double, double)
100      * @see #addObservedPoint(WeightedObservedPoint)
101      */
102     public WeightedObservedPoint[] getObservations() {
103         return observations.toArray(new WeightedObservedPoint[observations.size()]);
104     }
105 
106     /** Fit a curve.
107      * <p>This method compute the coefficients of the curve that best
108      * fit the sample of observed points previously given through calls
109      * to the {@link #addObservedPoint(WeightedObservedPoint)
110      * addObservedPoint} method.</p>
111      * @param f parametric function to fit
112      * @param initialGuess first guess of the function parameters
113      * @return fitted parameters
114      * @exception FunctionEvaluationException if the objective function throws one during
115      * the search
116      * @exception OptimizationException if the algorithm failed to converge
117      * @exception IllegalArgumentException if the start point dimension is wrong
118      */
119     public double[] fit(final ParametricRealFunction f,
120                         final double[] initialGuess)
121         throws FunctionEvaluationException, OptimizationException, IllegalArgumentException {
122 
123         // prepare least squares problem
124         double[] target  = new double[observations.size()];
125         double[] weights = new double[observations.size()];
126         int i = 0;
127         for (WeightedObservedPoint point : observations) {
128             target[i]  = point.getY();
129             weights[i] = point.getWeight();
130             ++i;
131         }
132 
133         // perform the fit
134         VectorialPointValuePair optimum =
135             optimizer.optimize(new TheoreticalValuesFunction(f), target, weights, initialGuess);
136 
137         // extract the coefficients
138         return optimum.getPointRef();
139 
140     }
141 
142     /** Vectorial function computing function theoretical values. */
143     private class TheoreticalValuesFunction
144         implements DifferentiableMultivariateVectorialFunction {
145 
146         /** Function to fit. */
147         private final ParametricRealFunction f;
148 
149         /** Simple constructor.
150          * @param f function to fit.
151          */
152         public TheoreticalValuesFunction(final ParametricRealFunction f) {
153             this.f = f;
154         }
155 
156         /** {@inheritDoc} */
157         public MultivariateMatrixFunction jacobian() {
158             return new MultivariateMatrixFunction() {
159                 public double[][] value(double[] point)
160                     throws FunctionEvaluationException, IllegalArgumentException {
161 
162                     final double[][] jacobian = new double[observations.size()][];
163 
164                     int i = 0;
165                     for (WeightedObservedPoint observed : observations) {
166                         jacobian[i++] = f.gradient(observed.getX(), point);
167                     }
168 
169                     return jacobian;
170 
171                 }
172             };
173         }
174 
175         /** {@inheritDoc} */
176         public double[] value(double[] point)
177                 throws FunctionEvaluationException, IllegalArgumentException {
178 
179             // compute the residuals
180             final double[] values = new double[observations.size()];
181             int i = 0;
182             for (WeightedObservedPoint observed : observations) {
183                 values[i++] = f.value(observed.getX(), point);
184             }
185 
186             return values;
187 
188         }
189 
190     }
191 
192 }