org.carrot2.mahout.math.AbstractMatrix Maven / Gradle / Ivy
/*
* Carrot2 project.
*
* Copyright (C) 2002-2016, Dawid Weiss, Stanisław Osiński.
* All rights reserved.
*
* Refer to the full license file "carrot2.LICENSE"
* in the root folder of the repository checkout or at:
* http://www.carrot2.org/carrot2.LICENSE
*/
package org.carrot2.mahout.math;
import java.util.Iterator;
import java.util.Map;
import org.carrot2.mahout.math.function.DoubleDoubleFunction;
import org.carrot2.mahout.math.function.DoubleFunction;
import org.carrot2.mahout.math.function.Functions;
import org.carrot2.mahout.math.function.PlusMult;
import org.carrot2.mahout.math.function.VectorFunction;
import org.carrot2.shaded.guava.common.collect.AbstractIterator;
import org.carrot2.shaded.guava.common.collect.Maps;
public abstract class AbstractMatrix implements Matrix {
protected Map columnLabelBindings;
protected Map rowLabelBindings;
protected int rows;
protected int columns;
protected AbstractMatrix(int rows, int columns) {
this.rows = rows;
this.columns = columns;
}
@Override
public int columnSize() {
return columns;
}
@Override
public int rowSize() {
return rows;
}
@Override
public Iterator iterator() {
return iterateAll();
}
@Override
public Iterator iterateAll() {
return new AbstractIterator() {
private int slice;
@Override
protected MatrixSlice computeNext() {
if (slice >= numSlices()) {
return endOfData();
}
int i = slice++;
return new MatrixSlice(viewRow(i), i);
}
};
}
@Override
public int numSlices() {
return numRows();
}
@Override
public double get(String rowLabel, String columnLabel) {
if (columnLabelBindings == null || rowLabelBindings == null) {
throw new IllegalStateException("Unbound label");
}
Integer row = rowLabelBindings.get(rowLabel);
Integer col = columnLabelBindings.get(columnLabel);
if (row == null || col == null) {
throw new IllegalStateException("Unbound label");
}
return get(row, col);
}
@Override
public Map getColumnLabelBindings() {
return columnLabelBindings;
}
@Override
public Map getRowLabelBindings() {
return rowLabelBindings;
}
@Override
public void set(String rowLabel, double[] rowData) {
if (columnLabelBindings == null) {
throw new IllegalStateException("Unbound label");
}
Integer row = rowLabelBindings.get(rowLabel);
if (row == null) {
throw new IllegalStateException("Unbound label");
}
set(row, rowData);
}
@Override
public void set(String rowLabel, int row, double[] rowData) {
if (rowLabelBindings == null) {
rowLabelBindings = Maps.newHashMap();
}
rowLabelBindings.put(rowLabel, row);
set(row, rowData);
}
@Override
public void set(String rowLabel, String columnLabel, double value) {
if (columnLabelBindings == null || rowLabelBindings == null) {
throw new IllegalStateException("Unbound label");
}
Integer row = rowLabelBindings.get(rowLabel);
Integer col = columnLabelBindings.get(columnLabel);
if (row == null || col == null) {
throw new IllegalStateException("Unbound label");
}
set(row, col, value);
}
@Override
public void set(String rowLabel, String columnLabel, int row, int column, double value) {
if (rowLabelBindings == null) {
rowLabelBindings = Maps.newHashMap();
}
rowLabelBindings.put(rowLabel, row);
if (columnLabelBindings == null) {
columnLabelBindings = Maps.newHashMap();
}
columnLabelBindings.put(columnLabel, column);
set(row, column, value);
}
@Override
public void setColumnLabelBindings(Map bindings) {
columnLabelBindings = bindings;
}
@Override
public void setRowLabelBindings(Map bindings) {
rowLabelBindings = bindings;
}
// index into int[2] for column value
public static final int COL = 1;
// index into int[2] for row value
public static final int ROW = 0;
@Override
public int numRows() {
return rowSize();
}
@Override
public int numCols() {
return columnSize();
}
@Override
public String asFormatString() {
return toString();
}
@Override
public Matrix assign(double value) {
int rows = rowSize();
int columns = columnSize();
for (int row = 0; row < rows; row++) {
for (int col = 0; col < columns; col++) {
setQuick(row, col, value);
}
}
return this;
}
@Override
public Matrix assign(double[][] values) {
int rows = rowSize();
if (rows != values.length) {
throw new CardinalityException(rows, values.length);
}
int columns = columnSize();
for (int row = 0; row < rows; row++) {
if (columns == values[row].length) {
for (int col = 0; col < columns; col++) {
setQuick(row, col, values[row][col]);
}
} else {
throw new CardinalityException(columns, values[row].length);
}
}
return this;
}
@Override
public Matrix assign(Matrix other, DoubleDoubleFunction function) {
int rows = rowSize();
if (rows != other.rowSize()) {
throw new CardinalityException(rows, other.rowSize());
}
int columns = columnSize();
if (columns != other.columnSize()) {
throw new CardinalityException(columns, other.columnSize());
}
for (int row = 0; row < rows; row++) {
for (int col = 0; col < columns; col++) {
setQuick(row, col, function.apply(getQuick(row, col), other.getQuick(
row, col)));
}
}
return this;
}
@Override
public Matrix assign(Matrix other) {
int rows = rowSize();
if (rows != other.rowSize()) {
throw new CardinalityException(rows, other.rowSize());
}
int columns = columnSize();
if (columns != other.columnSize()) {
throw new CardinalityException(columns, other.columnSize());
}
for (int row = 0; row < rows; row++) {
for (int col = 0; col < columns; col++) {
setQuick(row, col, other.getQuick(row, col));
}
}
return this;
}
@Override
public Matrix assign(DoubleFunction function) {
int rows = rowSize();
int columns = columnSize();
for (int row = 0; row < rows; row++) {
for (int col = 0; col < columns; col++) {
setQuick(row, col, function.apply(getQuick(row, col)));
}
}
return this;
}
@Override
public Vector aggregateRows(VectorFunction f) {
Vector r = new DenseVector(numRows());
int n = numRows();
for (int row = 0; row < n; row++) {
r.set(row, f.apply(viewRow(row)));
}
return r;
}
@Override
public Vector viewRow(int row) {
return new MatrixVectorView(this, row, 0, 0, 1);
}
@Override
public Vector viewColumn(int column) {
return new MatrixVectorView(this, 0, column, 1, 0);
}
@Override
public Vector viewDiagonal() {
return new MatrixVectorView(this, 0, 0, 1, 1);
}
@Override
public double aggregate(final DoubleDoubleFunction combiner, final DoubleFunction mapper) {
return aggregateRows(new VectorFunction() {
@Override
public double apply(Vector v) {
return v.aggregate(combiner, mapper);
}
}).aggregate(combiner, Functions.IDENTITY);
}
@Override
public Vector aggregateColumns(VectorFunction f) {
Vector r = new DenseVector(numCols());
for (int col = 0; col < numCols(); col++) {
r.set(col, f.apply(viewColumn(col)));
}
return r;
}
@Override
public double determinant() {
int rows = rowSize();
int columns = columnSize();
if (rows != columns) {
throw new CardinalityException(rows, columns);
}
if (rows == 2) {
return getQuick(0, 0) * getQuick(1, 1) - getQuick(0, 1) * getQuick(1, 0);
} else {
// TODO: this really should just be one line:
// TODO: new CholeskyDecomposition(this).getL().viewDiagonal().aggregate(Functions.TIMES)
int sign = 1;
double ret = 0;
for (int i = 0; i < columns; i++) {
Matrix minor = new DenseMatrix(rows - 1, columns - 1);
for (int j = 1; j < rows; j++) {
boolean flag = false; /* column offset flag */
for (int k = 0; k < columns; k++) {
if (k == i) {
flag = true;
continue;
}
minor.set(j - 1, flag ? k - 1 : k, getQuick(j, k));
}
}
ret += getQuick(0, i) * sign * minor.determinant();
sign *= -1;
}
return ret;
}
}
@Override
public Matrix clone() {
AbstractMatrix clone;
try {
clone = (AbstractMatrix) super.clone();
} catch (CloneNotSupportedException cnse) {
throw new IllegalStateException(cnse); // can't happen
}
if (rowLabelBindings != null) {
clone.rowLabelBindings = Maps.newHashMap(rowLabelBindings);
}
if (columnLabelBindings != null) {
clone.columnLabelBindings = Maps.newHashMap(columnLabelBindings);
}
return clone;
}
@Override
public Matrix divide(double x) {
Matrix result = like();
for (int row = 0; row < rowSize(); row++) {
for (int col = 0; col < columnSize(); col++) {
result.setQuick(row, col, getQuick(row, col) / x);
}
}
return result;
}
@Override
public double get(int row, int column) {
if (row < 0 || row >= rowSize()) {
throw new IndexException(row, rowSize());
}
if (column < 0 || column >= columnSize()) {
throw new IndexException(column, columnSize());
}
return getQuick(row, column);
}
@Override
public Matrix minus(Matrix other) {
int rows = rowSize();
if (rows != other.rowSize()) {
throw new CardinalityException(rows, other.rowSize());
}
int columns = columnSize();
if (columns != other.columnSize()) {
throw new CardinalityException(columns, other.columnSize());
}
Matrix result = like();
for (int row = 0; row < rows; row++) {
for (int col = 0; col < columns; col++) {
result.setQuick(row, col, getQuick(row, col)
- other.getQuick(row, col));
}
}
return result;
}
@Override
public Matrix plus(double x) {
Matrix result = like();
int rows = rowSize();
int columns = columnSize();
for (int row = 0; row < rows; row++) {
for (int col = 0; col < columns; col++) {
result.setQuick(row, col, getQuick(row, col) + x);
}
}
return result;
}
@Override
public Matrix plus(Matrix other) {
int rows = rowSize();
if (rows != other.rowSize()) {
throw new CardinalityException(rows, other.rowSize());
}
int columns = columnSize();
if (columns != other.columnSize()) {
throw new CardinalityException(columns, other.columnSize());
}
Matrix result = like();
for (int row = 0; row < rows; row++) {
for (int col = 0; col < columns; col++) {
result.setQuick(row, col, getQuick(row, col)
+ other.getQuick(row, col));
}
}
return result;
}
@Override
public void set(int row, int column, double value) {
if (row < 0 || row >= rowSize()) {
throw new IndexException(row, rowSize());
}
if (column < 0 || column >= columnSize()) {
throw new IndexException(column, columnSize());
}
setQuick(row, column, value);
}
@Override
public void set(int row, double[] data) {
int columns = columnSize();
if (columns < data.length) {
throw new CardinalityException(columns, data.length);
}
int rows = rowSize();
if (row < 0 || row >= rows) {
throw new IndexException(row, rowSize());
}
for (int i = 0; i < columns; i++) {
setQuick(row, i, data[i]);
}
}
@Override
public Matrix times(double x) {
Matrix result = like();
int rows = rowSize();
int columns = columnSize();
for (int row = 0; row < rows; row++) {
for (int col = 0; col < columns; col++) {
result.setQuick(row, col, getQuick(row, col) * x);
}
}
return result;
}
@Override
public Matrix times(Matrix other) {
int columns = columnSize();
if (columns != other.rowSize()) {
throw new CardinalityException(columns, other.rowSize());
}
int rows = rowSize();
int otherColumns = other.columnSize();
Matrix result = like(rows, otherColumns);
for (int row = 0; row < rows; row++) {
for (int col = 0; col < otherColumns; col++) {
double sum = 0.0;
for (int k = 0; k < columns; k++) {
sum += getQuick(row, k) * other.getQuick(k, col);
}
result.setQuick(row, col, sum);
}
}
return result;
}
@Override
public Vector times(Vector v) {
int columns = columnSize();
if (columns != v.size()) {
throw new CardinalityException(columns, v.size());
}
int rows = rowSize();
Vector w = new DenseVector(rows);
for (int row = 0; row < rows; row++) {
w.setQuick(row, v.dot(viewRow(row)));
}
return w;
}
@Override
public Vector timesSquared(Vector v) {
int columns = columnSize();
if (columns != v.size()) {
throw new CardinalityException(columns, v.size());
}
int rows = rowSize();
Vector w = new DenseVector(columns);
for (int i = 0; i < rows; i++) {
Vector xi = viewRow(i);
double d = xi.dot(v);
if (d != 0.0) {
w.assign(xi, new PlusMult(d));
}
}
return w;
}
@Override
public Matrix transpose() {
int rows = rowSize();
int columns = columnSize();
Matrix result = like(columns, rows);
for (int row = 0; row < rows; row++) {
for (int col = 0; col < columns; col++) {
result.setQuick(col, row, getQuick(row, col));
}
}
return result;
}
@Override
public Matrix viewPart(int rowOffset, int rowsRequested, int columnOffset, int columnsRequested) {
return viewPart(new int[]{rowOffset, columnOffset}, new int[]{rowsRequested, columnsRequested});
}
@Override
public double zSum() {
double result = 0;
for (int row = 0; row < rowSize(); row++) {
for (int col = 0; col < columnSize(); col++) {
result += getQuick(row, col);
}
}
return result;
}
@Override
public int[] getNumNondefaultElements() {
return new int[]{rowSize(), columnSize()};
}
protected class TransposeViewVector extends AbstractVector {
private final Matrix matrix;
private final int transposeOffset;
private final int numCols;
private final boolean rowToColumn;
protected TransposeViewVector(Matrix m, int offset) {
this(m, offset, true);
}
protected TransposeViewVector(Matrix m, int offset, boolean rowToColumn) {
super(rowToColumn ? m.numRows() : m.numCols());
matrix = m;
this.transposeOffset = offset;
this.rowToColumn = rowToColumn;
numCols = rowToColumn ? m.numCols() : m.numRows();
}
@Override
public Vector clone() {
Vector v = new DenseVector(size());
v.assign(this, Functions.PLUS);
return v;
}
@Override
public boolean isDense() {
return true;
}
@Override
public boolean isSequentialAccess() {
return true;
}
@Override
public Iterator iterator() {
return new AbstractIterator() {
private int i;
@Override
protected Element computeNext() {
if (i >= size()) {
return endOfData();
}
return getElement(i++);
}
};
}
@Override
public Iterator iterateNonZero() {
return iterator();
}
@Override
public Element getElement(final int i) {
return new Element() {
@Override
public double get() {
return getQuick(i);
}
@Override
public int index() {
return i;
}
@Override
public void set(double value) {
setQuick(i, value);
}
};
}
@Override
public double getQuick(int index) {
Vector v = rowToColumn ? matrix.viewColumn(index) : matrix.viewRow(index);
return v == null ? 0.0 : v.getQuick(transposeOffset);
}
@Override
public void setQuick(int index, double value) {
Vector v = rowToColumn ? matrix.viewColumn(index) : matrix.viewRow(index);
if (v == null) {
v = newVector(numCols);
if (rowToColumn) {
matrix.assignColumn(index, v);
} else {
matrix.assignRow(index, v);
}
}
v.setQuick(transposeOffset, value);
}
protected Vector newVector(int cardinality) {
return new DenseVector(cardinality);
}
@Override
public Vector like() {
return new DenseVector(size());
}
public Vector like(int cardinality) {
return new DenseVector(cardinality);
}
@Override
public int getNumNondefaultElements() {
return size();
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy