org.carrot2.mahout.math.matrix.DoubleMatrix2D Maven / Gradle / Ivy
/* Imported from Mahout. */package org.carrot2.mahout.math.matrix;
import org.carrot2.mahout.math.function.DoubleDoubleFunction;
import org.carrot2.mahout.math.function.DoubleFunction;
import org.carrot2.mahout.math.function.Functions;
import org.carrot2.mahout.math.function.IntIntDoubleFunction;
import org.carrot2.mahout.math.matrix.impl.AbstractMatrix2D;
import org.carrot2.mahout.math.matrix.impl.DenseDoubleMatrix1D;
import org.carrot2.mahout.math.matrix.impl.DenseDoubleMatrix2D;
public abstract class DoubleMatrix2D extends AbstractMatrix2D implements Cloneable {
protected DoubleMatrix2D() {
}
public double aggregate(DoubleDoubleFunction aggr,
DoubleFunction f) {
if (size() == 0) {
return Double.NaN;
}
double a = f.apply(getQuick(rows - 1, columns - 1));
int d = 1; // last cell already done
for (int row = rows; --row >= 0;) {
for (int column = columns - d; --column >= 0;) {
a = aggr.apply(a, f.apply(getQuick(row, column)));
}
d = 0;
}
return a;
}
public double aggregate(DoubleMatrix2D other, DoubleDoubleFunction aggr,
DoubleDoubleFunction f) {
checkShape(other);
if (size() == 0) {
return Double.NaN;
}
double a = f.apply(getQuick(rows - 1, columns - 1), other.getQuick(rows - 1, columns - 1));
int d = 1; // last cell already done
for (int row = rows; --row >= 0;) {
for (int column = columns - d; --column >= 0;) {
a = aggr.apply(a, f.apply(getQuick(row, column), other.getQuick(row, column)));
}
d = 0;
}
return a;
}
public void assign(double[][] values) {
if (values.length != rows) {
throw new IllegalArgumentException("Must have same number of rows: rows=" + values.length + "rows()=" + rows());
}
for (int row = rows; --row >= 0;) {
double[] currentRow = values[row];
if (currentRow.length != columns) {
throw new IllegalArgumentException(
"Must have same number of columns in every row: columns=" + currentRow.length + "columns()=" + columns());
}
for (int column = columns; --column >= 0;) {
setQuick(row, column, currentRow[column]);
}
}
}
public DoubleMatrix2D assign(double value) {
int r = rows;
int c = columns;
//for (int row=rows; --row >= 0;) {
// for (int column=columns; --column >= 0;) {
for (int row = 0; row < r; row++) {
for (int column = 0; column < c; column++) {
setQuick(row, column, value);
}
}
return this;
}
public void assign(DoubleFunction function) {
for (int row = rows; --row >= 0;) {
for (int column = columns; --column >= 0;) {
setQuick(row, column, function.apply(getQuick(row, column)));
}
}
}
public DoubleMatrix2D assign(DoubleMatrix2D other) {
if (other == this) {
return this;
}
checkShape(other);
if (haveSharedCells(other)) {
other = other.copy();
}
//for (int row=0; row= 0;) {
for (int column = columns; --column >= 0;) {
setQuick(row, column, other.getQuick(row, column));
}
}
return this;
}
public DoubleMatrix2D assign(DoubleMatrix2D y, DoubleDoubleFunction function) {
checkShape(y);
for (int row = rows; --row >= 0;) {
for (int column = columns; --column >= 0;) {
setQuick(row, column, function.apply(getQuick(row, column), y.getQuick(row, column)));
}
}
return this;
}
public int cardinality() {
int cardinality = 0;
for (int row = rows; --row >= 0;) {
for (int column = columns; --column >= 0;) {
if (getQuick(row, column) != 0) {
cardinality++;
}
}
}
return cardinality;
}
public DoubleMatrix2D copy() {
return like().assign(this);
}
public boolean equals(double value) {
return org.carrot2.mahout.math.matrix.linalg.Property.DEFAULT.equals(this, value);
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null) {
return false;
}
if (!(obj instanceof DoubleMatrix2D)) {
return false;
}
return org.carrot2.mahout.math.matrix.linalg.Property.DEFAULT.equals(this, (DoubleMatrix2D) obj);
}
public void forEachNonZero(IntIntDoubleFunction function) {
for (int row = rows; --row >= 0;) {
for (int column = columns; --column >= 0;) {
double value = getQuick(row, column);
if (value != 0) {
double r = function.apply(row, column, value);
if (r != value) {
setQuick(row, column, r);
}
}
}
}
}
public double get(int row, int column) {
if (column < 0 || column >= columns || row < 0 || row >= rows) {
throw new IndexOutOfBoundsException("row:" + row + ", column:" + column);
}
return getQuick(row, column);
}
protected DoubleMatrix2D getContent() {
return this;
}
public abstract double getQuick(int row, int column);
protected boolean haveSharedCells(DoubleMatrix2D other) {
if (other == null) {
return false;
}
if (this == other) {
return true;
}
return getContent().haveSharedCellsRaw(other.getContent());
}
protected boolean haveSharedCellsRaw(DoubleMatrix2D other) {
return false;
}
public DoubleMatrix2D like() {
return like(rows, columns);
}
public abstract DoubleMatrix2D like(int rows, int columns);
public abstract DoubleMatrix1D like1D(int size);
protected abstract DoubleMatrix1D like1D(int size, int zero, int stride);
public void set(int row, int column, double value) {
if (column < 0 || column >= columns || row < 0 || row >= rows) {
throw new IndexOutOfBoundsException("row:" + row + ", column:" + column);
}
setQuick(row, column, value);
}
public abstract void setQuick(int row, int column, double value);
public double[][] toArray() {
double[][] values = new double[rows][columns];
for (int row = rows; --row >= 0;) {
double[] currentRow = values[row];
for (int column = columns; --column >= 0;) {
currentRow[column] = getQuick(row, column);
}
}
return values;
}
protected DoubleMatrix2D view() {
try {
return (DoubleMatrix2D) clone();
} catch (CloneNotSupportedException cnse) {
throw new IllegalStateException();
}
}
public DoubleMatrix1D viewColumn(int column) {
checkColumn(column);
int viewSize = this.rows;
int viewZero = index(0, column);
int viewStride = this.rowStride;
return like1D(viewSize, viewZero, viewStride);
}
public DoubleMatrix2D viewColumnFlip() {
return (DoubleMatrix2D) view().vColumnFlip();
}
public DoubleMatrix2D viewDice() {
return (DoubleMatrix2D) view().vDice();
}
public DoubleMatrix2D viewPart(int row, int column, int height, int width) {
return (DoubleMatrix2D) view().vPart(row, column, height, width);
}
public DoubleMatrix1D viewRow(int row) {
checkRow(row);
int viewSize = this.columns;
int viewZero = index(row, 0);
int viewStride = this.columnStride;
return like1D(viewSize, viewZero, viewStride);
}
public DoubleMatrix2D viewRowFlip() {
return (DoubleMatrix2D) view().vRowFlip();
}
public DoubleMatrix2D viewSelection(int[] rowIndexes, int[] columnIndexes) {
// check for "all"
if (rowIndexes == null) {
rowIndexes = new int[rows];
for (int i = rows; --i >= 0;) {
rowIndexes[i] = i;
}
}
if (columnIndexes == null) {
columnIndexes = new int[columns];
for (int i = columns; --i >= 0;) {
columnIndexes[i] = i;
}
}
checkRowIndexes(rowIndexes);
checkColumnIndexes(columnIndexes);
int[] rowOffsets = new int[rowIndexes.length];
int[] columnOffsets = new int[columnIndexes.length];
for (int i = rowIndexes.length; --i >= 0;) {
rowOffsets[i] = rowOffset(rowRank(rowIndexes[i]));
}
for (int i = columnIndexes.length; --i >= 0;) {
columnOffsets[i] = columnOffset(columnRank(columnIndexes[i]));
}
return viewSelectionLike(rowOffsets, columnOffsets);
}
/*
public DoubleMatrix2D viewSelection(DoubleMatrix1DProcedure condition) {
IntArrayList matches = new IntArrayList();
for (int i = 0; i < rows; i++) {
if (condition.apply(viewRow(i))) {
matches.add(i);
}
}
matches.trimToSize();
return viewSelection(matches.elements(), null); // take all columns
}
*/
protected abstract DoubleMatrix2D viewSelectionLike(int[] rowOffsets, int[] columnOffsets);
public DoubleMatrix1D zMult(DoubleMatrix1D y, DoubleMatrix1D z, double alpha, double beta, boolean transposeA) {
if (transposeA) {
return viewDice().zMult(y, z, alpha, beta, false);
}
//boolean ignore = (z==null);
if (z == null) {
z = new DenseDoubleMatrix1D(this.rows);
}
if (columns != y.size() || rows > z.size()) {
throw new IllegalArgumentException("Incompatible args");
}
for (int i = rows; --i >= 0;) {
double s = 0;
for (int j = columns; --j >= 0;) {
s += getQuick(i, j) * y.getQuick(j);
}
z.setQuick(i, alpha * s + beta * z.getQuick(i));
}
return z;
}
public DoubleMatrix2D zMult(DoubleMatrix2D B, DoubleMatrix2D C, double alpha, double beta, boolean transposeA,
boolean transposeB) {
if (transposeA) {
return viewDice().zMult(B, C, alpha, beta, false, transposeB);
}
if (transposeB) {
return this.zMult(B.viewDice(), C, alpha, beta, transposeA, false);
}
int m = rows;
int n = columns;
int p = B.columns;
if (C == null) {
C = new DenseDoubleMatrix2D(m, p);
}
if (B.rows != n) {
throw new IllegalArgumentException("Matrix2D inner dimensions must agree");
}
if (C.rows != m || C.columns != p) {
throw new IllegalArgumentException("Incompatible result matrix");
}
if (this == C || B == C) {
throw new IllegalArgumentException("Matrices must not be identical");
}
for (int j = p; --j >= 0;) {
for (int i = m; --i >= 0;) {
double s = 0;
for (int k = n; --k >= 0;) {
s += getQuick(i, k) * B.getQuick(k, j);
}
C.setQuick(i, j, alpha * s + beta * C.getQuick(i, j));
}
}
return C;
}
public double zSum() {
if (size() == 0) {
return 0;
}
return aggregate(Functions.PLUS, Functions.IDENTITY);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy