org.carrot2.mahout.math.matrix.impl.SparseDoubleMatrix2D Maven / Gradle / Ivy
/* Imported from Mahout. */package org.carrot2.mahout.math.matrix.impl;
import org.carrot2.mahout.math.function.DoubleDoubleFunction;
import org.carrot2.mahout.math.function.DoubleFunction;
import org.carrot2.mahout.math.function.Functions;
import org.carrot2.mahout.math.function.IntDoubleProcedure;
import org.carrot2.mahout.math.function.IntIntDoubleFunction;
import org.carrot2.mahout.math.function.Mult;
import org.carrot2.mahout.math.function.PlusMult;
import org.carrot2.mahout.math.map.AbstractIntDoubleMap;
import org.carrot2.mahout.math.map.OpenIntDoubleHashMap;
import org.carrot2.mahout.math.matrix.DoubleMatrix1D;
import org.carrot2.mahout.math.matrix.DoubleMatrix2D;
public final class SparseDoubleMatrix2D extends DoubleMatrix2D {
/*
* The elements of the matrix.
*/
final AbstractIntDoubleMap elements;
public SparseDoubleMatrix2D(double[][] values) {
this(values.length, values.length == 0 ? 0 : values[0].length);
assign(values);
}
public SparseDoubleMatrix2D(int rows, int columns) {
this(rows, columns, rows * (columns / 1000), 0.2, 0.5);
}
public SparseDoubleMatrix2D(int rows, int columns, int initialCapacity, double minLoadFactor, double maxLoadFactor) {
setUp(rows, columns);
this.elements = new OpenIntDoubleHashMap(initialCapacity, minLoadFactor, maxLoadFactor);
}
@Override
public DoubleMatrix2D assign(double value) {
// overriden for performance only
if (this.isNoView && value == 0) {
this.elements.clear();
} else {
super.assign(value);
}
return this;
}
@Override
public void assign(DoubleFunction function) {
if (this.isNoView && function instanceof Mult) { // x[i] = mult*x[i]
this.elements.assign(function);
} else {
super.assign(function);
}
}
@Override
public DoubleMatrix2D assign(DoubleMatrix2D source) {
// overriden for performance only
if (!(source instanceof SparseDoubleMatrix2D)) {
return super.assign(source);
}
SparseDoubleMatrix2D other = (SparseDoubleMatrix2D) source;
if (other == this) {
return this;
} // nothing to do
checkShape(other);
if (this.isNoView && other.isNoView) { // quickest
this.elements.assign(other.elements);
return this;
}
return super.assign(source);
}
@Override
public DoubleMatrix2D assign(final DoubleMatrix2D y,
DoubleDoubleFunction function) {
if (!this.isNoView) {
return super.assign(y, function);
}
checkShape(y);
if (function instanceof PlusMult) { // x[i] = x[i] + alpha*y[i]
final double alpha = ((PlusMult) function).getMultiplicator();
if (alpha == 0) {
return this;
} // nothing to do
y.forEachNonZero(
new IntIntDoubleFunction() {
@Override
public double apply(int i, int j, double value) {
setQuick(i, j, getQuick(i, j) + alpha * value);
return value;
}
}
);
return this;
}
if (function == Functions.MULT) { // x[i] = x[i] * y[i]
this.elements.forEachPair(
new IntDoubleProcedure() {
@Override
public boolean apply(int key, double value) {
int i = key / columns;
int j = key % columns;
double r = value * y.getQuick(i, j);
if (r != value) {
elements.put(key, r);
}
return true;
}
}
);
}
if (function == Functions.DIV) { // x[i] = x[i] / y[i]
this.elements.forEachPair(
new IntDoubleProcedure() {
@Override
public boolean apply(int key, double value) {
int i = key / columns;
int j = key % columns;
double r = value / y.getQuick(i, j);
if (r != value) {
elements.put(key, r);
}
return true;
}
}
);
}
return super.assign(y, function);
}
@Override
public int cardinality() {
return this.isNoView ? this.elements.size() : super.cardinality();
}
@Override
public void ensureCapacity(int minCapacity) {
this.elements.ensureCapacity(minCapacity);
}
@Override
public void forEachNonZero(final org.carrot2.mahout.math.function.IntIntDoubleFunction function) {
if (this.isNoView) {
this.elements.forEachPair(
new IntDoubleProcedure() {
@Override
public boolean apply(int key, double value) {
int i = key / columns;
int j = key % columns;
double r = function.apply(i, j, value);
if (r != value) {
elements.put(key, r);
}
return true;
}
}
);
} else {
super.forEachNonZero(function);
}
}
@Override
public double getQuick(int row, int column) {
//if (debug) if (column<0 || column>=columns || row<0 || row>=rows)
// throw new IndexOutOfBoundsException("row:"+row+", column:"+column);
//return this.elements.get(index(row,column));
//manually inlined:
return this.elements.get(rowZero + row * rowStride + columnZero + column * columnStride);
}
@Override
protected boolean haveSharedCellsRaw(DoubleMatrix2D other) {
if (other instanceof SelectedSparseDoubleMatrix2D) {
SelectedSparseDoubleMatrix2D otherMatrix = (SelectedSparseDoubleMatrix2D) other;
return this.elements == otherMatrix.elements;
}
if (other instanceof SparseDoubleMatrix2D) {
SparseDoubleMatrix2D otherMatrix = (SparseDoubleMatrix2D) other;
return this.elements == otherMatrix.elements;
}
return false;
}
@Override
protected int index(int row, int column) {
// return super.index(row,column);
// manually inlined for speed:
return rowZero + row * rowStride + columnZero + column * columnStride;
}
@Override
public DoubleMatrix2D like(int rows, int columns) {
return new SparseDoubleMatrix2D(rows, columns);
}
@Override
public DoubleMatrix1D like1D(int size) {
return new SparseDoubleMatrix1D(size);
}
@Override
protected DoubleMatrix1D like1D(int size, int offset, int stride) {
return new SparseDoubleMatrix1D(size, this.elements, offset, stride);
}
@Override
public void setQuick(int row, int column, double value) {
//if (debug) if (column<0 || column>=columns || row<0 || row>=rows)
// throw new IndexOutOfBoundsException("row:"+row+", column:"+column);
//int index = index(row,column);
//manually inlined:
int index = rowZero + row * rowStride + columnZero + column * columnStride;
//if (value == 0 || Math.abs(value) < TOLERANCE)
if (value == 0) {
this.elements.removeKey(index);
} else {
this.elements.put(index, value);
}
}
@Override
protected DoubleMatrix2D viewSelectionLike(int[] rowOffsets, int[] columnOffsets) {
return new SelectedSparseDoubleMatrix2D(this.elements, rowOffsets, columnOffsets, 0);
}
@Override
public DoubleMatrix1D zMult(DoubleMatrix1D y, DoubleMatrix1D z, double alpha, double beta, final boolean transposeA) {
int m = rows;
int n = columns;
if (transposeA) {
m = columns;
n = rows;
}
boolean ignore = z == null;
if (ignore) {
z = new DenseDoubleMatrix1D(m);
}
if (!(this.isNoView && y instanceof DenseDoubleMatrix1D && z instanceof DenseDoubleMatrix1D)) {
return super.zMult(y, z, alpha, beta, transposeA);
}
if (n != y.size() || m > z.size()) {
throw new IllegalArgumentException("Incompatible args");
}
if (!ignore) {
z.assign(Functions.mult(beta / alpha));
}
DenseDoubleMatrix1D zz = (DenseDoubleMatrix1D) z;
final double[] zElements = zz.elements;
final int zStride = zz.stride;
final int zi = z.index(0);
DenseDoubleMatrix1D yy = (DenseDoubleMatrix1D) y;
final double[] yElements = yy.elements;
final int yStride = yy.stride;
final int yi = y.index(0);
if (yElements == null || zElements == null) {
throw new IllegalStateException();
}
this.elements.forEachPair(
new IntDoubleProcedure() {
@Override
public boolean apply(int key, double value) {
int i = key / columns;
int j = key % columns;
if (transposeA) {
int tmp = i;
i = j;
j = tmp;
}
zElements[zi + zStride * i] += value * yElements[yi + yStride * j];
return true;
}
}
);
if (alpha != 1.0) {
z.assign(Functions.mult(alpha));
}
return z;
}
@Override
public DoubleMatrix2D zMult(DoubleMatrix2D B, DoubleMatrix2D C, final double alpha, double beta,
final boolean transposeA, boolean transposeB) {
if (!this.isNoView) {
return super.zMult(B, C, alpha, beta, transposeA, transposeB);
}
if (transposeB) {
B = B.viewDice();
}
int m = rows;
int n = columns;
if (transposeA) {
m = columns;
n = rows;
}
int p = B.columns;
boolean ignore = C == null;
if (C == null) {
C = new DenseDoubleMatrix2D(m, p);
}
if (B.rows != n) {
throw new IllegalArgumentException("Matrix2D inner dimensions must agree");
}
if (C.rows != m || C.columns != p) {
throw new IllegalArgumentException("Incompatible result matrix");
}
if (this == C || B == C) {
throw new IllegalArgumentException("Matrices must not be identical");
}
if (!ignore) {
C.assign(Functions.mult(beta));
}
// cache views
final DoubleMatrix1D[] Brows = new DoubleMatrix1D[n];
for (int i = n; --i >= 0;) {
Brows[i] = B.viewRow(i);
}
final DoubleMatrix1D[] Crows = new DoubleMatrix1D[m];
for (int i = m; --i >= 0;) {
Crows[i] = C.viewRow(i);
}
final PlusMult fun = PlusMult.plusMult(0);
this.elements.forEachPair(
new IntDoubleProcedure() {
@Override
public boolean apply(int key, double value) {
int i = key / columns;
int j = key % columns;
fun.setMultiplicator(value * alpha);
if (transposeA) {
Crows[j].assign(Brows[i], fun);
} else {
Crows[i].assign(Brows[j], fun);
}
return true;
}
}
);
return C;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy