All Downloads are FREE. Search and download functionalities are using the official Maven repository.

net.maizegenetics.matrixalgebra.Matrix.EJMLDoubleMatrix Maven / Gradle / Ivy

Go to download

TASSEL is a software package to evaluate traits associations, evolutionary patterns, and linkage disequilibrium.

There is a newer version: 5.2.94
Show newest version
package net.maizegenetics.matrixalgebra.Matrix;

import org.ejml.factory.DecompositionFactory;
import org.ejml.factory.LinearSolver;
import org.ejml.factory.LinearSolverFactory;
import org.ejml.data.DenseMatrix64F;
import org.ejml.ops.CommonOps;
import org.ejml.ops.SpecializedOps;

import net.maizegenetics.matrixalgebra.Matrix.DoubleMatrix;
import net.maizegenetics.matrixalgebra.decomposition.ColtEigenvalueDecomposition;
import net.maizegenetics.matrixalgebra.decomposition.EJMLEigenvalueDecomposition;
import net.maizegenetics.matrixalgebra.decomposition.EJMLSingularValueDecomposition;
import net.maizegenetics.matrixalgebra.decomposition.EigenvalueDecomposition;
import net.maizegenetics.matrixalgebra.decomposition.QRDecomposition;
import net.maizegenetics.matrixalgebra.decomposition.SingularValueDecomposition;
import net.maizegenetics.taxa.distance.DistanceMatrix;

public class EJMLDoubleMatrix implements DoubleMatrix {

	public final DenseMatrix64F myMatrix;
	
	public EJMLDoubleMatrix(DenseMatrix64F aMatrix) {
		myMatrix = aMatrix;
	}
	
	public EJMLDoubleMatrix(int row, int col) {
		myMatrix = new DenseMatrix64F(row, col);
	}
	
	public EJMLDoubleMatrix(int row, int col, double[] values) {
		myMatrix = new DenseMatrix64F(row, col, true, values);
	}

	public EJMLDoubleMatrix(int row, int col, double value) {
		myMatrix = new DenseMatrix64F(row, col);
		if (value != 0) {
			for (int r = 0; r < row; r++) {
				for (int c = 0; c < col; c++) {
					myMatrix.set(r,c,value);
				}
					
			}
		}
	}

	public EJMLDoubleMatrix(double[][] values) {
		myMatrix = new DenseMatrix64F(values);
	}
        
        public EJMLDoubleMatrix(DistanceMatrix values) {
		myMatrix = new DenseMatrix64F(values.getDistances());
	}
	
	public EJMLDoubleMatrix(int size) {
		myMatrix = CommonOps.identity(size);
	}
	
	public EJMLDoubleMatrix(double[] diagonal) {
		myMatrix = CommonOps.diag(diagonal);
	}
	
	@Override
	public DoubleMatrix column(int j) {
		int n = myMatrix.numRows;
		DenseMatrix64F result = new DenseMatrix64F(n,1);
		SpecializedOps.subvector(myMatrix, 0, j, n, false, 0, result);
		return new EJMLDoubleMatrix(result);
	}

	@Override
	public int columnRank() {
		if (myMatrix.numCols == 1) return 1;
		EJMLSingularValueDecomposition svd = new EJMLSingularValueDecomposition(myMatrix);
		return svd.getRank();
	}

	@Override
	public DoubleMatrix concatenate(DoubleMatrix dm, boolean rows) {
		DenseMatrix64F result = null;
		DenseMatrix64F otherMatrix = ((EJMLDoubleMatrix) dm).myMatrix;
		int myNumberOfRows = myMatrix.numRows;
		int myNumberOfCols = myMatrix.numCols;
		int dmNumberOfRows = otherMatrix.numRows;
		int dmNumberOfCols = otherMatrix.numCols;
		if (rows) {
			if (myNumberOfCols != dmNumberOfCols) {
				StringBuilder sb = new StringBuilder("Non-conformable matrices in concatenate rows: ");
				sb.append(myNumberOfRows).append(" x ").append(myNumberOfCols);
				sb.append(", ").append(dmNumberOfRows).append(" x ").append(dmNumberOfCols);
				throw new IllegalArgumentException(sb.toString());
			}
			int totalRows = myNumberOfRows + dmNumberOfRows;
			result = new DenseMatrix64F(totalRows, myNumberOfCols);
			CommonOps.insert(myMatrix, result,0,0);
			CommonOps.insert(otherMatrix,result, dmNumberOfRows,0);
		} else {
			if (myNumberOfRows != dmNumberOfRows) {
				StringBuilder sb = new StringBuilder("Non-conformable matrices in concatenate columns: ");
				sb.append(myNumberOfRows).append(" x ").append(myNumberOfCols);
				sb.append(", ").append(dmNumberOfRows).append(" x ").append(dmNumberOfCols);
				throw new IllegalArgumentException(sb.toString());
			}
			int totalCol = myNumberOfCols + dmNumberOfCols;
			result = new DenseMatrix64F(myNumberOfRows, totalCol);
			CommonOps.insert(myMatrix, result, 0, 0);
			CommonOps.insert(otherMatrix, result, 0, myNumberOfCols);
		}
		return new EJMLDoubleMatrix(result);
	}

	@Override
	public DoubleMatrix copy() {
		return new EJMLDoubleMatrix(myMatrix.copy());
	}

	@Override
	public DoubleMatrix crossproduct() {
		int n = myMatrix.numCols;
		DenseMatrix64F result = new DenseMatrix64F(n,n);
		CommonOps.multTransA(myMatrix, myMatrix, result);
		return new EJMLDoubleMatrix(result);
	}

	@Override
	public DoubleMatrix crossproduct(DoubleMatrix dm) {
		int nrow = myMatrix.numCols;
		int ncol = dm.numberOfColumns();
		DenseMatrix64F otherMatrix = ((EJMLDoubleMatrix) dm).myMatrix;
		DenseMatrix64F result = new DenseMatrix64F(nrow,ncol);
		CommonOps.multTransA(myMatrix, otherMatrix, result);
		return new EJMLDoubleMatrix(result);
	}

	@Override
	public DoubleMatrix generalizedInverse() {
		DenseMatrix64F result = new DenseMatrix64F(myMatrix.numRows, myMatrix.numCols);
		CommonOps.pinv(myMatrix, result);
		return new EJMLDoubleMatrix(result);
	}

	@Override
	public DoubleMatrix generalizedInverseWithRank(int[] rank) {
		org.ejml.factory.SingularValueDecomposition myDecomposition = DecompositionFactory.svd(myMatrix.numRows, myMatrix.numCols, true, true, false);
		myDecomposition.decompose(myMatrix);
		double tol = 1e-10;
		rank[0] = 0;
		DenseMatrix64F W = myDecomposition.getW(null);
		int n = W.getNumRows();
		for (int i = 0; i < n; i++) {
			double val = W.get(i,i);
			if (val < tol) {
				val = 0;
			} else {
				val = 1/val;
				rank[0]++;
			}
			W.set(i, i, val);
		}
		
		DenseMatrix64F V = myDecomposition.getV(null, false);
		DenseMatrix64F UT  = myDecomposition.getU(null, true);
		
		int nrows = V.getNumRows();
		int ncols = W.getNumCols();
		
		DenseMatrix64F VW = new DenseMatrix64F(nrows, ncols);
		CommonOps.mult(V, W, VW);
		
		ncols = UT.getNumCols();
		
		DenseMatrix64F inv = new DenseMatrix64F(nrows, ncols);
		CommonOps.mult(VW, UT, inv);
		return new EJMLDoubleMatrix(inv);
	}

	@Override
	public double get(int row, int col) {
		return myMatrix.get(row, col);
	}

	@Override
	public double getChecked(int row, int col) {
		return myMatrix.get(row,col);
	}

	@Override
	public EigenvalueDecomposition getEigenvalueDecomposition() {
		EJMLEigenvalueDecomposition decomposition = new EJMLEigenvalueDecomposition(myMatrix);
		if (decomposition.wasSuccessful()) return decomposition;
		return new ColtEigenvalueDecomposition(myMatrix);
	}

	@Override
	public DoubleMatrix[] getXtXGM() {
		DoubleMatrix[] dmarray = new DoubleMatrix[3];
		int ncol = myMatrix.numCols;
		int nrow = myMatrix.numRows;
		DenseMatrix64F result1 = new DenseMatrix64F(ncol, ncol);
		CommonOps.multTransA(myMatrix, myMatrix, result1);
		dmarray[0] = new EJMLDoubleMatrix(result1);
		
		DenseMatrix64F inverse = new DenseMatrix64F(ncol, ncol);
		CommonOps.invert(result1, inverse);
		dmarray[1] = new EJMLDoubleMatrix(inverse);
		
		DenseMatrix64F result2 = new DenseMatrix64F(ncol, nrow);
		DenseMatrix64F ident = CommonOps.identity(nrow);
		CommonOps.multTransB(inverse, myMatrix, result2);
		CommonOps.multAdd(-1, myMatrix, result2, ident);
		dmarray[2] = new EJMLDoubleMatrix(ident);
		return dmarray;
	}

	@Override
	public QRDecomposition getQRDecomposition() {
		// TODO Auto-generated method stub
		return null;
	}

	@Override
	public SingularValueDecomposition getSingularValueDecomposition() {
		return new EJMLSingularValueDecomposition(myMatrix);
	}

	@Override
	public DoubleMatrix inverse() {
		DenseMatrix64F inverse = new DenseMatrix64F(myMatrix.numRows, myMatrix.numCols);
		CommonOps.invert(myMatrix, inverse);
		return new EJMLDoubleMatrix(inverse);
	}

	@Override
	public void invert() {
		CommonOps.invert(myMatrix);
	}

	@Override
	public DoubleMatrix minus(DoubleMatrix dm) {
		DenseMatrix64F result = new DenseMatrix64F(myMatrix.numRows, myMatrix.numCols);
		CommonOps.sub(myMatrix, ((EJMLDoubleMatrix) dm).myMatrix, result);
		return new EJMLDoubleMatrix(result);
	}

	@Override
	public void minusEquals(DoubleMatrix dm) {
		CommonOps.subEquals(myMatrix, ((EJMLDoubleMatrix) dm).myMatrix);
	}

	@Override
	public DoubleMatrix mult(DoubleMatrix dm, boolean transpose,
			boolean transposedm) {
		DenseMatrix64F b = ((EJMLDoubleMatrix) dm).myMatrix;
		DenseMatrix64F result;
		
		int nrow, ncol;
		if (transpose) {
			nrow = myMatrix.numCols;
			if (transposedm) {
				ncol = b.numRows;
				result = new DenseMatrix64F(nrow, ncol);
				CommonOps.multTransAB(myMatrix, b, result);
			} else {
				ncol = b.numCols;
				result = new DenseMatrix64F(nrow, ncol);
				CommonOps.multTransA(myMatrix, b, result);
			}
		} else {
			nrow = myMatrix.numRows;
			if (transposedm) {
				ncol = b.numRows;
				result = new DenseMatrix64F(nrow, ncol);
				CommonOps.multTransB(myMatrix, b, result);
			} else {
				ncol = b.numCols;
				result = new DenseMatrix64F(nrow, ncol);
				CommonOps.mult(myMatrix, b, result);
			}
		}
		
		return new EJMLDoubleMatrix(result);
	}

	@Override
	public DoubleMatrix mult(DoubleMatrix dm) {
		DenseMatrix64F b = ((EJMLDoubleMatrix) dm).myMatrix;
		int nrow = myMatrix.numRows;
		int ncol = b.numCols;
		DenseMatrix64F result = new DenseMatrix64F(nrow, ncol);
		CommonOps.mult(myMatrix, b, result);
		return new EJMLDoubleMatrix(result);
	}

	@Override
	public DoubleMatrix multadd(DoubleMatrix A, DoubleMatrix B, double alpha,
			double beta, boolean transpose, boolean transposeA) {
		
		DoubleMatrix result = mult(A, transpose, transposeA);
		if (alpha != 1) result.scalarMultEquals(alpha);
		if (B == null) return result;

		if (beta == 1) return result.plus(B);
		return result.plus(B.scalarMult(beta));
	}

	@Override
	public int numberOfColumns() {
		return myMatrix.numCols;
	}

	@Override
	public int numberOfRows() {
		return myMatrix.numRows;
	}

	@Override
	public DoubleMatrix plus(DoubleMatrix dm) {
		DenseMatrix64F result = new DenseMatrix64F(myMatrix.numRows, myMatrix.numCols);
		CommonOps.add(myMatrix, ((EJMLDoubleMatrix) dm).myMatrix, result);
		return new EJMLDoubleMatrix(result);
	}

	@Override
	public void plusEquals(DoubleMatrix dm) {
		CommonOps.addEquals(myMatrix, ((EJMLDoubleMatrix) dm).myMatrix);
	}

	@Override
	public DoubleMatrix row(int i) {
		int n = myMatrix.numCols;
		DenseMatrix64F result = new DenseMatrix64F(n,1);
		SpecializedOps.subvector(myMatrix, i, 0, n, true, 0, result);
		return new EJMLDoubleMatrix(result);
	}
	
	@Override
	public DoubleMatrix scalarAdd(double s) {
		DenseMatrix64F result = new DenseMatrix64F(myMatrix.numRows, myMatrix.numCols);
		CommonOps.add(myMatrix, s, result);
		return new EJMLDoubleMatrix(result);
	}

	@Override
	public void scalarAddEquals(double s) {
		CommonOps.add(myMatrix, s);
	}

	@Override
	public DoubleMatrix scalarMult(double s) {
		DenseMatrix64F result = new DenseMatrix64F(myMatrix.numRows, myMatrix.numCols);
		CommonOps.scale(s, myMatrix, result);
		return new EJMLDoubleMatrix(result);
	}

	@Override
	public void scalarMultEquals(double s) {
		CommonOps.scale(s, myMatrix);
	}

	@Override
	public void set(int row, int col, double value) {
		myMatrix.set(row, col, value);
	}

	@Override
	public void setChecked(int row, int col, double value) {
		myMatrix.set(row, col, value);
	}

	@Override
	public DoubleMatrix solve(DoubleMatrix Y) {
		DenseMatrix64F data = ((EJMLDoubleMatrix) Y).myMatrix;
		DenseMatrix64F result = new DenseMatrix64F(myMatrix.numCols, data.numCols);
		LinearSolver solver = LinearSolverFactory.leastSquares(myMatrix.numCols, data.numCols);
		solver.setA(myMatrix);
		solver.solve(data, result);
		return new EJMLDoubleMatrix(result);
	}

	@Override
	public DoubleMatrix tcrossproduct() {
		int ncol = myMatrix.numRows;
		DenseMatrix64F result = new DenseMatrix64F(ncol, ncol);
		CommonOps.multTransB(myMatrix, myMatrix, result);
		return new EJMLDoubleMatrix(result);
	}

	@Override
	public DoubleMatrix tcrossproduct(DoubleMatrix dm) {
		DenseMatrix64F other = ((EJMLDoubleMatrix) dm).myMatrix;
		int nrow = myMatrix.numRows;
		int ncol = other.numRows;
		DenseMatrix64F result = new DenseMatrix64F(nrow, ncol);
		CommonOps.multTransB(myMatrix, other, result);
		return new EJMLDoubleMatrix(result);
	}

	@Override
	public DoubleMatrix transpose() {
		DenseMatrix64F result = new DenseMatrix64F(myMatrix.numCols, myMatrix.numRows);
		CommonOps.transpose(myMatrix, result);
		return new EJMLDoubleMatrix(result);
	}

	@Override
	public DoubleMatrix getSelection(int[] rows, int[] columns) {
		if (rows == null) {
			if (columns == null) return copy();
			int nrow = myMatrix.numRows;
			int ncol = columns.length;
			DoubleMatrix result = DoubleMatrixFactory.DEFAULT.make(nrow, ncol);
			for (int r = 0; r < nrow; r++) {
				for (int c = 0; c < ncol; c++) {
					result.set(r, c, myMatrix.get(r, columns[c]));
				}
			}
			return result;
		} else if (columns == null) {
			int nrow = rows.length;
			int ncol = myMatrix.numCols;
			DoubleMatrix result = DoubleMatrixFactory.DEFAULT.make(nrow, ncol);
			for (int r = 0; r < nrow; r++) {
				for (int c = 0; c < ncol; c++) {
					result.set(r, c, myMatrix.get(rows[r], c));
				}
			}
			return result;
		} else {
			int nrow = rows.length;
			int ncol = columns.length;
			DoubleMatrix result = DoubleMatrixFactory.DEFAULT.make(nrow, ncol);
			for (int r = 0; r < nrow; r++) {
				for (int c = 0; c < ncol; c++) {
					result.set(r, c, myMatrix.get(rows[r], columns[c]));
				}
			}
			return result;
		}
	}

	@Override
	public String toString() {
		int nrows = Math.min(25, myMatrix.numRows);
		int ncols = Math.min(25, myMatrix.numCols);
		StringBuilder sb = new StringBuilder();
		for (int i = 0; i < nrows; i++) {
			for (int j = 0; j < ncols; j++) {
				if (j > 0) sb.append(" ");
				sb.append(myMatrix.get(i,j));
			}
			sb.append("\n");
		}
		return sb.toString();
	}

	@Override
	public double columnSum(int column) {
		int n = myMatrix.numRows;
		DenseMatrix64F vector = new DenseMatrix64F(n, 1);
		SpecializedOps.subvector(myMatrix, 0, column, n, false, 0, vector);
		return CommonOps.elementSum(vector);
	}

	@Override
	public double rowSum(int row) {
		int n = myMatrix.numCols;
		DenseMatrix64F vector = new DenseMatrix64F(n, 1);
		SpecializedOps.subvector(myMatrix, row, 0, n, true, 0, vector);
		return CommonOps.elementSum(vector);
	}

	@Override
	public double[] to1DArray() {
		return myMatrix.data;
	}

	@Override
	public double[][] toArray() {
		int nrows = myMatrix.getNumCols();
		int ncols = myMatrix.getNumRows();
		double[][] array = new double[nrows][ncols];
		for (int r = 0; r < nrows; r++) {
			for (int c = 0; c < ncols; c++) {
				array[r][c] = myMatrix.get(r,c);
			}
		}
		return array;
	}

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy