net.maizegenetics.matrixalgebra.Matrix.EJMLDoubleMatrix Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of tassel Show documentation
Show all versions of tassel Show documentation
TASSEL is a software package to evaluate traits associations, evolutionary patterns, and linkage
disequilibrium.
package net.maizegenetics.matrixalgebra.Matrix;
import org.ejml.factory.DecompositionFactory;
import org.ejml.factory.LinearSolver;
import org.ejml.factory.LinearSolverFactory;
import org.ejml.data.DenseMatrix64F;
import org.ejml.ops.CommonOps;
import org.ejml.ops.SpecializedOps;
import net.maizegenetics.matrixalgebra.Matrix.DoubleMatrix;
import net.maizegenetics.matrixalgebra.decomposition.ColtEigenvalueDecomposition;
import net.maizegenetics.matrixalgebra.decomposition.EJMLEigenvalueDecomposition;
import net.maizegenetics.matrixalgebra.decomposition.EJMLSingularValueDecomposition;
import net.maizegenetics.matrixalgebra.decomposition.EigenvalueDecomposition;
import net.maizegenetics.matrixalgebra.decomposition.QRDecomposition;
import net.maizegenetics.matrixalgebra.decomposition.SingularValueDecomposition;
import net.maizegenetics.taxa.distance.DistanceMatrix;
public class EJMLDoubleMatrix implements DoubleMatrix {
public final DenseMatrix64F myMatrix;
public EJMLDoubleMatrix(DenseMatrix64F aMatrix) {
myMatrix = aMatrix;
}
public EJMLDoubleMatrix(int row, int col) {
myMatrix = new DenseMatrix64F(row, col);
}
public EJMLDoubleMatrix(int row, int col, double[] values) {
myMatrix = new DenseMatrix64F(row, col, true, values);
}
public EJMLDoubleMatrix(int row, int col, double value) {
myMatrix = new DenseMatrix64F(row, col);
if (value != 0) {
for (int r = 0; r < row; r++) {
for (int c = 0; c < col; c++) {
myMatrix.set(r,c,value);
}
}
}
}
public EJMLDoubleMatrix(double[][] values) {
myMatrix = new DenseMatrix64F(values);
}
public EJMLDoubleMatrix(DistanceMatrix values) {
myMatrix = new DenseMatrix64F(values.getDistances());
}
public EJMLDoubleMatrix(int size) {
myMatrix = CommonOps.identity(size);
}
public EJMLDoubleMatrix(double[] diagonal) {
myMatrix = CommonOps.diag(diagonal);
}
@Override
public DoubleMatrix column(int j) {
int n = myMatrix.numRows;
DenseMatrix64F result = new DenseMatrix64F(n,1);
SpecializedOps.subvector(myMatrix, 0, j, n, false, 0, result);
return new EJMLDoubleMatrix(result);
}
@Override
public int columnRank() {
if (myMatrix.numCols == 1) return 1;
EJMLSingularValueDecomposition svd = new EJMLSingularValueDecomposition(myMatrix);
return svd.getRank();
}
@Override
public DoubleMatrix concatenate(DoubleMatrix dm, boolean rows) {
DenseMatrix64F result = null;
DenseMatrix64F otherMatrix = ((EJMLDoubleMatrix) dm).myMatrix;
int myNumberOfRows = myMatrix.numRows;
int myNumberOfCols = myMatrix.numCols;
int dmNumberOfRows = otherMatrix.numRows;
int dmNumberOfCols = otherMatrix.numCols;
if (rows) {
if (myNumberOfCols != dmNumberOfCols) {
StringBuilder sb = new StringBuilder("Non-conformable matrices in concatenate rows: ");
sb.append(myNumberOfRows).append(" x ").append(myNumberOfCols);
sb.append(", ").append(dmNumberOfRows).append(" x ").append(dmNumberOfCols);
throw new IllegalArgumentException(sb.toString());
}
int totalRows = myNumberOfRows + dmNumberOfRows;
result = new DenseMatrix64F(totalRows, myNumberOfCols);
CommonOps.insert(myMatrix, result,0,0);
CommonOps.insert(otherMatrix,result, dmNumberOfRows,0);
} else {
if (myNumberOfRows != dmNumberOfRows) {
StringBuilder sb = new StringBuilder("Non-conformable matrices in concatenate columns: ");
sb.append(myNumberOfRows).append(" x ").append(myNumberOfCols);
sb.append(", ").append(dmNumberOfRows).append(" x ").append(dmNumberOfCols);
throw new IllegalArgumentException(sb.toString());
}
int totalCol = myNumberOfCols + dmNumberOfCols;
result = new DenseMatrix64F(myNumberOfRows, totalCol);
CommonOps.insert(myMatrix, result, 0, 0);
CommonOps.insert(otherMatrix, result, 0, myNumberOfCols);
}
return new EJMLDoubleMatrix(result);
}
@Override
public DoubleMatrix copy() {
return new EJMLDoubleMatrix(myMatrix.copy());
}
@Override
public DoubleMatrix crossproduct() {
int n = myMatrix.numCols;
DenseMatrix64F result = new DenseMatrix64F(n,n);
CommonOps.multTransA(myMatrix, myMatrix, result);
return new EJMLDoubleMatrix(result);
}
@Override
public DoubleMatrix crossproduct(DoubleMatrix dm) {
int nrow = myMatrix.numCols;
int ncol = dm.numberOfColumns();
DenseMatrix64F otherMatrix = ((EJMLDoubleMatrix) dm).myMatrix;
DenseMatrix64F result = new DenseMatrix64F(nrow,ncol);
CommonOps.multTransA(myMatrix, otherMatrix, result);
return new EJMLDoubleMatrix(result);
}
@Override
public DoubleMatrix generalizedInverse() {
DenseMatrix64F result = new DenseMatrix64F(myMatrix.numRows, myMatrix.numCols);
CommonOps.pinv(myMatrix, result);
return new EJMLDoubleMatrix(result);
}
@Override
public DoubleMatrix generalizedInverseWithRank(int[] rank) {
org.ejml.factory.SingularValueDecomposition myDecomposition = DecompositionFactory.svd(myMatrix.numRows, myMatrix.numCols, true, true, false);
myDecomposition.decompose(myMatrix);
double tol = 1e-10;
rank[0] = 0;
DenseMatrix64F W = myDecomposition.getW(null);
int n = W.getNumRows();
for (int i = 0; i < n; i++) {
double val = W.get(i,i);
if (val < tol) {
val = 0;
} else {
val = 1/val;
rank[0]++;
}
W.set(i, i, val);
}
DenseMatrix64F V = myDecomposition.getV(null, false);
DenseMatrix64F UT = myDecomposition.getU(null, true);
int nrows = V.getNumRows();
int ncols = W.getNumCols();
DenseMatrix64F VW = new DenseMatrix64F(nrows, ncols);
CommonOps.mult(V, W, VW);
ncols = UT.getNumCols();
DenseMatrix64F inv = new DenseMatrix64F(nrows, ncols);
CommonOps.mult(VW, UT, inv);
return new EJMLDoubleMatrix(inv);
}
@Override
public double get(int row, int col) {
return myMatrix.get(row, col);
}
@Override
public double getChecked(int row, int col) {
return myMatrix.get(row,col);
}
@Override
public EigenvalueDecomposition getEigenvalueDecomposition() {
EJMLEigenvalueDecomposition decomposition = new EJMLEigenvalueDecomposition(myMatrix);
if (decomposition.wasSuccessful()) return decomposition;
return new ColtEigenvalueDecomposition(myMatrix);
}
@Override
public DoubleMatrix[] getXtXGM() {
DoubleMatrix[] dmarray = new DoubleMatrix[3];
int ncol = myMatrix.numCols;
int nrow = myMatrix.numRows;
DenseMatrix64F result1 = new DenseMatrix64F(ncol, ncol);
CommonOps.multTransA(myMatrix, myMatrix, result1);
dmarray[0] = new EJMLDoubleMatrix(result1);
DenseMatrix64F inverse = new DenseMatrix64F(ncol, ncol);
CommonOps.invert(result1, inverse);
dmarray[1] = new EJMLDoubleMatrix(inverse);
DenseMatrix64F result2 = new DenseMatrix64F(ncol, nrow);
DenseMatrix64F ident = CommonOps.identity(nrow);
CommonOps.multTransB(inverse, myMatrix, result2);
CommonOps.multAdd(-1, myMatrix, result2, ident);
dmarray[2] = new EJMLDoubleMatrix(ident);
return dmarray;
}
@Override
public QRDecomposition getQRDecomposition() {
// TODO Auto-generated method stub
return null;
}
@Override
public SingularValueDecomposition getSingularValueDecomposition() {
return new EJMLSingularValueDecomposition(myMatrix);
}
@Override
public DoubleMatrix inverse() {
DenseMatrix64F inverse = new DenseMatrix64F(myMatrix.numRows, myMatrix.numCols);
CommonOps.invert(myMatrix, inverse);
return new EJMLDoubleMatrix(inverse);
}
@Override
public void invert() {
CommonOps.invert(myMatrix);
}
@Override
public DoubleMatrix minus(DoubleMatrix dm) {
DenseMatrix64F result = new DenseMatrix64F(myMatrix.numRows, myMatrix.numCols);
CommonOps.sub(myMatrix, ((EJMLDoubleMatrix) dm).myMatrix, result);
return new EJMLDoubleMatrix(result);
}
@Override
public void minusEquals(DoubleMatrix dm) {
CommonOps.subEquals(myMatrix, ((EJMLDoubleMatrix) dm).myMatrix);
}
@Override
public DoubleMatrix mult(DoubleMatrix dm, boolean transpose,
boolean transposedm) {
DenseMatrix64F b = ((EJMLDoubleMatrix) dm).myMatrix;
DenseMatrix64F result;
int nrow, ncol;
if (transpose) {
nrow = myMatrix.numCols;
if (transposedm) {
ncol = b.numRows;
result = new DenseMatrix64F(nrow, ncol);
CommonOps.multTransAB(myMatrix, b, result);
} else {
ncol = b.numCols;
result = new DenseMatrix64F(nrow, ncol);
CommonOps.multTransA(myMatrix, b, result);
}
} else {
nrow = myMatrix.numRows;
if (transposedm) {
ncol = b.numRows;
result = new DenseMatrix64F(nrow, ncol);
CommonOps.multTransB(myMatrix, b, result);
} else {
ncol = b.numCols;
result = new DenseMatrix64F(nrow, ncol);
CommonOps.mult(myMatrix, b, result);
}
}
return new EJMLDoubleMatrix(result);
}
@Override
public DoubleMatrix mult(DoubleMatrix dm) {
DenseMatrix64F b = ((EJMLDoubleMatrix) dm).myMatrix;
int nrow = myMatrix.numRows;
int ncol = b.numCols;
DenseMatrix64F result = new DenseMatrix64F(nrow, ncol);
CommonOps.mult(myMatrix, b, result);
return new EJMLDoubleMatrix(result);
}
@Override
public DoubleMatrix multadd(DoubleMatrix A, DoubleMatrix B, double alpha,
double beta, boolean transpose, boolean transposeA) {
DoubleMatrix result = mult(A, transpose, transposeA);
if (alpha != 1) result.scalarMultEquals(alpha);
if (B == null) return result;
if (beta == 1) return result.plus(B);
return result.plus(B.scalarMult(beta));
}
@Override
public int numberOfColumns() {
return myMatrix.numCols;
}
@Override
public int numberOfRows() {
return myMatrix.numRows;
}
@Override
public DoubleMatrix plus(DoubleMatrix dm) {
DenseMatrix64F result = new DenseMatrix64F(myMatrix.numRows, myMatrix.numCols);
CommonOps.add(myMatrix, ((EJMLDoubleMatrix) dm).myMatrix, result);
return new EJMLDoubleMatrix(result);
}
@Override
public void plusEquals(DoubleMatrix dm) {
CommonOps.addEquals(myMatrix, ((EJMLDoubleMatrix) dm).myMatrix);
}
@Override
public DoubleMatrix row(int i) {
int n = myMatrix.numCols;
DenseMatrix64F result = new DenseMatrix64F(n,1);
SpecializedOps.subvector(myMatrix, i, 0, n, true, 0, result);
return new EJMLDoubleMatrix(result);
}
@Override
public DoubleMatrix scalarAdd(double s) {
DenseMatrix64F result = new DenseMatrix64F(myMatrix.numRows, myMatrix.numCols);
CommonOps.add(myMatrix, s, result);
return new EJMLDoubleMatrix(result);
}
@Override
public void scalarAddEquals(double s) {
CommonOps.add(myMatrix, s);
}
@Override
public DoubleMatrix scalarMult(double s) {
DenseMatrix64F result = new DenseMatrix64F(myMatrix.numRows, myMatrix.numCols);
CommonOps.scale(s, myMatrix, result);
return new EJMLDoubleMatrix(result);
}
@Override
public void scalarMultEquals(double s) {
CommonOps.scale(s, myMatrix);
}
@Override
public void set(int row, int col, double value) {
myMatrix.set(row, col, value);
}
@Override
public void setChecked(int row, int col, double value) {
myMatrix.set(row, col, value);
}
@Override
public DoubleMatrix solve(DoubleMatrix Y) {
DenseMatrix64F data = ((EJMLDoubleMatrix) Y).myMatrix;
DenseMatrix64F result = new DenseMatrix64F(myMatrix.numCols, data.numCols);
LinearSolver solver = LinearSolverFactory.leastSquares(myMatrix.numCols, data.numCols);
solver.setA(myMatrix);
solver.solve(data, result);
return new EJMLDoubleMatrix(result);
}
@Override
public DoubleMatrix tcrossproduct() {
int ncol = myMatrix.numRows;
DenseMatrix64F result = new DenseMatrix64F(ncol, ncol);
CommonOps.multTransB(myMatrix, myMatrix, result);
return new EJMLDoubleMatrix(result);
}
@Override
public DoubleMatrix tcrossproduct(DoubleMatrix dm) {
DenseMatrix64F other = ((EJMLDoubleMatrix) dm).myMatrix;
int nrow = myMatrix.numRows;
int ncol = other.numRows;
DenseMatrix64F result = new DenseMatrix64F(nrow, ncol);
CommonOps.multTransB(myMatrix, other, result);
return new EJMLDoubleMatrix(result);
}
@Override
public DoubleMatrix transpose() {
DenseMatrix64F result = new DenseMatrix64F(myMatrix.numCols, myMatrix.numRows);
CommonOps.transpose(myMatrix, result);
return new EJMLDoubleMatrix(result);
}
@Override
public DoubleMatrix getSelection(int[] rows, int[] columns) {
if (rows == null) {
if (columns == null) return copy();
int nrow = myMatrix.numRows;
int ncol = columns.length;
DoubleMatrix result = DoubleMatrixFactory.DEFAULT.make(nrow, ncol);
for (int r = 0; r < nrow; r++) {
for (int c = 0; c < ncol; c++) {
result.set(r, c, myMatrix.get(r, columns[c]));
}
}
return result;
} else if (columns == null) {
int nrow = rows.length;
int ncol = myMatrix.numCols;
DoubleMatrix result = DoubleMatrixFactory.DEFAULT.make(nrow, ncol);
for (int r = 0; r < nrow; r++) {
for (int c = 0; c < ncol; c++) {
result.set(r, c, myMatrix.get(rows[r], c));
}
}
return result;
} else {
int nrow = rows.length;
int ncol = columns.length;
DoubleMatrix result = DoubleMatrixFactory.DEFAULT.make(nrow, ncol);
for (int r = 0; r < nrow; r++) {
for (int c = 0; c < ncol; c++) {
result.set(r, c, myMatrix.get(rows[r], columns[c]));
}
}
return result;
}
}
@Override
public String toString() {
int nrows = Math.min(25, myMatrix.numRows);
int ncols = Math.min(25, myMatrix.numCols);
StringBuilder sb = new StringBuilder();
for (int i = 0; i < nrows; i++) {
for (int j = 0; j < ncols; j++) {
if (j > 0) sb.append(" ");
sb.append(myMatrix.get(i,j));
}
sb.append("\n");
}
return sb.toString();
}
@Override
public double columnSum(int column) {
int n = myMatrix.numRows;
DenseMatrix64F vector = new DenseMatrix64F(n, 1);
SpecializedOps.subvector(myMatrix, 0, column, n, false, 0, vector);
return CommonOps.elementSum(vector);
}
@Override
public double rowSum(int row) {
int n = myMatrix.numCols;
DenseMatrix64F vector = new DenseMatrix64F(n, 1);
SpecializedOps.subvector(myMatrix, row, 0, n, true, 0, vector);
return CommonOps.elementSum(vector);
}
@Override
public double[] to1DArray() {
return myMatrix.data;
}
@Override
public double[][] toArray() {
int nrows = myMatrix.getNumCols();
int ncols = myMatrix.getNumRows();
double[][] array = new double[nrows][ncols];
for (int r = 0; r < nrows; r++) {
for (int c = 0; c < ncols; c++) {
array[r][c] = myMatrix.get(r,c);
}
}
return array;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy