001    /*
002     * Licensed to the Apache Software Foundation (ASF) under one or more
003     * contributor license agreements.  See the NOTICE file distributed with
004     * this work for additional information regarding copyright ownership.
005     * The ASF licenses this file to You under the Apache License, Version 2.0
006     * (the "License"); you may not use this file except in compliance with
007     * the License.  You may obtain a copy of the License at
008     *
009     *      http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    
018    package org.apache.commons.math.linear;
019    
020    import java.util.Arrays;
021    
022    import org.apache.commons.math.linear.InvalidMatrixException;
023    import org.apache.commons.math.linear.MatrixUtils;
024    import org.apache.commons.math.linear.RealMatrix;
025    import org.apache.commons.math.linear.TriDiagonalTransformer;
026    
027    import junit.framework.Test;
028    import junit.framework.TestCase;
029    import junit.framework.TestSuite;
030    
031    public class TriDiagonalTransformerTest extends TestCase {
032    
033        private double[][] testSquare5 = {
034                { 1, 2, 3, 1, 1 },
035                { 2, 1, 1, 3, 1 },
036                { 3, 1, 1, 1, 2 },
037                { 1, 3, 1, 2, 1 },
038                { 1, 1, 2, 1, 3 }
039        };
040    
041        private double[][] testSquare3 = {
042                { 1, 3, 4 },
043                { 3, 2, 2 },
044                { 4, 2, 0 }
045        };
046    
047        public TriDiagonalTransformerTest(String name) {
048            super(name);
049        }
050    
051        public void testNonSquare() {
052            try {
053                new TriDiagonalTransformer(MatrixUtils.createRealMatrix(new double[3][2]));
054                fail("an exception should have been thrown");
055            } catch (InvalidMatrixException ime) {
056                // expected behavior
057            } catch (Exception e) {
058                fail("wrong exception caught");
059            }
060        }
061    
062        public void testAEqualQTQt() {
063            checkAEqualQTQt(MatrixUtils.createRealMatrix(testSquare5));
064            checkAEqualQTQt(MatrixUtils.createRealMatrix(testSquare3));
065        }
066    
067        private void checkAEqualQTQt(RealMatrix matrix) {
068            TriDiagonalTransformer transformer = new TriDiagonalTransformer(matrix);
069            RealMatrix q  = transformer.getQ();
070            RealMatrix qT = transformer.getQT();
071            RealMatrix t  = transformer.getT();
072            double norm = q.multiply(t).multiply(qT).subtract(matrix).getNorm();
073            assertEquals(0, norm, 4.0e-15);
074        }
075    
076        public void testNoAccessBelowDiagonal() {
077            checkNoAccessBelowDiagonal(testSquare5);
078            checkNoAccessBelowDiagonal(testSquare3);
079        }
080    
081        private void checkNoAccessBelowDiagonal(double[][] data) {
082            double[][] modifiedData = new double[data.length][];
083            for (int i = 0; i < data.length; ++i) {
084                modifiedData[i] = data[i].clone();
085                Arrays.fill(modifiedData[i], 0, i, Double.NaN);
086            }
087            RealMatrix matrix = MatrixUtils.createRealMatrix(modifiedData);
088            TriDiagonalTransformer transformer = new TriDiagonalTransformer(matrix);
089            RealMatrix q  = transformer.getQ();
090            RealMatrix qT = transformer.getQT();
091            RealMatrix t  = transformer.getT();
092            double norm = q.multiply(t).multiply(qT).subtract(MatrixUtils.createRealMatrix(data)).getNorm();
093            assertEquals(0, norm, 4.0e-15);
094        }
095    
096        public void testQOrthogonal() {
097            checkOrthogonal(new TriDiagonalTransformer(MatrixUtils.createRealMatrix(testSquare5)).getQ());
098            checkOrthogonal(new TriDiagonalTransformer(MatrixUtils.createRealMatrix(testSquare3)).getQ());
099        }
100    
101        public void testQTOrthogonal() {
102            checkOrthogonal(new TriDiagonalTransformer(MatrixUtils.createRealMatrix(testSquare5)).getQT());
103            checkOrthogonal(new TriDiagonalTransformer(MatrixUtils.createRealMatrix(testSquare3)).getQT());
104        }
105    
106        private void checkOrthogonal(RealMatrix m) {
107            RealMatrix mTm = m.transpose().multiply(m);
108            RealMatrix id  = MatrixUtils.createRealIdentityMatrix(mTm.getRowDimension());
109            assertEquals(0, mTm.subtract(id).getNorm(), 1.0e-15);        
110        }
111    
112        public void testTTriDiagonal() {
113            checkTriDiagonal(new TriDiagonalTransformer(MatrixUtils.createRealMatrix(testSquare5)).getT());
114            checkTriDiagonal(new TriDiagonalTransformer(MatrixUtils.createRealMatrix(testSquare3)).getT());
115        }
116    
117        private void checkTriDiagonal(RealMatrix m) {
118            final int rows = m.getRowDimension();
119            final int cols = m.getColumnDimension();
120            for (int i = 0; i < rows; ++i) {
121                for (int j = 0; j < cols; ++j) {
122                    if ((i < j - 1) || (i > j + 1)) {
123                        assertEquals(0, m.getEntry(i, j), 1.0e-16);
124                    }                    
125                }
126            }
127        }
128    
129        public void testMatricesValues5() {
130            checkMatricesValues(testSquare5,
131                                new double[][] {
132                                    { 1.0,  0.0,                 0.0,                  0.0,                   0.0 },
133                                    { 0.0, -0.5163977794943222,  0.016748280772542083, 0.839800693771262,     0.16669620021405473 },
134                                    { 0.0, -0.7745966692414833, -0.4354553000860955,  -0.44989322880603355,  -0.08930153582895772 },
135                                    { 0.0, -0.2581988897471611,  0.6364346693566014,  -0.30263204032131164,   0.6608313651342882 },
136                                    { 0.0, -0.2581988897471611,  0.6364346693566009,  -0.027289660803112598, -0.7263191580755246 }
137                                },
138                                new double[] { 1, 4.4, 1.433099579242636, -0.89537362758743, 2.062274048344794 },
139                                new double[] { -Math.sqrt(15), -3.0832882879592476, 0.6082710842351517, 1.1786086405912128 });
140        }
141    
142        public void testMatricesValues3() {
143            checkMatricesValues(testSquare3,
144                                new double[][] {
145                                    {  1.0,  0.0,  0.0 },
146                                    {  0.0, -0.6,  0.8 },
147                                    {  0.0, -0.8, -0.6 },
148                                },
149                                new double[] { 1, 2.64, -0.64 },
150                                new double[] { -5, -1.52 });
151        }
152    
153        private void checkMatricesValues(double[][] matrix, double[][] qRef,
154                                         double[] mainDiagnonal,
155                                         double[] secondaryDiagonal) {
156            TriDiagonalTransformer transformer =
157                new TriDiagonalTransformer(MatrixUtils.createRealMatrix(matrix));
158    
159            // check values against known references
160            RealMatrix q = transformer.getQ();
161            assertEquals(0, q.subtract(MatrixUtils.createRealMatrix(qRef)).getNorm(), 1.0e-14);
162    
163            RealMatrix t = transformer.getT();
164            double[][] tData = new double[mainDiagnonal.length][mainDiagnonal.length];
165            for (int i = 0; i < mainDiagnonal.length; ++i) {
166                tData[i][i] = mainDiagnonal[i];
167                if (i > 0) {
168                    tData[i][i - 1] = secondaryDiagonal[i - 1];
169                }
170                if (i < secondaryDiagonal.length) {
171                    tData[i][i + 1] = secondaryDiagonal[i];
172                }
173            }
174            assertEquals(0, t.subtract(MatrixUtils.createRealMatrix(tData)).getNorm(), 1.0e-14);
175    
176            // check the same cached instance is returned the second time
177            assertTrue(q == transformer.getQ());
178            assertTrue(t == transformer.getT());
179            
180        }
181    
182        public static Test suite() {
183            return new TestSuite(TriDiagonalTransformerTest.class);
184        }
185    
186    }