mikera.matrixx.AMatrix Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of vectorz Show documentation
Show all versions of vectorz Show documentation
Fast double-precision vector and matrix maths library for Java, supporting N-dimensional numeric arrays.
package mikera.matrixx;
import java.nio.DoubleBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import mikera.arrayz.Array;
import mikera.arrayz.Arrayz;
import mikera.arrayz.INDArray;
import mikera.arrayz.ISparse;
import mikera.arrayz.impl.AbstractArray;
import mikera.arrayz.impl.IDense;
import mikera.arrayz.impl.JoinedArray;
import mikera.arrayz.impl.SliceArray;
import mikera.indexz.AIndex;
import mikera.indexz.Index;
import mikera.matrixx.algo.Definite;
import mikera.matrixx.algo.Determinant;
import mikera.matrixx.algo.Inverse;
import mikera.matrixx.algo.Multiplications;
import mikera.matrixx.algo.Rank;
import mikera.matrixx.impl.ADenseArrayMatrix;
import mikera.matrixx.impl.ARectangularMatrix;
import mikera.matrixx.impl.AStridedMatrix;
import mikera.matrixx.impl.IdentityMatrix;
import mikera.matrixx.impl.ImmutableMatrix;
import mikera.matrixx.impl.MatrixBandView;
import mikera.matrixx.impl.MatrixColumnList;
import mikera.matrixx.impl.MatrixColumnView;
import mikera.matrixx.impl.MatrixElementIterator;
import mikera.matrixx.impl.MatrixRowIterator;
import mikera.matrixx.impl.MatrixRowList;
import mikera.matrixx.impl.MatrixRowView;
import mikera.matrixx.impl.SparseColumnMatrix;
import mikera.matrixx.impl.SparseRowMatrix;
import mikera.matrixx.impl.SubMatrixView;
import mikera.matrixx.impl.TransposedMatrix;
import mikera.matrixx.impl.ZeroMatrix;
import mikera.randomz.Hash;
import mikera.transformz.AAffineTransform;
import mikera.transformz.AffineMN;
import mikera.transformz.impl.IdentityTranslation;
import mikera.util.Maths;
import mikera.vectorz.AScalar;
import mikera.vectorz.AVector;
import mikera.vectorz.IOperator;
import mikera.vectorz.Op;
import mikera.vectorz.Op2;
import mikera.vectorz.Tools;
import mikera.vectorz.Vector;
import mikera.vectorz.Vectorz;
import mikera.vectorz.impl.ADenseArrayVector;
import mikera.vectorz.impl.MatrixViewVector;
import mikera.vectorz.impl.Vector0;
import mikera.vectorz.util.Constants;
import mikera.vectorz.util.DoubleArrays;
import mikera.vectorz.util.ErrorMessages;
import mikera.vectorz.util.IntArrays;
import mikera.vectorz.util.VectorzException;
/**
* Abstract 2D matrix class. All Vectorz 2D matrices inherit from this class.
*
* Implements generic version of most key matrix operations.
*
* @author Mike
*/
public abstract class AMatrix extends AbstractArray implements IMatrix {
// ==============================================
// Abstract interface
private static final long serialVersionUID = 4854869374064155441L;
@Override
public abstract double get(int i, int j);
@Override
public abstract void set(int i, int j, double value);
// =============================================
// Standard implementations
@Override
public final double get(int row) {
throw new VectorzException("1D get not supported on matrix!");
}
@Override
public final double get() {
throw new VectorzException("0D get not supported on matrix!");
}
@Override
public void set(int row, double value) {
throw new VectorzException("1D get not supported on matrix!");
}
@Override
public final void set(double value) {
fill(value);
}
@Override
public void fill(double value) {
int len=rowCount();
for (int i = 0; i < len; i++) {
getRowView(i).fill(value);
}
}
/**
* Sets an element value in the matrix in an unsafe fashion, without performing bound checks
* The result is undefined if the row and column are out of bounds.
* @param i
* @param j
* @return
*/
public void unsafeSet(int i, int j, double value) {
set(i,j,value);
}
/**
* Gets an element in the matrix in an unsafe fashion, without performing bound checks
* The result is undefined if the row and column are out of bounds.
* @param i
* @param j
* @return
*/
public double unsafeGet(int i, int j) {
return get(i,j);
}
@Override
public void clamp(double min, double max) {
int len=rowCount();
for (int i = 0; i < len; i++) {
getRowView(i).clamp(min, max);
}
}
@Override
public void pow(double exponent) {
int len=rowCount();
for (int i = 0; i < len; i++) {
AVector v=getRowView(i);
v.pow(exponent);
}
}
@Override
public void square() {
int len=rowCount();
for (int i = 0; i < len; i++) {
getRowView(i).square();
}
}
@Override
public void set(int[] indexes, double value) {
if (indexes.length==2) {
set(indexes[0],indexes[1],value);
} else {
throw new VectorzException(""+indexes.length+"D set not supported on AMatrix");
}
}
@Override
public void set(long[] indexes, double value) {
if (indexes.length==2) {
set(Tools.toInt(indexes[0]),Tools.toInt(indexes[1]),value);
} else {
throw new VectorzException(""+indexes.length+"D set not supported on AMatrix");
}
}
@Override
public final int dimensionality() {
return 2;
}
/**
* Returns the number of dimensions required for input vectors
* @return
*/
public final int inputDimensions() {
return columnCount();
}
/**
* Returns the number of dimensions required for output vectors
* @return
*/
public final int outputDimensions() {
return rowCount();
}
@Override
public long elementCount() {
return ((long)rowCount())*columnCount();
}
@Override
public final AVector slice(int row) {
return getRowView(row);
}
@Override
public AVector slice(int dimension, int index) {
checkDimension(dimension);
return (dimension==0)?getRowView(index):getColumnView(index);
}
@Override
public int sliceCount() {
return rowCount();
}
@Override
public final List getSlices() {
return getRows();
}
@Override
public List getRows() {
return new MatrixRowList(this);
}
@Override
public List getColumns() {
return new MatrixColumnList(this);
}
@Override
public final List getSlices(int dimension) {
checkDimension(dimension);
return (dimension==0)?getRows():getColumns();
}
@Override
public List getSliceViews() {
int rc=rowCount();
ArrayList al=new ArrayList(rc);
for (int i=0; i=(rc*cc))) throw new IndexOutOfBoundsException(ErrorMessages.invalidElementIndex(this,i));
return unsafeGet((int)(i/cc),(int)(i%cc));
}
/**
* Returns a vector view of the leading diagonal values of the matrix
* @return
*/
public AVector getLeadingDiagonal() {
return getBand(0);
}
public AAffineTransform toAffineTransform() {
return new AffineMN(this,IdentityTranslation.create(rowCount()));
}
@Override
public boolean isIdentity() {
int rc=this.rowCount();
int cc=this.columnCount();
if (rc!=cc) return false;
for (int i=0; i cols=getColumns();
for( int i = 0; i < n; i++ ) {
AVector a = cols.get(i);
if (!a.isUnitLengthVector(tolerance)) return false;
for( int j = i+1; j < n; j++ ) {
double val = a.dotProduct(cols.get(j));
if ((Math.abs(val) > tolerance)) return false;
}
}
return true;
}
/**
* Tests whether all columns in the matrix are orthonormal vectors
* @return
*/
public boolean hasOrthonormalColumns() {
return getTranspose().innerProduct(this).epsilonEquals(IdentityMatrix.create(columnCount()));
}
/**
* Tests whether all rows in the matrix are orthonormal vectors
* @return
*/
public boolean hasOrthonormalRows() {
return innerProduct(getTranspose()).epsilonEquals(IdentityMatrix.create(rowCount()));
}
@Override
public INDArray reshape(int... dimensions) {
int ndims=dimensions.length;
if (ndims==1) {
return toVector().subVector(0, dimensions[0]);
} else if (ndims==2) {
return Matrixx.createFromVector(asVector(), dimensions[0], dimensions[1]);
} else {
return Arrayz.createFromVector(toVector(), dimensions);
}
}
public Matrix reshape(int rows, int cols) {
return Matrixx.createFromVector(asVector(), rows, cols);
}
@Override
public AMatrix reorder(int[] order) {
return reorder(0,order);
}
@Override
public AMatrix reorder(int dim, int[] order) {
int n=order.length;
switch (dim) {
case 0: {
if (n==0) return ZeroMatrix.create(0, columnCount());
if (IntArrays.isRange(order)) {
return subMatrix(order[0],n,0,columnCount());
}
ArrayList al=new ArrayList(n);
for (int si: order) {
al.add(slice(si));
}
return SparseRowMatrix.wrap(al);
}
case 1: {
if (n==0) return ZeroMatrix.create(rowCount(),0);
if (IntArrays.isRange(order)) {
return subMatrix(0,rowCount(),order[0],n);
}
ArrayList al=new ArrayList(n);
for (int si: order) {
al.add(slice(1,si));
}
return SparseColumnMatrix.wrap(al);
}
default: throw new IndexOutOfBoundsException(ErrorMessages.invalidDimension(this, dim));
}
}
@Override
public AMatrix subMatrix(int rowStart, int rows, int colStart, int cols) {
if ((rows==0)||(cols==0)) return ZeroMatrix.create(rows, cols);
return new SubMatrixView(this, rowStart, colStart, rows, cols);
}
@Override
public final AMatrix subArray(int[] offsets, int[] shape) {
if (offsets.length!=2) throw new IllegalArgumentException(ErrorMessages.invalidIndex(this, offsets));
if (shape.length!=2) throw new IllegalArgumentException(ErrorMessages.illegalSize(shape));
return subMatrix(offsets[0],shape[0],offsets[1],shape[1]);
}
@Override
public INDArray rotateView(int dimension, int shift) {
int n=getShape(dimension);
if (n==0) return this;
shift = Maths.mod(shift,n);
if (shift==0) return this;
int[] off=new int[2];
int[] shp=getShapeClone();
shp[dimension]=shift;
INDArray right=subArray(off,shp);
shp[dimension]=n-shift;
off[dimension]=shift;
INDArray left=subArray(off,shp);
return left.join(right,dimension);
}
@Override
public void transform(AVector source, AVector dest) {
if ((source instanceof Vector )&&(dest instanceof Vector)) {
transform ((Vector)source, (Vector)dest);
return;
}
int rc = rowCount();
int cc = columnCount();
if (source.length()!=cc) throw new IllegalArgumentException(ErrorMessages.wrongSourceLength(source));
if (dest.length()!=rc) throw new IllegalArgumentException(ErrorMessages.wrongDestLength(dest));
for (int row = 0; row < rc; row++) {
dest.unsafeSet(row, rowDotProduct(row,source));
}
}
public void transform(Vector source, Vector dest) {
int rc = rowCount();
int cc = columnCount();
if (source.length()!=cc) throw new IllegalArgumentException(ErrorMessages.wrongSourceLength(source));
if (dest.length()!=rc) throw new IllegalArgumentException(ErrorMessages.wrongDestLength(dest));
for (int row = 0; row < rc; row++) {
dest.unsafeSet(row, rowDotProduct(row,source));
}
}
@Override
public void transformInPlace(AVector v) {
if (v instanceof ADenseArrayVector) {
transformInPlace((ADenseArrayVector)v);
return;
}
double[] temp = new double[v.length()];
int rc = rowCount();
int cc = columnCount();
if (v.length()!=rc) throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(this,v));
if (rc != cc)
throw new UnsupportedOperationException(
"Cannot transform in place with a non-square transformation");
for (int row = 0; row < rc; row++) {
temp[row] = getRow(row).dotProduct(v);
}
v.setElements(temp);
}
public void transformInPlace(ADenseArrayVector v) {
double[] temp = new double[v.length()];
int rc = rowCount();
int cc = columnCount();
if (v.length()!=rc) throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(this,v));
if (rc != cc)
throw new UnsupportedOperationException(
"Cannot transform in place with a non-square transformation");
double[] data=v.getArray();
int offset=v.getArrayOffset();
for (int row = 0; row < rc; row++) {
temp[row] = getRow(row).dotProduct(data,offset);
}
v.setElements(temp);
}
@Override
public AVector getRow(int row) {
return getRowView(row);
}
@Override
public AVector getColumn(int column) {
return getColumnView(column);
}
@Override
public AVector getRowView(int row) {
return new MatrixRowView(this, row);
}
@Override
public AVector getColumnView(int column) {
return new MatrixColumnView(this, column);
}
@Override
public AVector getRowClone(int row) {
int cc=columnCount();
Vector result=Vector.createLength(cc);
this.copyRowTo(row,result.getArray(),0);
return result;
}
@Override
public AVector getColumnClone(int column) {
int rc=rowCount();
Vector result=Vector.createLength(rc);
this.copyColumnTo(column,result.getArray(),0);
return result;
}
/**
* Sets this matrix with the element values from another matrix. Source matrix must have the same shape.
* @param a
*/
public void set(AMatrix a) {
int rc = rowCount();
int cc = columnCount();
a.checkShape(rc,cc);
for (int i = 0; i < rc; i++) {
setRow(i,a.getRow(i));
}
}
@Override
public void set(INDArray a) {
if (a instanceof AMatrix) {set((AMatrix) a); return;}
if (a instanceof AVector) {set((AVector)a); return;}
if (a instanceof AScalar) {set(a.get()); return;}
// fall back to default impl
super.set(a);
}
/**
* Sets every row of this matrix with the element values from a vector.
* @param a
*/
public void set(AVector v) {
int rc=rowCount();
for (int i=0; i elementIterator() {
return new MatrixElementIterator(this);
}
@Override
public boolean isBoolean() {
int rc=rowCount();
for (int i=0; i iterator() {
return new MatrixRowIterator(this);
}
@Override
public final boolean epsilonEquals(INDArray a, double epsilon) {
if (a instanceof AMatrix) {
return epsilonEquals((AMatrix) a,epsilon);
} if (a.dimensionality()!=2) {
return false;
} else {
int sc=rowCount();
if (a.sliceCount()!=sc) return false;
for (int i=0; iConstants.PRINT_THRESHOLD) {
Index shape=Index.create(getShape());
return "Large matrix with shape: "+shape.toString();
}
return toStringFull();
}
@Override
public String toStringFull() {
StringBuilder sb = new StringBuilder();
int rc = rowCount();
sb.append("[");
for (int i = 0; i < rc; i++) {
if (i>0) sb.append(",\n");
sb.append(getRow(i).toString());
}
sb.append("]");
return sb.toString();
}
@Override
public int hashCode() {
// hashcode is hashcode of all doubles, row by row
int hashCode = 1;
int rc = rowCount();
int cc = columnCount();
for (int i = 0; i < rc; i++) {
for (int j = 0; j < cc; j++) {
hashCode = 31 * hashCode + (Hash.hashCode(unsafeGet(i, j)));
}
}
return hashCode;
}
/**
* Returns the matrix values as a single reference Vector in the form [row0
* row1 row2....]
*
* @return
*/
@Override
public AVector asVector() {
int rc = rowCount();
if (rc == 0) return Vector0.INSTANCE;
if (rc == 1) return getRowView(0);
int cc= columnCount();
if (cc==1) return getColumnView(0);
return new MatrixViewVector(this);
}
@Override
public List asElementList() {
return asVector().asElementList();
}
@Override
public AMatrix innerProduct(AMatrix a) {
return Multiplications.multiply(this, a);
}
@Override
public Vector innerProduct(Vector v) {
int cc=this.columnCount();
int rc=this.rowCount();
v.checkLength(cc);
Vector r=Vector.createLength(rc);
for (int i=0; i al=getRows();
List rl=new ArrayList(rc);
for (AVector v: al ) {
rl.add(v.innerProduct(a));
}
return SliceArray.create(rl);
}
@Override
public INDArray outerProduct(INDArray a) {
ArrayList al=new ArrayList(sliceCount());
for (AVector s:this) {
al.add(s.outerProduct(a));
}
return Arrayz.create(al);
}
/**
* Returns the dot product of a specific row with a vector.
*
* Unsafe operation: performs no bounds checking
* @param i
* @param a
* @return
*/
public double rowDotProduct(int i, AVector a) {
return getRow(i).dotProduct(a);
}
/**
* Computes the inverse of a matrix. Returns null if the matrix is singular.
*
* Throws an Exception is the matrix is not square
* @param m
* @return
*/
@Override
public AMatrix inverse() {
return Inverse.calculate(this);
}
@Override
public double trace() {
int rc=Math.min(rowCount(), columnCount());
double result=0.0;
for (int i=0; i2) throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(this, a));
if (dims==0) {
add(a.get());
} else if (dims==1) {
add(Vectorz.toVector(a));
} else if (dims==2) {
add(Matrixx.toMatrix(a));
}
}
}
/**
* Multiply each row in this matrix by a vector. Mutates this matrix.
* @param v
*/
public void multiply(AVector v) {
int rc = rowCount();
for (int i = 0; i < rc; i++) {
getRowView(i).multiply(v);
}
}
@Override
public void multiply(INDArray a) {
if (a instanceof AMatrix) {
multiply((AMatrix)a);
} else if (a instanceof AVector) {
multiply((AVector)a);
} else{
int dims=a.dimensionality();
if (dims==0) {
multiply(a.get());
} else if (dims==1) {
multiply(Vectorz.toVector(a));
} else if (dims==2) {
multiply(Matrixx.toMatrix(a));
} else {
throw new IllegalArgumentException(ErrorMessages.incompatibleShapes(this,a));
}
}
}
/**
* Divides every row of this matrix by the given vector
* @param v
*/
public void divide(AVector v) {
int rc=rowCount();
for (int i=0; i2) r=r.broadcastLike(target);
return r.clone();
}
@Override
public INDArray broadcastCopyLike(INDArray target) {
INDArray r=this.copy();
if (target.dimensionality()>2) r=r.broadcastLike(target);
return r;
}
/**
* Returns true if the matrix is the zero matrix (all components zero)
*/
@Override
public boolean isZero() {
int rc=rowCount();
for (int i=0; i j
*/
public boolean isUpperTriangular() {
int rc=rowCount();
int cc=columnCount();
for (int i=1; i0)?band:0;
}
protected final static int bandLength(int rc, int cc, int band) {
if (band>0) {
return (band0; w--) {
if (!getBand(w).isZero()) return w;
}
return 0;
}
/**
* Computes the lower bandwidth of a matrix, i.e. the number of bands below the leading diagonal
* that cover all non-zero values
* @return
*/
public int lowerBandwidth() {
for (int w=lowerBandwidthLimit(); w>0; w--) {
if (!getBand(-w).isZero()) return w;
}
return 0;
}
/**
* Gets a specific band of the matrix, as a view vector. The band is truncated at the edges of the
* matrix, i.e. it does not wrap around the matrix.
*
* @param band
* @return
*/
@Override
public AVector getBand(int band) {
return MatrixBandView.create(this,band);
}
public AVector getBandWrapped(int band) {
AVector result=Vector0.INSTANCE;
int rc=rowCount();
int cc=columnCount();
if (rc0) si-=rc;
for (;si-rc; si-=cc) {
result=result.join(getBand(si));
}
}
return result;
}
/**
* Sets a row in a matrix to the value specified by the given vector
*
* @param i
* @param row
*/
public void setRow(int i, AVector row) {
getRowView(i).set(row);
}
/**
* Replaces a row in a matrix, adding the row to the internal structure of the matrix.
*
* Will throw UnsupportedOperationException if not possible for the given matrix type.
*
* @param i
* @param row
*/
public void replaceRow(int i, AVector row) {
throw new UnsupportedOperationException("replaceRow not supported for "+this.getClass()+". Consider using an AVectorMatrix or SparseRowMatrix instance instead.");
}
/**
* Replaces a column in a matrix, adding the column to the internal structure of the matrix.
*
* Will throw UnsupportedOperationException if not possible for the given matrix type.
*
* @param i
* @param row
*/
public void replaceColumn(int i, AVector row) {
throw new UnsupportedOperationException("replaceColumn not supported for "+this.getClass()+". Consider using a SparseColumnMatrix instance instead.");
}
/**
* Sets a column in a matrix.
*
* @param i
* @param row
*/
public void setColumn(int i, AVector col) {
getColumnView(i).set(col);
}
@Override
public abstract AMatrix exactClone();
@Override
public INDArray immutable() {
if (!isMutable()) return this;
return ImmutableMatrix.create(this);
}
@Override
public AMatrix mutable() {
if (isFullyMutable()) return this;
return clone();
}
@Override
public AMatrix sparse() {
if (this instanceof ISparse) return this;
return Matrixx.createSparse(this);
}
@Override
public AMatrix dense() {
if (this instanceof IDense) return this;
return Matrix.create(this);
}
@Override
public final Matrix denseClone() {
return Matrix.create(this);
}
@Override
public void validate() {
if (((long)rowCount())*columnCount()!=elementCount()) throw new VectorzException("Invalid Array shape?");
super.validate();
}
/**
* Copies the elements in a selected row of this matrix to a double array
* @param i The index of the selected row
* @param dest Destination double[] array
* @param destOffset Offset into destination array
*/
public void copyRowTo(int i, double[] dest, int destOffset) {
// note: using getRow() may be faster when overriding
int cc=columnCount();
for (int j=0; j= 2))
throw new IndexOutOfBoundsException(ErrorMessages.invalidDimension(this,dimension));
}
/**
* Checks if this matrix has the same shape as another matrix. Throws an exception if not.
* @param m
*/
protected void checkSameShape(AMatrix m) {
if((rowCount()!=m.rowCount())||(columnCount()!=m.columnCount())) {
throw new IndexOutOfBoundsException(ErrorMessages.mismatch(this, m));
}
}
/**
* Checks if this matrix has the same shape as another matrix. Throws an exception if not.
* @param m
*/
protected void checkSameShape(ARectangularMatrix m) {
if((rowCount()!=m.rowCount())||(columnCount()!=m.columnCount())) {
throw new IndexOutOfBoundsException(ErrorMessages.mismatch(this, m));
}
}
/**
* Checks if this matrix has the specified shape. Throws an exception if not.
*/
protected void checkShape(int rows, int cols) {
int rc=rowCount();
int cc=columnCount();
if((rc!=rows)||(cc!=cols)) {
throw new IllegalArgumentException("Unexpected shape: ["+cc+","+rc+"] expected: ["+rows+","+cols+"]");
}
}
/**
* Checks if the given index is valid for this matrix. Throws an exception if not.
*/
protected void checkIndex(int i, int j) {
int rc=rowCount();
int cc=columnCount();
if ((i<0)||(i>=rc)||(j<0)||(j>=cc)) {
throw new IndexOutOfBoundsException(ErrorMessages.invalidIndex(this, i,j));
}
}
/**
* Checks if the given column index is valid for this matrix. Throws an exception if not.
*
* @return the number of columns in this matrix
*/
public int checkColumn(int column) {
int cc=columnCount();
if ((column<0)||(column>=cc)) throw new IndexOutOfBoundsException(ErrorMessages.invalidSlice(this, 1, column));
return cc;
}
/**
* Checks if the given rows index is valid for this matrix. Throws an exception if not.
*
* @return the number of rows in this matrix
*/
public int checkRow(int row) {
int rc=rowCount();
if ((row<0)||(row>=rc)) throw new IndexOutOfBoundsException(ErrorMessages.invalidSlice(this, 0, row));
return rc;
}
@Override
public void add2(AMatrix a, AMatrix b) {
add(a);
add(b);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy