1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.math.random;
19
20 import junit.framework.Test;
21 import junit.framework.TestCase;
22 import junit.framework.TestSuite;
23
24 import org.apache.commons.math.DimensionMismatchException;
25 import org.apache.commons.math.linear.MatrixUtils;
26 import org.apache.commons.math.linear.NotPositiveDefiniteMatrixException;
27 import org.apache.commons.math.linear.RealMatrix;
28 import org.apache.commons.math.stat.descriptive.moment.VectorialCovariance;
29 import org.apache.commons.math.stat.descriptive.moment.VectorialMean;
30
31 public class CorrelatedRandomVectorGeneratorTest
32 extends TestCase {
33
34 public CorrelatedRandomVectorGeneratorTest(String name) {
35 super(name);
36 mean = null;
37 covariance = null;
38 generator = null;
39 }
40
41 public void testRank() {
42 assertEquals(3, generator.getRank());
43 }
44
45 public void testMath226()
46 throws DimensionMismatchException, NotPositiveDefiniteMatrixException {
47 double[] mean = { 1, 1, 10, 1 };
48 double[][] cov = {
49 { 1, 3, 2, 6 },
50 { 3, 13, 16, 2 },
51 { 2, 16, 38, -1 },
52 { 6, 2, -1, 197 }
53 };
54 RealMatrix covRM = MatrixUtils.createRealMatrix(cov);
55 JDKRandomGenerator jg = new JDKRandomGenerator();
56 jg.setSeed(5322145245211l);
57 NormalizedRandomGenerator rg = new GaussianRandomGenerator(jg);
58 CorrelatedRandomVectorGenerator sg =
59 new CorrelatedRandomVectorGenerator(mean, covRM, 0.00001, rg);
60
61 for (int i = 0; i < 10; i++) {
62 double[] generated = sg.nextVector();
63 assertTrue(Math.abs(generated[0] - 1) > 0.1);
64 }
65
66 }
67
68 public void testRootMatrix() {
69 RealMatrix b = generator.getRootMatrix();
70 RealMatrix bbt = b.multiply(b.transpose());
71 for (int i = 0; i < covariance.getRowDimension(); ++i) {
72 for (int j = 0; j < covariance.getColumnDimension(); ++j) {
73 assertEquals(covariance.getEntry(i, j), bbt.getEntry(i, j), 1.0e-12);
74 }
75 }
76 }
77
78 public void testMeanAndCovariance() throws DimensionMismatchException {
79
80 VectorialMean meanStat = new VectorialMean(mean.length);
81 VectorialCovariance covStat = new VectorialCovariance(mean.length, true);
82 for (int i = 0; i < 5000; ++i) {
83 double[] v = generator.nextVector();
84 meanStat.increment(v);
85 covStat.increment(v);
86 }
87
88 double[] estimatedMean = meanStat.getResult();
89 RealMatrix estimatedCovariance = covStat.getResult();
90 for (int i = 0; i < estimatedMean.length; ++i) {
91 assertEquals(mean[i], estimatedMean[i], 0.07);
92 for (int j = 0; j <= i; ++j) {
93 assertEquals(covariance.getEntry(i, j),
94 estimatedCovariance.getEntry(i, j),
95 0.1 * (1.0 + Math.abs(mean[i])) * (1.0 + Math.abs(mean[j])));
96 }
97 }
98
99 }
100
101 @Override
102 public void setUp() {
103 try {
104 mean = new double[] { 0.0, 1.0, -3.0, 2.3};
105
106 RealMatrix b = MatrixUtils.createRealMatrix(4, 3);
107 int counter = 0;
108 for (int i = 0; i < b.getRowDimension(); ++i) {
109 for (int j = 0; j < b.getColumnDimension(); ++j) {
110 b.setEntry(i, j, 1.0 + 0.1 * ++counter);
111 }
112 }
113 RealMatrix bbt = b.multiply(b.transpose());
114 covariance = MatrixUtils.createRealMatrix(mean.length, mean.length);
115 for (int i = 0; i < covariance.getRowDimension(); ++i) {
116 covariance.setEntry(i, i, bbt.getEntry(i, i));
117 for (int j = 0; j < covariance.getColumnDimension(); ++j) {
118 double s = bbt.getEntry(i, j);
119 covariance.setEntry(i, j, s);
120 covariance.setEntry(j, i, s);
121 }
122 }
123
124 RandomGenerator rg = new JDKRandomGenerator();
125 rg.setSeed(17399225432l);
126 GaussianRandomGenerator rawGenerator = new GaussianRandomGenerator(rg);
127 generator = new CorrelatedRandomVectorGenerator(mean,
128 covariance,
129 1.0e-12 * covariance.getNorm(),
130 rawGenerator);
131 } catch (DimensionMismatchException e) {
132 fail(e.getMessage());
133 } catch (NotPositiveDefiniteMatrixException e) {
134 fail("not positive definite matrix");
135 }
136 }
137
138 @Override
139 public void tearDown() {
140 mean = null;
141 covariance = null;
142 generator = null;
143 }
144
145 public static Test suite() {
146 return new TestSuite(CorrelatedRandomVectorGeneratorTest.class);
147 }
148
149 private double[] mean;
150 private RealMatrix covariance;
151 private CorrelatedRandomVectorGenerator generator;
152
153 }