001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     * 
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     * 
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    package org.apache.commons.math.stat.descriptive;
018    
019    
020    import java.util.Locale;
021    
022    import junit.framework.Test;
023    import junit.framework.TestCase;
024    import junit.framework.TestSuite;
025    
026    import org.apache.commons.math.DimensionMismatchException;
027    import org.apache.commons.math.TestUtils;
028    import org.apache.commons.math.stat.descriptive.moment.Mean;
029    
030    /**
031     * Test cases for the {@link MultivariateSummaryStatistics} class.
032     *
033     * @version $Revision: 797744 $ $Date: 2009-07-25 07:09:14 -0400 (Sat, 25 Jul 2009) $
034     */
035    
036    public class MultivariateSummaryStatisticsTest extends TestCase {
037    
038        public MultivariateSummaryStatisticsTest(String name) {
039            super(name);
040        }
041        
042        public static Test suite() {
043            TestSuite suite = new TestSuite(MultivariateSummaryStatisticsTest.class);
044            suite.setName("MultivariateSummaryStatistics tests");
045            return suite;
046        }
047    
048        protected MultivariateSummaryStatistics createMultivariateSummaryStatistics(int k, boolean isCovarianceBiasCorrected) {
049            return new MultivariateSummaryStatistics(k, isCovarianceBiasCorrected);
050        }
051    
052        public void testSetterInjection() throws Exception {
053            MultivariateSummaryStatistics u = createMultivariateSummaryStatistics(2, true);
054            u.setMeanImpl(new StorelessUnivariateStatistic[] {
055                            new sumMean(), new sumMean()
056                          });
057            u.addValue(new double[] { 1, 2 });
058            u.addValue(new double[] { 3, 4 });
059            assertEquals(4, u.getMean()[0], 1E-14);
060            assertEquals(6, u.getMean()[1], 1E-14);
061            u.clear();
062            u.addValue(new double[] { 1, 2 });
063            u.addValue(new double[] { 3, 4 });
064            assertEquals(4, u.getMean()[0], 1E-14);
065            assertEquals(6, u.getMean()[1], 1E-14);
066            u.clear();
067            u.setMeanImpl(new StorelessUnivariateStatistic[] {
068                            new Mean(), new Mean()
069                          }); // OK after clear
070            u.addValue(new double[] { 1, 2 });
071            u.addValue(new double[] { 3, 4 });
072            assertEquals(2, u.getMean()[0], 1E-14);
073            assertEquals(3, u.getMean()[1], 1E-14);
074            assertEquals(2, u.getDimension());
075        }
076        
077        public void testSetterIllegalState() throws Exception {
078            MultivariateSummaryStatistics u = createMultivariateSummaryStatistics(2, true);
079            u.addValue(new double[] { 1, 2 });
080            u.addValue(new double[] { 3, 4 });
081            try {
082                u.setMeanImpl(new StorelessUnivariateStatistic[] {
083                                new sumMean(), new sumMean()
084                              });
085                fail("Expecting IllegalStateException");
086            } catch (IllegalStateException ex) {
087                // expected
088            }
089        }
090    
091        public void testToString() throws DimensionMismatchException {
092            MultivariateSummaryStatistics stats = createMultivariateSummaryStatistics(2, true);
093            stats.addValue(new double[] {1, 3});
094            stats.addValue(new double[] {2, 2});
095            stats.addValue(new double[] {3, 1});
096            Locale d = Locale.getDefault();
097            Locale.setDefault(Locale.US);
098            assertEquals("MultivariateSummaryStatistics:\n" +
099                         "n: 3\n" +
100                         "min: 1.0, 1.0\n" +
101                         "max: 3.0, 3.0\n" +
102                         "mean: 2.0, 2.0\n" +
103                         "geometric mean: 1.817..., 1.817...\n" +
104                         "sum of squares: 14.0, 14.0\n" +
105                         "sum of logarithms: 1.791..., 1.791...\n" +
106                         "standard deviation: 1.0, 1.0\n" +
107                         "covariance: Array2DRowRealMatrix{{1.0,-1.0},{-1.0,1.0}}\n",
108                         stats.toString().replaceAll("([0-9]+\\.[0-9][0-9][0-9])[0-9]+", "$1..."));
109            Locale.setDefault(d);
110        }
111    
112        public void testShuffledStatistics() throws DimensionMismatchException {
113            // the purpose of this test is only to check the get/set methods
114            // we are aware shuffling statistics like this is really not
115            // something sensible to do in production ...
116            MultivariateSummaryStatistics reference = createMultivariateSummaryStatistics(2, true);
117            MultivariateSummaryStatistics shuffled  = createMultivariateSummaryStatistics(2, true);
118    
119            StorelessUnivariateStatistic[] tmp = shuffled.getGeoMeanImpl();
120            shuffled.setGeoMeanImpl(shuffled.getMeanImpl());
121            shuffled.setMeanImpl(shuffled.getMaxImpl());
122            shuffled.setMaxImpl(shuffled.getMinImpl());
123            shuffled.setMinImpl(shuffled.getSumImpl());
124            shuffled.setSumImpl(shuffled.getSumsqImpl());
125            shuffled.setSumsqImpl(shuffled.getSumLogImpl());
126            shuffled.setSumLogImpl(tmp);
127    
128            for (int i = 100; i > 0; --i) {
129                reference.addValue(new double[] {i, i});
130                shuffled.addValue(new double[] {i, i});
131            }
132    
133            TestUtils.assertEquals(reference.getMean(),          shuffled.getGeometricMean(), 1.0e-10);
134            TestUtils.assertEquals(reference.getMax(),           shuffled.getMean(),          1.0e-10);
135            TestUtils.assertEquals(reference.getMin(),           shuffled.getMax(),           1.0e-10);
136            TestUtils.assertEquals(reference.getSum(),           shuffled.getMin(),           1.0e-10);
137            TestUtils.assertEquals(reference.getSumSq(),         shuffled.getSum(),           1.0e-10);
138            TestUtils.assertEquals(reference.getSumLog(),        shuffled.getSumSq(),         1.0e-10);
139            TestUtils.assertEquals(reference.getGeometricMean(), shuffled.getSumLog(),        1.0e-10);
140    
141        }
142        
143        /**
144         * Bogus mean implementation to test setter injection.
145         * Returns the sum instead of the mean.
146         */
147        static class sumMean implements StorelessUnivariateStatistic {   
148            private double sum = 0;
149            private long n = 0;
150            public double evaluate(double[] values, int begin, int length) {
151                return 0;
152            }
153            public double evaluate(double[] values) {
154                return 0;
155            }
156            public void clear() {
157              sum = 0; 
158              n = 0;
159            }
160            public long getN() {
161                return n;
162            }
163            public double getResult() {
164                return sum;
165            }
166            public void increment(double d) {
167                sum += d;
168                n++;
169            }
170            public void incrementAll(double[] values, int start, int length) {
171            }
172            public void incrementAll(double[] values) {
173            }   
174            public StorelessUnivariateStatistic copy() {
175                return new sumMean();
176            }
177        }
178    
179        public void testDimension() {
180            try {
181                createMultivariateSummaryStatistics(2, true).addValue(new double[3]);
182            } catch (DimensionMismatchException dme) {
183                // expected behavior
184            } catch (Exception e) {
185                fail("wrong exception caught");
186            }
187        }
188    
189        /** test stats */
190        public void testStats() throws DimensionMismatchException {
191            MultivariateSummaryStatistics u = createMultivariateSummaryStatistics(2, true);
192            assertEquals(0, u.getN());
193            u.addValue(new double[] { 1, 2 });
194            u.addValue(new double[] { 2, 3 });
195            u.addValue(new double[] { 2, 3 });
196            u.addValue(new double[] { 3, 4 });
197            assertEquals( 4, u.getN());
198            assertEquals( 8, u.getSum()[0], 1.0e-10);
199            assertEquals(12, u.getSum()[1], 1.0e-10);
200            assertEquals(18, u.getSumSq()[0], 1.0e-10);
201            assertEquals(38, u.getSumSq()[1], 1.0e-10);
202            assertEquals( 1, u.getMin()[0], 1.0e-10);
203            assertEquals( 2, u.getMin()[1], 1.0e-10);
204            assertEquals( 3, u.getMax()[0], 1.0e-10);
205            assertEquals( 4, u.getMax()[1], 1.0e-10);
206            assertEquals(2.4849066497880003102, u.getSumLog()[0], 1.0e-10);
207            assertEquals( 4.276666119016055311, u.getSumLog()[1], 1.0e-10);
208            assertEquals( 1.8612097182041991979, u.getGeometricMean()[0], 1.0e-10);
209            assertEquals( 2.9129506302439405217, u.getGeometricMean()[1], 1.0e-10);
210            assertEquals( 2, u.getMean()[0], 1.0e-10);
211            assertEquals( 3, u.getMean()[1], 1.0e-10);
212            assertEquals(Math.sqrt(2.0 / 3.0), u.getStandardDeviation()[0], 1.0e-10);
213            assertEquals(Math.sqrt(2.0 / 3.0), u.getStandardDeviation()[1], 1.0e-10);
214            assertEquals(2.0 / 3.0, u.getCovariance().getEntry(0, 0), 1.0e-10);
215            assertEquals(2.0 / 3.0, u.getCovariance().getEntry(0, 1), 1.0e-10);
216            assertEquals(2.0 / 3.0, u.getCovariance().getEntry(1, 0), 1.0e-10);
217            assertEquals(2.0 / 3.0, u.getCovariance().getEntry(1, 1), 1.0e-10);
218            u.clear();
219            assertEquals(0, u.getN());    
220        }     
221    
222        public void testN0andN1Conditions() throws Exception {
223            MultivariateSummaryStatistics u = createMultivariateSummaryStatistics(1, true);
224            assertTrue(Double.isNaN(u.getMean()[0]));
225            assertTrue(Double.isNaN(u.getStandardDeviation()[0]));
226    
227            /* n=1 */
228            u.addValue(new double[] { 1 });
229            assertEquals(1.0, u.getMean()[0], 1.0e-10);
230            assertEquals(1.0, u.getGeometricMean()[0], 1.0e-10);
231            assertEquals(0.0, u.getStandardDeviation()[0], 1.0e-10);
232    
233            /* n=2 */               
234            u.addValue(new double[] { 2 });
235            assertTrue(u.getStandardDeviation()[0] > 0);
236    
237        }
238    
239        public void testNaNContracts() throws DimensionMismatchException {
240            MultivariateSummaryStatistics u = createMultivariateSummaryStatistics(1, true);
241            assertTrue(Double.isNaN(u.getMean()[0])); 
242            assertTrue(Double.isNaN(u.getMin()[0])); 
243            assertTrue(Double.isNaN(u.getStandardDeviation()[0])); 
244            assertTrue(Double.isNaN(u.getGeometricMean()[0]));
245    
246            u.addValue(new double[] { 1.0 });
247            assertFalse(Double.isNaN(u.getMean()[0])); 
248            assertFalse(Double.isNaN(u.getMin()[0])); 
249            assertFalse(Double.isNaN(u.getStandardDeviation()[0])); 
250            assertFalse(Double.isNaN(u.getGeometricMean()[0]));
251    
252        }
253    
254        public void testSerialization() throws DimensionMismatchException {
255            MultivariateSummaryStatistics u = createMultivariateSummaryStatistics(2, true);
256            // Empty test
257            TestUtils.checkSerializedEquality(u);
258            MultivariateSummaryStatistics s = (MultivariateSummaryStatistics) TestUtils.serializeAndRecover(u);
259            assertEquals(u, s);
260    
261            // Add some data
262            u.addValue(new double[] { 2d, 1d });
263            u.addValue(new double[] { 1d, 1d });
264            u.addValue(new double[] { 3d, 1d });
265            u.addValue(new double[] { 4d, 1d });
266            u.addValue(new double[] { 5d, 1d });
267    
268            // Test again
269            TestUtils.checkSerializedEquality(u);
270            s = (MultivariateSummaryStatistics) TestUtils.serializeAndRecover(u);
271            assertEquals(u, s);
272    
273        }
274    
275        public void testEqualsAndHashCode() throws DimensionMismatchException {
276            MultivariateSummaryStatistics u = createMultivariateSummaryStatistics(2, true);
277            MultivariateSummaryStatistics t = null;
278            int emptyHash = u.hashCode();
279            assertTrue(u.equals(u));
280            assertFalse(u.equals(t));
281            assertFalse(u.equals(Double.valueOf(0)));
282            t = createMultivariateSummaryStatistics(2, true);
283            assertTrue(t.equals(u));
284            assertTrue(u.equals(t));
285            assertEquals(emptyHash, t.hashCode());
286    
287            // Add some data to u
288            u.addValue(new double[] { 2d, 1d });
289            u.addValue(new double[] { 1d, 1d });
290            u.addValue(new double[] { 3d, 1d });
291            u.addValue(new double[] { 4d, 1d });
292            u.addValue(new double[] { 5d, 1d });
293            assertFalse(t.equals(u));
294            assertFalse(u.equals(t));
295            assertTrue(u.hashCode() != t.hashCode());
296    
297            //Add data in same order to t
298            t.addValue(new double[] { 2d, 1d });
299            t.addValue(new double[] { 1d, 1d });
300            t.addValue(new double[] { 3d, 1d });
301            t.addValue(new double[] { 4d, 1d });
302            t.addValue(new double[] { 5d, 1d });
303            assertTrue(t.equals(u));
304            assertTrue(u.equals(t));
305            assertEquals(u.hashCode(), t.hashCode());   
306    
307            // Clear and make sure summaries are indistinguishable from empty summary
308            u.clear();
309            t.clear();
310            assertTrue(t.equals(u));
311            assertTrue(u.equals(t));
312            assertEquals(emptyHash, t.hashCode());
313            assertEquals(emptyHash, u.hashCode());
314        }
315    
316    }