View Javadoc

1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.math.stat.regression;
18  
19  import org.apache.commons.math.MathRuntimeException;
20  import org.apache.commons.math.linear.RealMatrix;
21  import org.apache.commons.math.linear.Array2DRowRealMatrix;
22  import org.apache.commons.math.linear.RealVector;
23  import org.apache.commons.math.linear.ArrayRealVector;
24  
25  /**
26   * Abstract base class for implementations of MultipleLinearRegression.
27   * @version $Revision: 791244 $ $Date: 2009-07-05 09:29:37 -0400 (Sun, 05 Jul 2009) $
28   * @since 2.0
29   */
30  public abstract class AbstractMultipleLinearRegression implements
31          MultipleLinearRegression {
32  
33      /** X sample data. */
34      protected RealMatrix X;
35  
36      /** Y sample data. */
37      protected RealVector Y;
38  
39      /**
40       * Loads model x and y sample data from a flat array of data, overriding any previous sample.
41       * Assumes that rows are concatenated with y values first in each row.
42       * 
43       * @param data input data array
44       * @param nobs number of observations (rows)
45       * @param nvars number of independent variables (columns, not counting y)
46       */
47      public void newSampleData(double[] data, int nobs, int nvars) {
48          double[] y = new double[nobs];
49          double[][] x = new double[nobs][nvars + 1];
50          int pointer = 0;
51          for (int i = 0; i < nobs; i++) {
52              y[i] = data[pointer++];
53              x[i][0] = 1.0d;
54              for (int j = 1; j < nvars + 1; j++) {
55                  x[i][j] = data[pointer++];
56              }
57          }
58          this.X = new Array2DRowRealMatrix(x);
59          this.Y = new ArrayRealVector(y);
60      }
61      
62      /**
63       * Loads new y sample data, overriding any previous sample
64       * 
65       * @param y the [n,1] array representing the y sample
66       */
67      protected void newYSampleData(double[] y) {
68          this.Y = new ArrayRealVector(y);
69      }
70  
71      /**
72       * Loads new x sample data, overriding any previous sample
73       * 
74       * @param x the [n,k] array representing the x sample
75       */
76      protected void newXSampleData(double[][] x) {
77          this.X = new Array2DRowRealMatrix(x);
78      }
79  
80      /**
81       * Validates sample data.
82       * 
83       * @param x the [n,k] array representing the x sample
84       * @param y the [n,1] array representing the y sample
85       * @throws IllegalArgumentException if the x and y array data are not
86       *             compatible for the regression
87       */
88      protected void validateSampleData(double[][] x, double[] y) {
89          if ((x == null) || (y == null) || (x.length != y.length)) {
90              throw MathRuntimeException.createIllegalArgumentException(
91                    "dimension mismatch {0} != {1}",
92                    (x == null) ? 0 : x.length,
93                    (y == null) ? 0 : y.length);
94          } else if ((x.length > 0) && (x[0].length > x.length)) {
95              throw MathRuntimeException.createIllegalArgumentException(
96                    "not enough data ({0} rows) for this many predictors ({1} predictors)",
97                    x.length, x[0].length);
98          }
99      }
100 
101     /**
102      * Validates sample data.
103      * 
104      * @param x the [n,k] array representing the x sample
105      * @param covariance the [n,n] array representing the covariance matrix
106      * @throws IllegalArgumentException if the x sample data or covariance
107      *             matrix are not compatible for the regression
108      */
109     protected void validateCovarianceData(double[][] x, double[][] covariance) {
110         if (x.length != covariance.length) {
111             throw MathRuntimeException.createIllegalArgumentException(
112                  "dimension mismatch {0} != {1}", x.length, covariance.length);
113         }
114         if (covariance.length > 0 && covariance.length != covariance[0].length) {
115             throw MathRuntimeException.createIllegalArgumentException(
116                   "a {0}x{1} matrix was provided instead of a square matrix",
117                   covariance.length, covariance[0].length);
118         }
119     }
120 
121     /**
122      * {@inheritDoc}
123      */
124     public double[] estimateRegressionParameters() {
125         RealVector b = calculateBeta();
126         return b.getData();
127     }
128 
129     /**
130      * {@inheritDoc}
131      */
132     public double[] estimateResiduals() {
133         RealVector b = calculateBeta();
134         RealVector e = Y.subtract(X.operate(b));
135         return e.getData();
136     }
137 
138     /**
139      * {@inheritDoc}
140      */
141     public double[][] estimateRegressionParametersVariance() {
142         return calculateBetaVariance().getData();
143     }
144     
145     /**
146      * {@inheritDoc}
147      */
148     public double[] estimateRegressionParametersStandardErrors() {
149         double[][] betaVariance = estimateRegressionParametersVariance();
150         double sigma = calculateYVariance();
151         int length = betaVariance[0].length;
152         double[] result = new double[length];
153         for (int i = 0; i < length; i++) {
154             result[i] = Math.sqrt(sigma * betaVariance[i][i]);
155         }
156         return result;
157     }
158 
159     /**
160      * {@inheritDoc}
161      */
162     public double estimateRegressandVariance() {
163         return calculateYVariance();
164     }
165 
166     /**
167      * Calculates the beta of multiple linear regression in matrix notation.
168      * 
169      * @return beta
170      */
171     protected abstract RealVector calculateBeta();
172 
173     /**
174      * Calculates the beta variance of multiple linear regression in matrix
175      * notation.
176      * 
177      * @return beta variance
178      */
179     protected abstract RealMatrix calculateBetaVariance();
180 
181     /**
182      * Calculates the Y variance of multiple linear regression.
183      * 
184      * @return Y variance
185      */
186     protected abstract double calculateYVariance();
187 
188     /**
189      * Calculates the residuals of multiple linear regression in matrix
190      * notation.
191      * 
192      * <pre>
193      * u = y - X * b
194      * </pre>
195      * 
196      * @return The residuals [n,1] matrix
197      */
198     protected RealVector calculateResiduals() {
199         RealVector b = calculateBeta();
200         return Y.subtract(X.operate(b));
201     }
202 
203 }