edu.cmu.tetrad.util.Matrix Maven / Gradle / Ivy
The newest version!
///////////////////////////////////////////////////////////////////////////////
// For information as to what this class does, see the Javadoc, below. //
// Copyright (C) 1998, 1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, //
// 2007, 2008, 2009, 2010, 2014, 2015, 2022 by Peter Spirtes, Richard //
// Scheines, Joseph Ramsey, and Clark Glymour. //
// //
// This program is free software; you can redistribute it and/or modify //
// it under the terms of the GNU General Public License as published by //
// the Free Software Foundation; either version 2 of the License, or //
// (at your option) any later version. //
// //
// This program is distributed in the hope that it will be useful, //
// but WITHOUT ANY WARRANTY; without even the implied warranty of //
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the //
// GNU General Public License for more details. //
// //
// You should have received a copy of the GNU General Public License //
// along with this program; if not, write to the Free Software //
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA //
///////////////////////////////////////////////////////////////////////////////
package edu.cmu.tetrad.util;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import org.apache.commons.math3.linear.*;
import org.apache.commons.math3.util.FastMath;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serial;
/**
* Wraps the Apache math3 linear algebra library for most uses in Tetrad. Specialized uses will still have to use the
* library directly. One issue this fixes is that a BlockRealMatrix cannot represent a matrix with zero rows; this uses
* an Array2DRowRealMatrix to represent that case.
*
* @author josephramsey
* @version $Id: $Id
*/
public class Matrix implements TetradSerializable {
@Serial
private static final long serialVersionUID = 23L;
/**
* The Apache math3 matrix.
*/
private final RealMatrix apacheData;
/**
* The number of rows.
*/
private int m;
/**
* The number of columns.
*/
private int n;
/**
* Constructor for Matrix.
*
* @param data an array of {@link double} objects
*/
public Matrix(double[][] data) {
if (data.length == 0) {
this.apacheData = new Array2DRowRealMatrix();
} else {
this.apacheData = new BlockRealMatrix(data);
}
this.m = data.length;
this.n = this.m == 0 ? 0 : data[0].length;
}
/**
* Constructor for Matrix.
*
* @param data a {@link org.apache.commons.math3.linear.RealMatrix} object
*/
public Matrix(RealMatrix data) {
this.apacheData = data;
this.m = data.getRowDimension();
this.n = data.getColumnDimension();
}
/**
* Constructor for Matrix.
*
* @param m a int
* @param n a int
*/
public Matrix(int m, int n) {
if (m == 0 || n == 0) {
this.apacheData = new Array2DRowRealMatrix();
} else {
this.apacheData = new BlockRealMatrix(m, n);
}
this.m = m;
this.n = n;
}
/**
* Constructor for Matrix.
*
* @param m a {@link edu.cmu.tetrad.util.Matrix} object
*/
public Matrix(Matrix m) {
this(m.apacheData.copy());
}
/**
* identity.
*
* @param rows a int
* @return a {@link edu.cmu.tetrad.util.Matrix} object
*/
public static Matrix identity(int rows) {
Matrix m = new Matrix(rows, rows);
for (int i = 0; i < rows; i++) m.set(i, i, 1);
return m;
}
/**
* sparseMatrix.
*
* @param m a int
* @param n a int
* @return a {@link edu.cmu.tetrad.util.Matrix} object
*/
public static Matrix sparseMatrix(int m, int n) {
return new Matrix(new OpenMapRealMatrix(m, n).getData());
}
/**
* Generates a simple exemplar of this class to test serialization.
*
* @return a {@link edu.cmu.tetrad.util.Matrix} object
*/
public static Matrix serializableInstance() {
return new Matrix(0, 0);
}
/**
* assign.
*
* @param matrix a {@link edu.cmu.tetrad.util.Matrix} object
*/
public void assign(Matrix matrix) {
if (this.apacheData.getRowDimension() != matrix.getNumRows() || this.apacheData.getColumnDimension() != matrix.getNumColumns()) {
throw new IllegalArgumentException("Mismatched matrix size.");
}
for (int i = 0; i < this.apacheData.getRowDimension(); i++) {
for (int j = 0; j < this.apacheData.getColumnDimension(); j++) {
this.apacheData.setEntry(i, j, matrix.get(i, j));
}
}
}
/**
* getNumColumns.
*
* @return a int
*/
public int getNumColumns() {
return this.n;
}
/**
* diag.
*
* @return a {@link edu.cmu.tetrad.util.Vector} object
*/
public Vector diag() {
double[] diag = new double[this.apacheData.getRowDimension()];
for (int i = 0; i < this.apacheData.getRowDimension(); i++) {
diag[i] = this.apacheData.getEntry(i, i);
}
return new Vector(diag);
}
/**
* getSelection.
*
* @param rows an array of {@link int} objects
* @param cols an array of {@link int} objects
* @return a {@link edu.cmu.tetrad.util.Matrix} object
*/
public Matrix getSelection(int[] rows, int[] cols) {
Matrix m = new Matrix(rows.length, cols.length);
for (int i = 0; i < rows.length; i++) {
for (int j = 0; j < cols.length; j++) {
m.set(i, j, this.apacheData.getEntry(rows[i], cols[j]));
}
}
return m;
// if (rows.length == 0 || cols.length == 0) {
// return new Matrix(rows.length, cols.length);
// }
//
// RealMatrix subMatrix = this.apacheData.getSubMatrix(rows, cols);
// return new Matrix(subMatrix.getData());
}
/**
* copy.
*
* @return a {@link edu.cmu.tetrad.util.Matrix} object
*/
public Matrix copy() {
if (zeroDimension()) return new Matrix(getNumRows(), getNumColumns());
return new Matrix(this.apacheData.copy());
}
/**
* getColumn.
*
* @param j a int
* @return a {@link edu.cmu.tetrad.util.Vector} object
*/
public Vector getColumn(int j) {
if (zeroDimension()) {
return new Vector(getNumRows());
}
return new Vector(this.apacheData.getColumn(j));
}
/**
* times.
*
* @param m a {@link edu.cmu.tetrad.util.Matrix} object
* @return a {@link edu.cmu.tetrad.util.Matrix} object
*/
public Matrix times(Matrix m) {
if (this.zeroDimension() || m.zeroDimension())
return new Matrix(this.getNumRows(), m.getNumColumns());
else {
return new Matrix(this.apacheData.multiply(m.apacheData));
}
}
/**
* times.
*
* @param v a {@link edu.cmu.tetrad.util.Vector} object
* @return a {@link edu.cmu.tetrad.util.Vector} object
*/
public Vector times(Vector v) {
if (v.size() != this.apacheData.getColumnDimension()) {
throw new IllegalArgumentException("Mismatched dimensions.");
}
double[] y = new double[this.apacheData.getRowDimension()];
for (int i = 0; i < this.apacheData.getRowDimension(); i++) {
double sum = 0.0;
for (int j = 0; j < this.apacheData.getColumnDimension(); j++) {
sum += this.apacheData.getEntry(i, j) * v.get(j);
}
y[i] = sum;
}
return new Vector(y);
}
/**
* toArray.
*
* @return an array of {@link double} objects
*/
public double[][] toArray() {
return this.apacheData.getData();
}
/**
* Getter for the field apacheData
.
*
* @return a {@link org.apache.commons.math3.linear.RealMatrix} object
*/
public RealMatrix getApacheData() {
return this.apacheData;
}
/**
* get.
*
* @param i a int
* @param j a int
* @return a double
*/
public double get(int i, int j) {
return this.apacheData.getEntry(i, j);
}
/**
* like.
*
* @return a {@link edu.cmu.tetrad.util.Matrix} object
*/
public Matrix like() {
return new Matrix(this.apacheData.getRowDimension(), this.apacheData.getColumnDimension());
}
/**
* set.
*
* @param i a int
* @param j a int
* @param v a double
*/
public void set(int i, int j, double v) {
this.apacheData.setEntry(i, j, v);
}
/**
* getRow.
*
* @param i a int
* @return a {@link edu.cmu.tetrad.util.Vector} object
*/
public Vector getRow(int i) {
if (zeroDimension()) {
return new Vector(getNumColumns());
}
return new Vector(this.apacheData.getRow(i));
}
/**
* getPart.
*
* @param i a int
* @param j a int
* @param k a int
* @param l a int
* @return a {@link edu.cmu.tetrad.util.Matrix} object
*/
public Matrix getPart(int i, int j, int k, int l) {
return new Matrix(this.apacheData.getSubMatrix(i, j, k, l));
}
/**
* inverse.
*
* @return a {@link edu.cmu.tetrad.util.Matrix} object
* @throws org.apache.commons.math3.linear.SingularMatrixException if any.
*/
public Matrix inverse() throws SingularMatrixException {
if (!isSquare()) throw new IllegalArgumentException("I can only invert square matrices.");
if (getNumRows() == 0) {
return new Matrix(0, 0);
}
return new Matrix(new LUDecomposition(this.apacheData, 1e-10).getSolver().getInverse());
}
/**
* symmetricInverse.
*
* @return a {@link edu.cmu.tetrad.util.Matrix} object
*/
public Matrix symmetricInverse() {
if (!isSquare()) throw new IllegalArgumentException();
if (getNumRows() == 0) return new Matrix(0, 0);
return new Matrix(new CholeskyDecomposition(this.apacheData).getSolver().getInverse());
}
/**
* ginverse.
*
* @return a {@link edu.cmu.tetrad.util.Matrix} object
*/
public Matrix ginverse() {
double[][] data = this.apacheData.getData();
if (data.length == 0 || data[0].length == 0) {
return new Matrix(data);
}
return new Matrix(MatrixUtils.pseudoInverse(data));
}
/**
* assignRow.
*
* @param row a int
* @param doubles a {@link edu.cmu.tetrad.util.Vector} object
*/
public void assignRow(int row, Vector doubles) {
this.apacheData.setRow(row, doubles.toArray());
}
/**
* assignColumn.
*
* @param col a int
* @param doubles a {@link edu.cmu.tetrad.util.Vector} object
*/
public void assignColumn(int col, Vector doubles) {
this.apacheData.setColumn(col, doubles.toArray());
}
/**
* trace.
*
* @return a double
*/
public double trace() {
return this.apacheData.getTrace();
}
/**
* det.
*
* @return a double
*/
public double det() {
return new LUDecomposition(this.apacheData, 1e-6D).getDeterminant();
}
/**
* transpose.
*
* @return a {@link edu.cmu.tetrad.util.Matrix} object
*/
public Matrix transpose() {
if (zeroDimension()) return new Matrix(getNumColumns(), getNumRows());
return new Matrix(this.apacheData.transpose());
}
/**
* equals.
*
* @param m a {@link edu.cmu.tetrad.util.Matrix} object
* @param tolerance a double
* @return a boolean
*/
public boolean equals(Matrix m, double tolerance) {
for (int i = 0; i < this.apacheData.getRowDimension(); i++) {
for (int j = 0; j < this.apacheData.getColumnDimension(); j++) {
if (FastMath.abs(this.apacheData.getEntry(i, j) - m.apacheData.getEntry(i, j)) > tolerance) {
return false;
}
}
}
return true;
}
/**
* isSquare.
*
* @return a boolean
*/
public boolean isSquare() {
return getNumRows() == getNumColumns();
}
/**
* isSymmetric.
*
* @param tolerance a double
* @return a boolean
*/
public boolean isSymmetric(double tolerance) {
return MatrixUtils.isSymmetric(this.apacheData.getData(), tolerance);
}
/**
* minus.
*
* @param mb a {@link edu.cmu.tetrad.util.Matrix} object
* @return a {@link edu.cmu.tetrad.util.Matrix} object
*/
public Matrix minus(Matrix mb) {
if (mb.getNumRows() == 0 || mb.getNumColumns() == 0) return this;
return new Matrix(this.apacheData.subtract(mb.apacheData));
}
/**
* norm1.
*
* @return a double
*/
public double norm1() {
return this.apacheData.getNorm();
}
/**
* plus.
*
* @param mb a {@link edu.cmu.tetrad.util.Matrix} object
* @return a {@link edu.cmu.tetrad.util.Matrix} object
*/
public Matrix plus(Matrix mb) {
if (mb.getNumRows() == 0 || mb.getNumColumns() == 0) return this;
return new Matrix(this.apacheData.add(mb.apacheData));
}
/**
* rank.
*
* @return a int
*/
public int rank() {
SingularValueDecomposition singularValueDecomposition = new SingularValueDecomposition(this.apacheData);
return singularValueDecomposition.getRank();
}
/**
* getNumRows.
*
* @return a int
*/
public int getNumRows() {
return this.m;
}
/**
* scalarMult.
*
* @param scalar a double
* @return a {@link edu.cmu.tetrad.util.Matrix} object
*/
public Matrix scalarMult(double scalar) {
Matrix newMatrix = copy();
for (int i = 0; i < getNumRows(); i++) {
for (int j = 0; j < getNumColumns(); j++) {
newMatrix.set(i, j, get(i, j) * scalar);
}
}
return newMatrix;
}
/**
* sqrt.
*
* @return a {@link edu.cmu.tetrad.util.Matrix} object
*/
public Matrix sqrt() {
SingularValueDecomposition svd = new SingularValueDecomposition(this.apacheData);
RealMatrix U = svd.getU();
RealMatrix V = svd.getV();
double[] s = svd.getSingularValues();
for (int i = 0; i < s.length; i++) s[i] = 1.0 / s[i];
RealMatrix S = new BlockRealMatrix(s.length, s.length);
for (int i = 0; i < s.length; i++) S.setEntry(i, i, s[i]);
RealMatrix sqrt = U.multiply(S).multiply(V);
return new Matrix(sqrt);
}
/**
* sum.
*
* @param direction a int
* @return a {@link edu.cmu.tetrad.util.Vector} object
*/
public Vector sum(int direction) {
if (direction == 1) {
Vector sums = new Vector(getNumColumns());
for (int j = 0; j < getNumColumns(); j++) {
double sum = 0.0;
for (int i = 0; i < getNumRows(); i++) {
sum += this.apacheData.getEntry(i, j);
}
sums.set(j, sum);
}
return sums;
} else if (direction == 2) {
Vector sums = new Vector(getNumRows());
for (int i = 0; i < getNumRows(); i++) {
double sum = 0.0;
for (int j = 0; j < getNumColumns(); j++) {
sum += this.apacheData.getEntry(i, j);
}
sums.set(i, sum);
}
return sums;
} else {
throw new IllegalArgumentException("Expecting 1 (sum columns) or 2 (sum rows).");
}
}
/**
* zSum.
*
* @return a double
*/
public double zSum() {
return new DenseDoubleMatrix2D(this.apacheData.getData()).zSum();
}
private boolean zeroDimension() {
return getNumRows() == 0 || getNumColumns() == 0;
}
/**
* toString.
*
* @return a {@link java.lang.String} object
*/
public String toString() {
if (getNumRows() == 0) {
return "Empty";
} else {
return MatrixUtils.toString(toArray());
}
}
/**
* Writes the object to the specified ObjectOutputStream.
*
* @param out The ObjectOutputStream to write the object to.
* @throws IOException If an I/O error occurs.
*/
@Serial
private void writeObject(ObjectOutputStream out) throws IOException {
try {
out.defaultWriteObject();
} catch (IOException e) {
TetradLogger.getInstance().log("Failed to serialize object: " + getClass().getCanonicalName()
+ ", " + e.getMessage());
throw e;
}
}
/**
* Reads the object from the specified ObjectInputStream. This method is used during deserialization
* to restore the state of the object.
*
* @param in The ObjectInputStream to read the object from.
* @throws IOException If an I/O error occurs.
* @throws ClassNotFoundException If the class of the serialized object cannot be found.
*/
@Serial
private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
try {
in.defaultReadObject();
} catch (IOException e) {
TetradLogger.getInstance().log("Failed to deserialize object: " + getClass().getCanonicalName()
+ ", " + e.getMessage());
throw e;
}
}
}