1   /*
2    * Licensed to the Apache Software Foundation (ASF) under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    * 
9    *      http://www.apache.org/licenses/LICENSE-2.0
10   * 
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.apache.commons.math.stat.descriptive;
18  
19  
20  import java.util.Locale;
21  
22  import junit.framework.Test;
23  import junit.framework.TestCase;
24  import junit.framework.TestSuite;
25  
26  import org.apache.commons.math.DimensionMismatchException;
27  import org.apache.commons.math.TestUtils;
28  import org.apache.commons.math.stat.descriptive.moment.Mean;
29  
30  /**
31   * Test cases for the {@link MultivariateSummaryStatistics} class.
32   *
33   * @version $Revision: 797744 $ $Date: 2009-07-25 07:09:14 -0400 (Sat, 25 Jul 2009) $
34   */
35  
36  public class MultivariateSummaryStatisticsTest extends TestCase {
37  
38      public MultivariateSummaryStatisticsTest(String name) {
39          super(name);
40      }
41      
42      public static Test suite() {
43          TestSuite suite = new TestSuite(MultivariateSummaryStatisticsTest.class);
44          suite.setName("MultivariateSummaryStatistics tests");
45          return suite;
46      }
47  
48      protected MultivariateSummaryStatistics createMultivariateSummaryStatistics(int k, boolean isCovarianceBiasCorrected) {
49          return new MultivariateSummaryStatistics(k, isCovarianceBiasCorrected);
50      }
51  
52      public void testSetterInjection() throws Exception {
53          MultivariateSummaryStatistics u = createMultivariateSummaryStatistics(2, true);
54          u.setMeanImpl(new StorelessUnivariateStatistic[] {
55                          new sumMean(), new sumMean()
56                        });
57          u.addValue(new double[] { 1, 2 });
58          u.addValue(new double[] { 3, 4 });
59          assertEquals(4, u.getMean()[0], 1E-14);
60          assertEquals(6, u.getMean()[1], 1E-14);
61          u.clear();
62          u.addValue(new double[] { 1, 2 });
63          u.addValue(new double[] { 3, 4 });
64          assertEquals(4, u.getMean()[0], 1E-14);
65          assertEquals(6, u.getMean()[1], 1E-14);
66          u.clear();
67          u.setMeanImpl(new StorelessUnivariateStatistic[] {
68                          new Mean(), new Mean()
69                        }); // OK after clear
70          u.addValue(new double[] { 1, 2 });
71          u.addValue(new double[] { 3, 4 });
72          assertEquals(2, u.getMean()[0], 1E-14);
73          assertEquals(3, u.getMean()[1], 1E-14);
74          assertEquals(2, u.getDimension());
75      }
76      
77      public void testSetterIllegalState() throws Exception {
78          MultivariateSummaryStatistics u = createMultivariateSummaryStatistics(2, true);
79          u.addValue(new double[] { 1, 2 });
80          u.addValue(new double[] { 3, 4 });
81          try {
82              u.setMeanImpl(new StorelessUnivariateStatistic[] {
83                              new sumMean(), new sumMean()
84                            });
85              fail("Expecting IllegalStateException");
86          } catch (IllegalStateException ex) {
87              // expected
88          }
89      }
90  
91      public void testToString() throws DimensionMismatchException {
92          MultivariateSummaryStatistics stats = createMultivariateSummaryStatistics(2, true);
93          stats.addValue(new double[] {1, 3});
94          stats.addValue(new double[] {2, 2});
95          stats.addValue(new double[] {3, 1});
96          Locale d = Locale.getDefault();
97          Locale.setDefault(Locale.US);
98          assertEquals("MultivariateSummaryStatistics:\n" +
99                       "n: 3\n" +
100                      "min: 1.0, 1.0\n" +
101                      "max: 3.0, 3.0\n" +
102                      "mean: 2.0, 2.0\n" +
103                      "geometric mean: 1.817..., 1.817...\n" +
104                      "sum of squares: 14.0, 14.0\n" +
105                      "sum of logarithms: 1.791..., 1.791...\n" +
106                      "standard deviation: 1.0, 1.0\n" +
107                      "covariance: Array2DRowRealMatrix{{1.0,-1.0},{-1.0,1.0}}\n",
108                      stats.toString().replaceAll("([0-9]+\\.[0-9][0-9][0-9])[0-9]+", "$1..."));
109         Locale.setDefault(d);
110     }
111 
112     public void testShuffledStatistics() throws DimensionMismatchException {
113         // the purpose of this test is only to check the get/set methods
114         // we are aware shuffling statistics like this is really not
115         // something sensible to do in production ...
116         MultivariateSummaryStatistics reference = createMultivariateSummaryStatistics(2, true);
117         MultivariateSummaryStatistics shuffled  = createMultivariateSummaryStatistics(2, true);
118 
119         StorelessUnivariateStatistic[] tmp = shuffled.getGeoMeanImpl();
120         shuffled.setGeoMeanImpl(shuffled.getMeanImpl());
121         shuffled.setMeanImpl(shuffled.getMaxImpl());
122         shuffled.setMaxImpl(shuffled.getMinImpl());
123         shuffled.setMinImpl(shuffled.getSumImpl());
124         shuffled.setSumImpl(shuffled.getSumsqImpl());
125         shuffled.setSumsqImpl(shuffled.getSumLogImpl());
126         shuffled.setSumLogImpl(tmp);
127 
128         for (int i = 100; i > 0; --i) {
129             reference.addValue(new double[] {i, i});
130             shuffled.addValue(new double[] {i, i});
131         }
132 
133         TestUtils.assertEquals(reference.getMean(),          shuffled.getGeometricMean(), 1.0e-10);
134         TestUtils.assertEquals(reference.getMax(),           shuffled.getMean(),          1.0e-10);
135         TestUtils.assertEquals(reference.getMin(),           shuffled.getMax(),           1.0e-10);
136         TestUtils.assertEquals(reference.getSum(),           shuffled.getMin(),           1.0e-10);
137         TestUtils.assertEquals(reference.getSumSq(),         shuffled.getSum(),           1.0e-10);
138         TestUtils.assertEquals(reference.getSumLog(),        shuffled.getSumSq(),         1.0e-10);
139         TestUtils.assertEquals(reference.getGeometricMean(), shuffled.getSumLog(),        1.0e-10);
140 
141     }
142     
143     /**
144      * Bogus mean implementation to test setter injection.
145      * Returns the sum instead of the mean.
146      */
147     static class sumMean implements StorelessUnivariateStatistic {   
148         private double sum = 0;
149         private long n = 0;
150         public double evaluate(double[] values, int begin, int length) {
151             return 0;
152         }
153         public double evaluate(double[] values) {
154             return 0;
155         }
156         public void clear() {
157           sum = 0; 
158           n = 0;
159         }
160         public long getN() {
161             return n;
162         }
163         public double getResult() {
164             return sum;
165         }
166         public void increment(double d) {
167             sum += d;
168             n++;
169         }
170         public void incrementAll(double[] values, int start, int length) {
171         }
172         public void incrementAll(double[] values) {
173         }   
174         public StorelessUnivariateStatistic copy() {
175             return new sumMean();
176         }
177     }
178 
179     public void testDimension() {
180         try {
181             createMultivariateSummaryStatistics(2, true).addValue(new double[3]);
182         } catch (DimensionMismatchException dme) {
183             // expected behavior
184         } catch (Exception e) {
185             fail("wrong exception caught");
186         }
187     }
188 
189     /** test stats */
190     public void testStats() throws DimensionMismatchException {
191         MultivariateSummaryStatistics u = createMultivariateSummaryStatistics(2, true);
192         assertEquals(0, u.getN());
193         u.addValue(new double[] { 1, 2 });
194         u.addValue(new double[] { 2, 3 });
195         u.addValue(new double[] { 2, 3 });
196         u.addValue(new double[] { 3, 4 });
197         assertEquals( 4, u.getN());
198         assertEquals( 8, u.getSum()[0], 1.0e-10);
199         assertEquals(12, u.getSum()[1], 1.0e-10);
200         assertEquals(18, u.getSumSq()[0], 1.0e-10);
201         assertEquals(38, u.getSumSq()[1], 1.0e-10);
202         assertEquals( 1, u.getMin()[0], 1.0e-10);
203         assertEquals( 2, u.getMin()[1], 1.0e-10);
204         assertEquals( 3, u.getMax()[0], 1.0e-10);
205         assertEquals( 4, u.getMax()[1], 1.0e-10);
206         assertEquals(2.4849066497880003102, u.getSumLog()[0], 1.0e-10);
207         assertEquals( 4.276666119016055311, u.getSumLog()[1], 1.0e-10);
208         assertEquals( 1.8612097182041991979, u.getGeometricMean()[0], 1.0e-10);
209         assertEquals( 2.9129506302439405217, u.getGeometricMean()[1], 1.0e-10);
210         assertEquals( 2, u.getMean()[0], 1.0e-10);
211         assertEquals( 3, u.getMean()[1], 1.0e-10);
212         assertEquals(Math.sqrt(2.0 / 3.0), u.getStandardDeviation()[0], 1.0e-10);
213         assertEquals(Math.sqrt(2.0 / 3.0), u.getStandardDeviation()[1], 1.0e-10);
214         assertEquals(2.0 / 3.0, u.getCovariance().getEntry(0, 0), 1.0e-10);
215         assertEquals(2.0 / 3.0, u.getCovariance().getEntry(0, 1), 1.0e-10);
216         assertEquals(2.0 / 3.0, u.getCovariance().getEntry(1, 0), 1.0e-10);
217         assertEquals(2.0 / 3.0, u.getCovariance().getEntry(1, 1), 1.0e-10);
218         u.clear();
219         assertEquals(0, u.getN());    
220     }     
221 
222     public void testN0andN1Conditions() throws Exception {
223         MultivariateSummaryStatistics u = createMultivariateSummaryStatistics(1, true);
224         assertTrue(Double.isNaN(u.getMean()[0]));
225         assertTrue(Double.isNaN(u.getStandardDeviation()[0]));
226 
227         /* n=1 */
228         u.addValue(new double[] { 1 });
229         assertEquals(1.0, u.getMean()[0], 1.0e-10);
230         assertEquals(1.0, u.getGeometricMean()[0], 1.0e-10);
231         assertEquals(0.0, u.getStandardDeviation()[0], 1.0e-10);
232 
233         /* n=2 */               
234         u.addValue(new double[] { 2 });
235         assertTrue(u.getStandardDeviation()[0] > 0);
236 
237     }
238 
239     public void testNaNContracts() throws DimensionMismatchException {
240         MultivariateSummaryStatistics u = createMultivariateSummaryStatistics(1, true);
241         assertTrue(Double.isNaN(u.getMean()[0])); 
242         assertTrue(Double.isNaN(u.getMin()[0])); 
243         assertTrue(Double.isNaN(u.getStandardDeviation()[0])); 
244         assertTrue(Double.isNaN(u.getGeometricMean()[0]));
245 
246         u.addValue(new double[] { 1.0 });
247         assertFalse(Double.isNaN(u.getMean()[0])); 
248         assertFalse(Double.isNaN(u.getMin()[0])); 
249         assertFalse(Double.isNaN(u.getStandardDeviation()[0])); 
250         assertFalse(Double.isNaN(u.getGeometricMean()[0]));
251 
252     }
253 
254     public void testSerialization() throws DimensionMismatchException {
255         MultivariateSummaryStatistics u = createMultivariateSummaryStatistics(2, true);
256         // Empty test
257         TestUtils.checkSerializedEquality(u);
258         MultivariateSummaryStatistics s = (MultivariateSummaryStatistics) TestUtils.serializeAndRecover(u);
259         assertEquals(u, s);
260 
261         // Add some data
262         u.addValue(new double[] { 2d, 1d });
263         u.addValue(new double[] { 1d, 1d });
264         u.addValue(new double[] { 3d, 1d });
265         u.addValue(new double[] { 4d, 1d });
266         u.addValue(new double[] { 5d, 1d });
267 
268         // Test again
269         TestUtils.checkSerializedEquality(u);
270         s = (MultivariateSummaryStatistics) TestUtils.serializeAndRecover(u);
271         assertEquals(u, s);
272 
273     }
274 
275     public void testEqualsAndHashCode() throws DimensionMismatchException {
276         MultivariateSummaryStatistics u = createMultivariateSummaryStatistics(2, true);
277         MultivariateSummaryStatistics t = null;
278         int emptyHash = u.hashCode();
279         assertTrue(u.equals(u));
280         assertFalse(u.equals(t));
281         assertFalse(u.equals(Double.valueOf(0)));
282         t = createMultivariateSummaryStatistics(2, true);
283         assertTrue(t.equals(u));
284         assertTrue(u.equals(t));
285         assertEquals(emptyHash, t.hashCode());
286 
287         // Add some data to u
288         u.addValue(new double[] { 2d, 1d });
289         u.addValue(new double[] { 1d, 1d });
290         u.addValue(new double[] { 3d, 1d });
291         u.addValue(new double[] { 4d, 1d });
292         u.addValue(new double[] { 5d, 1d });
293         assertFalse(t.equals(u));
294         assertFalse(u.equals(t));
295         assertTrue(u.hashCode() != t.hashCode());
296 
297         //Add data in same order to t
298         t.addValue(new double[] { 2d, 1d });
299         t.addValue(new double[] { 1d, 1d });
300         t.addValue(new double[] { 3d, 1d });
301         t.addValue(new double[] { 4d, 1d });
302         t.addValue(new double[] { 5d, 1d });
303         assertTrue(t.equals(u));
304         assertTrue(u.equals(t));
305         assertEquals(u.hashCode(), t.hashCode());   
306 
307         // Clear and make sure summaries are indistinguishable from empty summary
308         u.clear();
309         t.clear();
310         assertTrue(t.equals(u));
311         assertTrue(u.equals(t));
312         assertEquals(emptyHash, t.hashCode());
313         assertEquals(emptyHash, u.hashCode());
314     }
315 
316 }