1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 package org.apache.commons.math.stat.descriptive;
18
19
20 import java.util.Locale;
21
22 import junit.framework.Test;
23 import junit.framework.TestCase;
24 import junit.framework.TestSuite;
25
26 import org.apache.commons.math.DimensionMismatchException;
27 import org.apache.commons.math.TestUtils;
28 import org.apache.commons.math.stat.descriptive.moment.Mean;
29
30
31
32
33
34
35
36 public class MultivariateSummaryStatisticsTest extends TestCase {
37
38 public MultivariateSummaryStatisticsTest(String name) {
39 super(name);
40 }
41
42 public static Test suite() {
43 TestSuite suite = new TestSuite(MultivariateSummaryStatisticsTest.class);
44 suite.setName("MultivariateSummaryStatistics tests");
45 return suite;
46 }
47
48 protected MultivariateSummaryStatistics createMultivariateSummaryStatistics(int k, boolean isCovarianceBiasCorrected) {
49 return new MultivariateSummaryStatistics(k, isCovarianceBiasCorrected);
50 }
51
52 public void testSetterInjection() throws Exception {
53 MultivariateSummaryStatistics u = createMultivariateSummaryStatistics(2, true);
54 u.setMeanImpl(new StorelessUnivariateStatistic[] {
55 new sumMean(), new sumMean()
56 });
57 u.addValue(new double[] { 1, 2 });
58 u.addValue(new double[] { 3, 4 });
59 assertEquals(4, u.getMean()[0], 1E-14);
60 assertEquals(6, u.getMean()[1], 1E-14);
61 u.clear();
62 u.addValue(new double[] { 1, 2 });
63 u.addValue(new double[] { 3, 4 });
64 assertEquals(4, u.getMean()[0], 1E-14);
65 assertEquals(6, u.getMean()[1], 1E-14);
66 u.clear();
67 u.setMeanImpl(new StorelessUnivariateStatistic[] {
68 new Mean(), new Mean()
69 });
70 u.addValue(new double[] { 1, 2 });
71 u.addValue(new double[] { 3, 4 });
72 assertEquals(2, u.getMean()[0], 1E-14);
73 assertEquals(3, u.getMean()[1], 1E-14);
74 assertEquals(2, u.getDimension());
75 }
76
77 public void testSetterIllegalState() throws Exception {
78 MultivariateSummaryStatistics u = createMultivariateSummaryStatistics(2, true);
79 u.addValue(new double[] { 1, 2 });
80 u.addValue(new double[] { 3, 4 });
81 try {
82 u.setMeanImpl(new StorelessUnivariateStatistic[] {
83 new sumMean(), new sumMean()
84 });
85 fail("Expecting IllegalStateException");
86 } catch (IllegalStateException ex) {
87
88 }
89 }
90
91 public void testToString() throws DimensionMismatchException {
92 MultivariateSummaryStatistics stats = createMultivariateSummaryStatistics(2, true);
93 stats.addValue(new double[] {1, 3});
94 stats.addValue(new double[] {2, 2});
95 stats.addValue(new double[] {3, 1});
96 Locale d = Locale.getDefault();
97 Locale.setDefault(Locale.US);
98 assertEquals("MultivariateSummaryStatistics:\n" +
99 "n: 3\n" +
100 "min: 1.0, 1.0\n" +
101 "max: 3.0, 3.0\n" +
102 "mean: 2.0, 2.0\n" +
103 "geometric mean: 1.817..., 1.817...\n" +
104 "sum of squares: 14.0, 14.0\n" +
105 "sum of logarithms: 1.791..., 1.791...\n" +
106 "standard deviation: 1.0, 1.0\n" +
107 "covariance: Array2DRowRealMatrix{{1.0,-1.0},{-1.0,1.0}}\n",
108 stats.toString().replaceAll("([0-9]+\\.[0-9][0-9][0-9])[0-9]+", "$1..."));
109 Locale.setDefault(d);
110 }
111
112 public void testShuffledStatistics() throws DimensionMismatchException {
113
114
115
116 MultivariateSummaryStatistics reference = createMultivariateSummaryStatistics(2, true);
117 MultivariateSummaryStatistics shuffled = createMultivariateSummaryStatistics(2, true);
118
119 StorelessUnivariateStatistic[] tmp = shuffled.getGeoMeanImpl();
120 shuffled.setGeoMeanImpl(shuffled.getMeanImpl());
121 shuffled.setMeanImpl(shuffled.getMaxImpl());
122 shuffled.setMaxImpl(shuffled.getMinImpl());
123 shuffled.setMinImpl(shuffled.getSumImpl());
124 shuffled.setSumImpl(shuffled.getSumsqImpl());
125 shuffled.setSumsqImpl(shuffled.getSumLogImpl());
126 shuffled.setSumLogImpl(tmp);
127
128 for (int i = 100; i > 0; --i) {
129 reference.addValue(new double[] {i, i});
130 shuffled.addValue(new double[] {i, i});
131 }
132
133 TestUtils.assertEquals(reference.getMean(), shuffled.getGeometricMean(), 1.0e-10);
134 TestUtils.assertEquals(reference.getMax(), shuffled.getMean(), 1.0e-10);
135 TestUtils.assertEquals(reference.getMin(), shuffled.getMax(), 1.0e-10);
136 TestUtils.assertEquals(reference.getSum(), shuffled.getMin(), 1.0e-10);
137 TestUtils.assertEquals(reference.getSumSq(), shuffled.getSum(), 1.0e-10);
138 TestUtils.assertEquals(reference.getSumLog(), shuffled.getSumSq(), 1.0e-10);
139 TestUtils.assertEquals(reference.getGeometricMean(), shuffled.getSumLog(), 1.0e-10);
140
141 }
142
143
144
145
146
147 static class sumMean implements StorelessUnivariateStatistic {
148 private double sum = 0;
149 private long n = 0;
150 public double evaluate(double[] values, int begin, int length) {
151 return 0;
152 }
153 public double evaluate(double[] values) {
154 return 0;
155 }
156 public void clear() {
157 sum = 0;
158 n = 0;
159 }
160 public long getN() {
161 return n;
162 }
163 public double getResult() {
164 return sum;
165 }
166 public void increment(double d) {
167 sum += d;
168 n++;
169 }
170 public void incrementAll(double[] values, int start, int length) {
171 }
172 public void incrementAll(double[] values) {
173 }
174 public StorelessUnivariateStatistic copy() {
175 return new sumMean();
176 }
177 }
178
179 public void testDimension() {
180 try {
181 createMultivariateSummaryStatistics(2, true).addValue(new double[3]);
182 } catch (DimensionMismatchException dme) {
183
184 } catch (Exception e) {
185 fail("wrong exception caught");
186 }
187 }
188
189
190 public void testStats() throws DimensionMismatchException {
191 MultivariateSummaryStatistics u = createMultivariateSummaryStatistics(2, true);
192 assertEquals(0, u.getN());
193 u.addValue(new double[] { 1, 2 });
194 u.addValue(new double[] { 2, 3 });
195 u.addValue(new double[] { 2, 3 });
196 u.addValue(new double[] { 3, 4 });
197 assertEquals( 4, u.getN());
198 assertEquals( 8, u.getSum()[0], 1.0e-10);
199 assertEquals(12, u.getSum()[1], 1.0e-10);
200 assertEquals(18, u.getSumSq()[0], 1.0e-10);
201 assertEquals(38, u.getSumSq()[1], 1.0e-10);
202 assertEquals( 1, u.getMin()[0], 1.0e-10);
203 assertEquals( 2, u.getMin()[1], 1.0e-10);
204 assertEquals( 3, u.getMax()[0], 1.0e-10);
205 assertEquals( 4, u.getMax()[1], 1.0e-10);
206 assertEquals(2.4849066497880003102, u.getSumLog()[0], 1.0e-10);
207 assertEquals( 4.276666119016055311, u.getSumLog()[1], 1.0e-10);
208 assertEquals( 1.8612097182041991979, u.getGeometricMean()[0], 1.0e-10);
209 assertEquals( 2.9129506302439405217, u.getGeometricMean()[1], 1.0e-10);
210 assertEquals( 2, u.getMean()[0], 1.0e-10);
211 assertEquals( 3, u.getMean()[1], 1.0e-10);
212 assertEquals(Math.sqrt(2.0 / 3.0), u.getStandardDeviation()[0], 1.0e-10);
213 assertEquals(Math.sqrt(2.0 / 3.0), u.getStandardDeviation()[1], 1.0e-10);
214 assertEquals(2.0 / 3.0, u.getCovariance().getEntry(0, 0), 1.0e-10);
215 assertEquals(2.0 / 3.0, u.getCovariance().getEntry(0, 1), 1.0e-10);
216 assertEquals(2.0 / 3.0, u.getCovariance().getEntry(1, 0), 1.0e-10);
217 assertEquals(2.0 / 3.0, u.getCovariance().getEntry(1, 1), 1.0e-10);
218 u.clear();
219 assertEquals(0, u.getN());
220 }
221
222 public void testN0andN1Conditions() throws Exception {
223 MultivariateSummaryStatistics u = createMultivariateSummaryStatistics(1, true);
224 assertTrue(Double.isNaN(u.getMean()[0]));
225 assertTrue(Double.isNaN(u.getStandardDeviation()[0]));
226
227
228 u.addValue(new double[] { 1 });
229 assertEquals(1.0, u.getMean()[0], 1.0e-10);
230 assertEquals(1.0, u.getGeometricMean()[0], 1.0e-10);
231 assertEquals(0.0, u.getStandardDeviation()[0], 1.0e-10);
232
233
234 u.addValue(new double[] { 2 });
235 assertTrue(u.getStandardDeviation()[0] > 0);
236
237 }
238
239 public void testNaNContracts() throws DimensionMismatchException {
240 MultivariateSummaryStatistics u = createMultivariateSummaryStatistics(1, true);
241 assertTrue(Double.isNaN(u.getMean()[0]));
242 assertTrue(Double.isNaN(u.getMin()[0]));
243 assertTrue(Double.isNaN(u.getStandardDeviation()[0]));
244 assertTrue(Double.isNaN(u.getGeometricMean()[0]));
245
246 u.addValue(new double[] { 1.0 });
247 assertFalse(Double.isNaN(u.getMean()[0]));
248 assertFalse(Double.isNaN(u.getMin()[0]));
249 assertFalse(Double.isNaN(u.getStandardDeviation()[0]));
250 assertFalse(Double.isNaN(u.getGeometricMean()[0]));
251
252 }
253
254 public void testSerialization() throws DimensionMismatchException {
255 MultivariateSummaryStatistics u = createMultivariateSummaryStatistics(2, true);
256
257 TestUtils.checkSerializedEquality(u);
258 MultivariateSummaryStatistics s = (MultivariateSummaryStatistics) TestUtils.serializeAndRecover(u);
259 assertEquals(u, s);
260
261
262 u.addValue(new double[] { 2d, 1d });
263 u.addValue(new double[] { 1d, 1d });
264 u.addValue(new double[] { 3d, 1d });
265 u.addValue(new double[] { 4d, 1d });
266 u.addValue(new double[] { 5d, 1d });
267
268
269 TestUtils.checkSerializedEquality(u);
270 s = (MultivariateSummaryStatistics) TestUtils.serializeAndRecover(u);
271 assertEquals(u, s);
272
273 }
274
275 public void testEqualsAndHashCode() throws DimensionMismatchException {
276 MultivariateSummaryStatistics u = createMultivariateSummaryStatistics(2, true);
277 MultivariateSummaryStatistics t = null;
278 int emptyHash = u.hashCode();
279 assertTrue(u.equals(u));
280 assertFalse(u.equals(t));
281 assertFalse(u.equals(Double.valueOf(0)));
282 t = createMultivariateSummaryStatistics(2, true);
283 assertTrue(t.equals(u));
284 assertTrue(u.equals(t));
285 assertEquals(emptyHash, t.hashCode());
286
287
288 u.addValue(new double[] { 2d, 1d });
289 u.addValue(new double[] { 1d, 1d });
290 u.addValue(new double[] { 3d, 1d });
291 u.addValue(new double[] { 4d, 1d });
292 u.addValue(new double[] { 5d, 1d });
293 assertFalse(t.equals(u));
294 assertFalse(u.equals(t));
295 assertTrue(u.hashCode() != t.hashCode());
296
297
298 t.addValue(new double[] { 2d, 1d });
299 t.addValue(new double[] { 1d, 1d });
300 t.addValue(new double[] { 3d, 1d });
301 t.addValue(new double[] { 4d, 1d });
302 t.addValue(new double[] { 5d, 1d });
303 assertTrue(t.equals(u));
304 assertTrue(u.equals(t));
305 assertEquals(u.hashCode(), t.hashCode());
306
307
308 u.clear();
309 t.clear();
310 assertTrue(t.equals(u));
311 assertTrue(u.equals(t));
312 assertEquals(emptyHash, t.hashCode());
313 assertEquals(emptyHash, u.hashCode());
314 }
315
316 }