1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 package org.apache.commons.math.linear;
19
20 import java.lang.reflect.Array;
21
22 import org.apache.commons.math.Field;
23 import org.apache.commons.math.FieldElement;
24 import org.apache.commons.math.MathRuntimeException;
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40 public class FieldLUDecompositionImpl<T extends FieldElement<T>> implements FieldLUDecomposition<T> {
41
42
43 private final Field<T> field;
44
45
46 private T lu[][];
47
48
49 private int[] pivot;
50
51
52 private boolean even;
53
54
55 private boolean singular;
56
57
58 private FieldMatrix<T> cachedL;
59
60
61 private FieldMatrix<T> cachedU;
62
63
64 private FieldMatrix<T> cachedP;
65
66
67
68
69
70
71 public FieldLUDecompositionImpl(FieldMatrix<T> matrix)
72 throws NonSquareMatrixException {
73
74 if (!matrix.isSquare()) {
75 throw new NonSquareMatrixException(matrix.getRowDimension(), matrix.getColumnDimension());
76 }
77
78 final int m = matrix.getColumnDimension();
79 field = matrix.getField();
80 lu = matrix.getData();
81 pivot = new int[m];
82 cachedL = null;
83 cachedU = null;
84 cachedP = null;
85
86
87 for (int row = 0; row < m; row++) {
88 pivot[row] = row;
89 }
90 even = true;
91 singular = false;
92
93
94 for (int col = 0; col < m; col++) {
95
96 T sum = field.getZero();
97
98
99 for (int row = 0; row < col; row++) {
100 final T[] luRow = lu[row];
101 sum = luRow[col];
102 for (int i = 0; i < row; i++) {
103 sum = sum.subtract(luRow[i].multiply(lu[i][col]));
104 }
105 luRow[col] = sum;
106 }
107
108
109 int nonZero = col;
110 for (int row = col; row < m; row++) {
111 final T[] luRow = lu[row];
112 sum = luRow[col];
113 for (int i = 0; i < col; i++) {
114 sum = sum.subtract(luRow[i].multiply(lu[i][col]));
115 }
116 luRow[col] = sum;
117
118 if (lu[nonZero][col].equals(field.getZero())) {
119
120 ++nonZero;
121 }
122 }
123
124
125 if (nonZero >= m) {
126 singular = true;
127 return;
128 }
129
130
131 if (nonZero != col) {
132 T tmp = field.getZero();
133 for (int i = 0; i < m; i++) {
134 tmp = lu[nonZero][i];
135 lu[nonZero][i] = lu[col][i];
136 lu[col][i] = tmp;
137 }
138 int temp = pivot[nonZero];
139 pivot[nonZero] = pivot[col];
140 pivot[col] = temp;
141 even = !even;
142 }
143
144
145 final T luDiag = lu[col][col];
146 for (int row = col + 1; row < m; row++) {
147 final T[] luRow = lu[row];
148 luRow[col] = luRow[col].divide(luDiag);
149 }
150 }
151
152 }
153
154
155 public FieldMatrix<T> getL() {
156 if ((cachedL == null) && !singular) {
157 final int m = pivot.length;
158 cachedL = new Array2DRowFieldMatrix<T>(field, m, m);
159 for (int i = 0; i < m; ++i) {
160 final T[] luI = lu[i];
161 for (int j = 0; j < i; ++j) {
162 cachedL.setEntry(i, j, luI[j]);
163 }
164 cachedL.setEntry(i, i, field.getOne());
165 }
166 }
167 return cachedL;
168 }
169
170
171 public FieldMatrix<T> getU() {
172 if ((cachedU == null) && !singular) {
173 final int m = pivot.length;
174 cachedU = new Array2DRowFieldMatrix<T>(field, m, m);
175 for (int i = 0; i < m; ++i) {
176 final T[] luI = lu[i];
177 for (int j = i; j < m; ++j) {
178 cachedU.setEntry(i, j, luI[j]);
179 }
180 }
181 }
182 return cachedU;
183 }
184
185
186 public FieldMatrix<T> getP() {
187 if ((cachedP == null) && !singular) {
188 final int m = pivot.length;
189 cachedP = new Array2DRowFieldMatrix<T>(field, m, m);
190 for (int i = 0; i < m; ++i) {
191 cachedP.setEntry(i, pivot[i], field.getOne());
192 }
193 }
194 return cachedP;
195 }
196
197
198 public int[] getPivot() {
199 return pivot.clone();
200 }
201
202
203 public T getDeterminant() {
204 if (singular) {
205 return field.getZero();
206 } else {
207 final int m = pivot.length;
208 T determinant = even ? field.getOne() : field.getZero().subtract(field.getOne());
209 for (int i = 0; i < m; i++) {
210 determinant = determinant.multiply(lu[i][i]);
211 }
212 return determinant;
213 }
214 }
215
216
217 public FieldDecompositionSolver<T> getSolver() {
218 return new Solver<T>(field, lu, pivot, singular);
219 }
220
221
222 private static class Solver<T extends FieldElement<T>> implements FieldDecompositionSolver<T> {
223
224
225 private static final long serialVersionUID = -6353105415121373022L;
226
227
228 private final Field<T> field;
229
230
231 private final T lu[][];
232
233
234 private final int[] pivot;
235
236
237 private final boolean singular;
238
239
240
241
242
243
244
245
246 private Solver(final Field<T> field, final T[][] lu,
247 final int[] pivot, final boolean singular) {
248 this.field = field;
249 this.lu = lu;
250 this.pivot = pivot;
251 this.singular = singular;
252 }
253
254
255 public boolean isNonSingular() {
256 return !singular;
257 }
258
259
260 @SuppressWarnings("unchecked")
261 public T[] solve(T[] b)
262 throws IllegalArgumentException, InvalidMatrixException {
263
264 final int m = pivot.length;
265 if (b.length != m) {
266 throw MathRuntimeException.createIllegalArgumentException(
267 "vector length mismatch: got {0} but expected {1}",
268 b.length, m);
269 }
270 if (singular) {
271 throw new SingularMatrixException();
272 }
273
274 final T[] bp = (T[]) Array.newInstance(field.getZero().getClass(), m);
275
276
277 for (int row = 0; row < m; row++) {
278 bp[row] = b[pivot[row]];
279 }
280
281
282 for (int col = 0; col < m; col++) {
283 final T bpCol = bp[col];
284 for (int i = col + 1; i < m; i++) {
285 bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
286 }
287 }
288
289
290 for (int col = m - 1; col >= 0; col--) {
291 bp[col] = bp[col].divide(lu[col][col]);
292 final T bpCol = bp[col];
293 for (int i = 0; i < col; i++) {
294 bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
295 }
296 }
297
298 return bp;
299
300 }
301
302
303 @SuppressWarnings("unchecked")
304 public FieldVector<T> solve(FieldVector<T> b)
305 throws IllegalArgumentException, InvalidMatrixException {
306 try {
307 return solve((ArrayFieldVector<T>) b);
308 } catch (ClassCastException cce) {
309
310 final int m = pivot.length;
311 if (b.getDimension() != m) {
312 throw MathRuntimeException.createIllegalArgumentException(
313 "vector length mismatch: got {0} but expected {1}",
314 b.getDimension(), m);
315 }
316 if (singular) {
317 throw new SingularMatrixException();
318 }
319
320 final T[] bp = (T[]) Array.newInstance(field.getZero().getClass(), m);
321
322
323 for (int row = 0; row < m; row++) {
324 bp[row] = b.getEntry(pivot[row]);
325 }
326
327
328 for (int col = 0; col < m; col++) {
329 final T bpCol = bp[col];
330 for (int i = col + 1; i < m; i++) {
331 bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
332 }
333 }
334
335
336 for (int col = m - 1; col >= 0; col--) {
337 bp[col] = bp[col].divide(lu[col][col]);
338 final T bpCol = bp[col];
339 for (int i = 0; i < col; i++) {
340 bp[i] = bp[i].subtract(bpCol.multiply(lu[i][col]));
341 }
342 }
343
344 return new ArrayFieldVector<T>(bp, false);
345
346 }
347 }
348
349
350
351
352
353
354
355
356 public ArrayFieldVector<T> solve(ArrayFieldVector<T> b)
357 throws IllegalArgumentException, InvalidMatrixException {
358 return new ArrayFieldVector<T>(solve(b.getDataRef()), false);
359 }
360
361
362 @SuppressWarnings("unchecked")
363 public FieldMatrix<T> solve(FieldMatrix<T> b)
364 throws IllegalArgumentException, InvalidMatrixException {
365
366 final int m = pivot.length;
367 if (b.getRowDimension() != m) {
368 throw MathRuntimeException.createIllegalArgumentException(
369 "dimensions mismatch: got {0}x{1} but expected {2}x{3}",
370 b.getRowDimension(), b.getColumnDimension(), m, "n");
371 }
372 if (singular) {
373 throw new SingularMatrixException();
374 }
375
376 final int nColB = b.getColumnDimension();
377
378
379 final T[][] bp = (T[][]) Array.newInstance(field.getZero().getClass(), new int[] { m, nColB });
380 for (int row = 0; row < m; row++) {
381 final T[] bpRow = bp[row];
382 final int pRow = pivot[row];
383 for (int col = 0; col < nColB; col++) {
384 bpRow[col] = b.getEntry(pRow, col);
385 }
386 }
387
388
389 for (int col = 0; col < m; col++) {
390 final T[] bpCol = bp[col];
391 for (int i = col + 1; i < m; i++) {
392 final T[] bpI = bp[i];
393 final T luICol = lu[i][col];
394 for (int j = 0; j < nColB; j++) {
395 bpI[j] = bpI[j].subtract(bpCol[j].multiply(luICol));
396 }
397 }
398 }
399
400
401 for (int col = m - 1; col >= 0; col--) {
402 final T[] bpCol = bp[col];
403 final T luDiag = lu[col][col];
404 for (int j = 0; j < nColB; j++) {
405 bpCol[j] = bpCol[j].divide(luDiag);
406 }
407 for (int i = 0; i < col; i++) {
408 final T[] bpI = bp[i];
409 final T luICol = lu[i][col];
410 for (int j = 0; j < nColB; j++) {
411 bpI[j] = bpI[j].subtract(bpCol[j].multiply(luICol));
412 }
413 }
414 }
415
416 return new Array2DRowFieldMatrix<T>(bp, false);
417
418 }
419
420
421 public FieldMatrix<T> getInverse() throws InvalidMatrixException {
422 final int m = pivot.length;
423 final T one = field.getOne();
424 FieldMatrix<T> identity = new Array2DRowFieldMatrix<T>(field, m, m);
425 for (int i = 0; i < m; ++i) {
426 identity.setEntry(i, i, one);
427 }
428 return solve(identity);
429 }
430
431 }
432
433 }