1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.math.stat.regression;
18  
19  import static org.junit.Assert.assertEquals;
20  
21  import org.apache.commons.math.TestUtils;
22  import org.apache.commons.math.linear.DefaultRealMatrixChangingVisitor;
23  import org.apache.commons.math.linear.MatrixUtils;
24  import org.apache.commons.math.linear.MatrixVisitorException;
25  import org.apache.commons.math.linear.RealMatrix;
26  import org.apache.commons.math.linear.Array2DRowRealMatrix;
27  import org.junit.Before;
28  import org.junit.Test;
29  
30  public class OLSMultipleLinearRegressionTest extends MultipleLinearRegressionAbstractTest {
31  
32      private double[] y;
33      private double[][] x;
34      
35      @Before
36      @Override
37      public void setUp(){
38          y = new double[]{11.0, 12.0, 13.0, 14.0, 15.0, 16.0};
39          x = new double[6][];
40          x[0] = new double[]{1.0, 0, 0, 0, 0, 0};
41          x[1] = new double[]{1.0, 2.0, 0, 0, 0, 0};
42          x[2] = new double[]{1.0, 0, 3.0, 0, 0, 0};
43          x[3] = new double[]{1.0, 0, 0, 4.0, 0, 0};
44          x[4] = new double[]{1.0, 0, 0, 0, 5.0, 0};
45          x[5] = new double[]{1.0, 0, 0, 0, 0, 6.0};
46          super.setUp();
47      }
48  
49      @Override
50      protected OLSMultipleLinearRegression createRegression() {
51          OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
52          regression.newSampleData(y, x);
53          return regression;
54      }
55  
56      @Override
57      protected int getNumberOfRegressors() {
58          return x[0].length;
59      }
60  
61      @Override
62      protected int getSampleSize() {
63          return y.length;
64      }
65      
66      @Test(expected=IllegalArgumentException.class)
67      public void cannotAddXSampleData() {
68          createRegression().newSampleData(new double[]{}, null);
69      }
70  
71      @Test(expected=IllegalArgumentException.class)
72      public void cannotAddNullYSampleData() {
73          createRegression().newSampleData(null, new double[][]{});
74      }
75      
76      @Test(expected=IllegalArgumentException.class)
77      public void cannotAddSampleDataWithSizeMismatch() {
78          double[] y = new double[]{1.0, 2.0};
79          double[][] x = new double[1][];
80          x[0] = new double[]{1.0, 0};
81          createRegression().newSampleData(y, x);
82      }
83      
84      @Test
85      public void testPerfectFit() {
86          double[] betaHat = regression.estimateRegressionParameters();
87          TestUtils.assertEquals(betaHat, 
88                                 new double[]{ 11.0, 1.0 / 2.0, 2.0 / 3.0, 3.0 / 4.0, 4.0 / 5.0, 5.0 / 6.0 },
89                                 1e-14);
90          double[] residuals = regression.estimateResiduals();
91          TestUtils.assertEquals(residuals, new double[]{0d,0d,0d,0d,0d,0d},
92                                 1e-14);
93          RealMatrix errors =
94              new Array2DRowRealMatrix(regression.estimateRegressionParametersVariance(), false);
95          final double[] s = { 1.0, -1.0 /  2.0, -1.0 /  3.0, -1.0 /  4.0, -1.0 /  5.0, -1.0 /  6.0 };
96          RealMatrix referenceVariance = new Array2DRowRealMatrix(s.length, s.length);
97          referenceVariance.walkInOptimizedOrder(new DefaultRealMatrixChangingVisitor() {
98              @Override
99              public double visit(int row, int column, double value)
100                 throws MatrixVisitorException {
101                 if (row == 0) {
102                     return s[column];
103                 }
104                 double x = s[row] * s[column];
105                 return (row == column) ? 2 * x : x;
106             }
107         });
108        assertEquals(0.0,
109                      errors.subtract(referenceVariance).getNorm(),
110                      5.0e-16 * referenceVariance.getNorm());
111     }
112     
113     
114     /**
115      * Test Longley dataset against certified values provided by NIST.
116      * Data Source: J. Longley (1967) "An Appraisal of Least Squares
117      * Programs for the Electronic Computer from the Point of View of the User"
118      * Journal of the American Statistical Association, vol. 62. September,
119      * pp. 819-841.
120      * 
121      * Certified values (and data) are from NIST:
122      * http://www.itl.nist.gov/div898/strd/lls/data/LINKS/DATA/Longley.dat
123      */
124     @Test
125     public void testLongly() {
126         // Y values are first, then independent vars
127         // Each row is one observation
128         double[] design = new double[] {
129             60323,83.0,234289,2356,1590,107608,1947,
130             61122,88.5,259426,2325,1456,108632,1948,
131             60171,88.2,258054,3682,1616,109773,1949,
132             61187,89.5,284599,3351,1650,110929,1950,
133             63221,96.2,328975,2099,3099,112075,1951,
134             63639,98.1,346999,1932,3594,113270,1952,
135             64989,99.0,365385,1870,3547,115094,1953,
136             63761,100.0,363112,3578,3350,116219,1954,
137             66019,101.2,397469,2904,3048,117388,1955,
138             67857,104.6,419180,2822,2857,118734,1956,
139             68169,108.4,442769,2936,2798,120445,1957,
140             66513,110.8,444546,4681,2637,121950,1958,
141             68655,112.6,482704,3813,2552,123366,1959,
142             69564,114.2,502601,3931,2514,125368,1960,
143             69331,115.7,518173,4806,2572,127852,1961,
144             70551,116.9,554894,4007,2827,130081,1962
145         };
146         
147         // Transform to Y and X required by interface
148         int nobs = 16;
149         int nvars = 6;
150         
151         // Estimate the model
152         OLSMultipleLinearRegression model = new OLSMultipleLinearRegression();
153         model.newSampleData(design, nobs, nvars);
154         
155         // Check expected beta values from NIST
156         double[] betaHat = model.estimateRegressionParameters();
157         TestUtils.assertEquals(betaHat, 
158           new double[]{-3482258.63459582, 15.0618722713733,
159                 -0.358191792925910E-01,-2.02022980381683,
160                 -1.03322686717359,-0.511041056535807E-01,
161                  1829.15146461355}, 2E-8); // 
162         
163         // Check expected residuals from R
164         double[] residuals = model.estimateResiduals();
165         TestUtils.assertEquals(residuals, new double[]{
166                 267.340029759711,-94.0139423988359,46.28716775752924,
167                 -410.114621930906,309.7145907602313,-249.3112153297231,
168                 -164.0489563956039,-13.18035686637081,14.30477260005235,
169                  455.394094551857,-17.26892711483297,-39.0550425226967,
170                 -155.5499735953195,-85.6713080421283,341.9315139607727,
171                 -206.7578251937366},
172                       1E-8);
173         
174         // Check standard errors from NIST
175         double[] errors = model.estimateRegressionParametersStandardErrors();
176         TestUtils.assertEquals(new double[] {890420.383607373,
177                        84.9149257747669,
178                        0.334910077722432E-01,
179                        0.488399681651699,
180                        0.214274163161675,
181                        0.226073200069370,
182                        455.478499142212}, errors, 1E-6); 
183     }
184     
185     /**
186      * Test R Swiss fertility dataset against R.
187      * Data Source: R datasets package
188      */
189     @Test
190     public void testSwissFertility() {
191         double[] design = new double[] {
192             80.2,17.0,15,12,9.96,
193             83.1,45.1,6,9,84.84,
194             92.5,39.7,5,5,93.40,
195             85.8,36.5,12,7,33.77,
196             76.9,43.5,17,15,5.16,
197             76.1,35.3,9,7,90.57,
198             83.8,70.2,16,7,92.85,
199             92.4,67.8,14,8,97.16,
200             82.4,53.3,12,7,97.67,
201             82.9,45.2,16,13,91.38,
202             87.1,64.5,14,6,98.61,
203             64.1,62.0,21,12,8.52,
204             66.9,67.5,14,7,2.27,
205             68.9,60.7,19,12,4.43,
206             61.7,69.3,22,5,2.82,
207             68.3,72.6,18,2,24.20,
208             71.7,34.0,17,8,3.30,
209             55.7,19.4,26,28,12.11,
210             54.3,15.2,31,20,2.15,
211             65.1,73.0,19,9,2.84,
212             65.5,59.8,22,10,5.23,
213             65.0,55.1,14,3,4.52,
214             56.6,50.9,22,12,15.14,
215             57.4,54.1,20,6,4.20,
216             72.5,71.2,12,1,2.40,
217             74.2,58.1,14,8,5.23,
218             72.0,63.5,6,3,2.56,
219             60.5,60.8,16,10,7.72,
220             58.3,26.8,25,19,18.46,
221             65.4,49.5,15,8,6.10,
222             75.5,85.9,3,2,99.71,
223             69.3,84.9,7,6,99.68,
224             77.3,89.7,5,2,100.00,
225             70.5,78.2,12,6,98.96,
226             79.4,64.9,7,3,98.22,
227             65.0,75.9,9,9,99.06,
228             92.2,84.6,3,3,99.46,
229             79.3,63.1,13,13,96.83,
230             70.4,38.4,26,12,5.62,
231             65.7,7.7,29,11,13.79,
232             72.7,16.7,22,13,11.22,
233             64.4,17.6,35,32,16.92,
234             77.6,37.6,15,7,4.97,
235             67.6,18.7,25,7,8.65,
236             35.0,1.2,37,53,42.34,
237             44.7,46.6,16,29,50.43,
238             42.8,27.7,22,29,58.33
239         };
240 
241         // Transform to Y and X required by interface
242         int nobs = 47;
243         int nvars = 4;
244 
245         // Estimate the model
246         OLSMultipleLinearRegression model = new OLSMultipleLinearRegression();
247         model.newSampleData(design, nobs, nvars);
248 
249         // Check expected beta values from R
250         double[] betaHat = model.estimateRegressionParameters();
251         TestUtils.assertEquals(betaHat, 
252                 new double[]{91.05542390271397,
253                 -0.22064551045715,
254                 -0.26058239824328,
255                 -0.96161238456030,
256                  0.12441843147162}, 1E-12);
257 
258         // Check expected residuals from R
259         double[] residuals = model.estimateResiduals();
260         TestUtils.assertEquals(residuals, new double[]{
261                 7.1044267859730512,1.6580347433531366,
262                 4.6944952770029644,8.4548022690166160,13.6547432343186212,
263                -9.3586864458500774,7.5822446330520386,15.5568995563859289,
264                 0.8113090736598980,7.1186762732484308,7.4251378771228724,
265                 2.6761316873234109,0.8351584810309354,7.1769991119615177,
266                -3.8746753206299553,-3.1337779476387251,-0.1412575244091504,
267                 1.1186809170469780,-6.3588097346816594,3.4039270429434074,
268                 2.3374058329820175,-7.9272368576900503,-7.8361010968497959,
269                -11.2597369269357070,0.9445333697827101,6.6544245101380328,
270                -0.9146136301118665,-4.3152449403848570,-4.3536932047009183,
271                -3.8907885169304661,-6.3027643926302188,-7.8308982189289091,
272                -3.1792280015332750,-6.7167298771158226,-4.8469946718041754,
273                -10.6335664353633685,11.1031134362036958,6.0084032641811733,
274                 5.4326230830188482,-7.2375578629692230,2.1671550814448222,
275                 15.0147574652763112,4.8625103516321015,-7.1597256413907706,
276                 -0.4515205619767598,-10.2916870903837587,-15.7812984571900063},
277                 1E-12); 
278         
279         // Check standard errors from R
280         double[] errors = model.estimateRegressionParametersStandardErrors();
281         TestUtils.assertEquals(new double[] {6.94881329475087,
282                 0.07360008972340,
283                 0.27410957467466,
284                 0.19454551679325,
285                 0.03726654773803}, errors, 1E-10); 
286     }
287     
288     /**
289      * Test hat matrix computation
290      * 
291      * @throws Exception
292      */
293     @Test
294     public void testHat() throws Exception {
295         
296         /*
297          * This example is from "The Hat Matrix in Regression and ANOVA", 
298          * David C. Hoaglin and Roy E. Welsch, 
299          * The American Statistician, Vol. 32, No. 1 (Feb., 1978), pp. 17-22.
300          * 
301          */
302         double[] design = new double[] {
303                 11.14, .499, 11.1,
304                 12.74, .558, 8.9,
305                 13.13, .604, 8.8,
306                 11.51, .441, 8.9,
307                 12.38, .550, 8.8,
308                 12.60, .528, 9.9,
309                 11.13, .418, 10.7,
310                 11.7, .480, 10.5,
311                 11.02, .406, 10.5,
312                 11.41, .467, 10.7
313         };
314         
315         int nobs = 10;
316         int nvars = 2;
317         
318         // Estimate the model
319         OLSMultipleLinearRegression model = new OLSMultipleLinearRegression();
320         model.newSampleData(design, nobs, nvars);
321         
322         RealMatrix hat = model.calculateHat();
323         
324         // Reference data is upper half of symmetric hat matrix
325         double[] referenceData = new double[] {
326                 .418, -.002,  .079, -.274, -.046,  .181,  .128,  .222,  .050,  .242,
327                        .242,  .292,  .136,  .243,  .128, -.041,  .033, -.035,  .004,
328                               .417, -.019,  .273,  .187, -.126,  .044, -.153,  .004,
329                                      .604,  .197, -.038,  .168, -.022,  .275, -.028,
330                                             .252,  .111, -.030,  .019, -.010, -.010,
331                                                    .148,  .042,  .117,  .012,  .111,
332                                                           .262,  .145,  .277,  .174,
333                                                                  .154,  .120,  .168,
334                                                                         .315,  .148,
335                                                                                .187
336         };
337         
338         // Check against reference data and verify symmetry
339         int k = 0;
340         for (int i = 0; i < 10; i++) {
341             for (int j = i; j < 10; j++) {
342                 assertEquals(referenceData[k], hat.getEntry(i, j), 10e-3);
343                 assertEquals(hat.getEntry(i, j), hat.getEntry(j, i), 10e-12);
344                 k++;  
345             }
346         }
347         
348         /* 
349          * Verify that residuals computed using the hat matrix are close to 
350          * what we get from direct computation, i.e. r = (I - H) y
351          */
352         double[] residuals = model.estimateResiduals();
353         RealMatrix I = MatrixUtils.createRealIdentityMatrix(10);
354         double[] hatResiduals = I.subtract(hat).operate(model.Y).getData();
355         TestUtils.assertEquals(residuals, hatResiduals, 10e-12);    
356     }
357 }