org.ojalgo.matrix.store.SparseStore Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of ojalgo Show documentation
Show all versions of ojalgo Show documentation
oj! Algorithms - ojAlgo - is Open Source Java code that has to do with mathematics, linear algebra and optimisation.
/*
* Copyright 1997-2021 Optimatika
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
package org.ojalgo.matrix.store;
import static org.ojalgo.function.constant.PrimitiveMath.*;
import java.util.Arrays;
import org.ojalgo.ProgrammingError;
import org.ojalgo.array.SparseArray;
import org.ojalgo.array.SparseArray.NonzeroView;
import org.ojalgo.function.BinaryFunction;
import org.ojalgo.function.NullaryFunction;
import org.ojalgo.function.UnaryFunction;
import org.ojalgo.function.VoidFunction;
import org.ojalgo.function.aggregator.Aggregator;
import org.ojalgo.matrix.operation.MultiplyBoth;
import org.ojalgo.scalar.ComplexNumber;
import org.ojalgo.scalar.Quaternion;
import org.ojalgo.scalar.RationalNumber;
import org.ojalgo.scalar.Scalar;
import org.ojalgo.structure.Access1D;
import org.ojalgo.structure.Access2D;
import org.ojalgo.structure.ElementView2D;
import org.ojalgo.structure.Mutate1D;
import org.ojalgo.structure.Structure2D;
import org.ojalgo.type.NumberDefinition;
import org.ojalgo.type.context.NumberContext;
public final class SparseStore> extends FactoryStore implements TransformableRegion {
public interface Factory> {
SparseStore make(long rowsCount, long columnsCount);
default SparseStore make(final Structure2D shape) {
return this.make(shape.countRows(), shape.countColumns());
}
}
public static final SparseStore.Factory COMPLEX = SparseStore::makeComplex;
public static final SparseStore.Factory PRIMITIVE32 = SparseStore::makePrimitive32;
public static final SparseStore.Factory PRIMITIVE64 = SparseStore::makePrimitive;
public static final SparseStore.Factory QUATERNION = SparseStore::makeQuaternion;
public static final SparseStore.Factory RATIONAL = SparseStore::makeRational;
public static SparseStore makeComplex(final long rowsCount, final long columnsCount) {
return SparseStore.makeSparse(GenericStore.COMPLEX, rowsCount, columnsCount);
}
public static SparseStore makePrimitive(final long rowsCount, final long columnsCount) {
return SparseStore.makeSparse(Primitive64Store.FACTORY, rowsCount, columnsCount);
}
public static SparseStore makePrimitive32(final long rowsCount, final long columnsCount) {
return SparseStore.makeSparse(Primitive32Store.FACTORY, rowsCount, columnsCount);
}
public static SparseStore makeQuaternion(final long rowsCount, final long columnsCount) {
return SparseStore.makeSparse(GenericStore.QUATERNION, rowsCount, columnsCount);
}
public static SparseStore makeRational(final long rowsCount, final long columnsCount) {
return SparseStore.makeSparse(GenericStore.RATIONAL, rowsCount, columnsCount);
}
private static > void doGenericColumnAXPY(final SparseArray elements, final long colX, final long colY, final N a,
final TransformableRegion y) {
long structure = y.countRows();
long first = structure * colX;
long limit = first + structure;
elements.visitReferenceTypeNonzerosInRange(first, limit, (index, value) -> y.add(Structure2D.row(index, structure), colY, value.multiply(a)));
}
private static void doPrimitiveColumnAXPY(final SparseArray elements, final long colX, final long colY, final double a,
final TransformableRegion y) {
long structure = y.countRows();
long first = structure * colX;
long limit = first + structure;
elements.visitPrimitiveNonzerosInRange(first, limit, (index, value) -> y.add(Structure2D.row(index, structure), colY, a * value));
}
static > SparseStore makeSparse(final PhysicalStore.Factory physical, final long numberOfRows,
final long numberOfColumns) {
return new SparseStore<>(physical, Math.toIntExact(numberOfRows), Math.toIntExact(numberOfColumns));
}
static > SparseStore makeSparse(final PhysicalStore.Factory physical, final Structure2D shape) {
return SparseStore.makeSparse(physical, shape.countRows(), shape.countColumns());
}
static > void multiply(final SparseStore left, final SparseStore right, final TransformableRegion target) {
target.reset();
if (left.isPrimitive()) {
SparseArray tmpLeft = (SparseArray) left.getElements();
TransformableRegion tmpTarget = (TransformableRegion) target;
right.nonzeros().stream().forEach(element -> {
SparseStore.doPrimitiveColumnAXPY(tmpLeft, element.row(), element.column(), element.doubleValue(), tmpTarget);
});
} else if (left.getComponentType().isAssignableFrom(ComplexNumber.class)) {
SparseArray tmpLeft = (SparseArray) left.getElements();
SparseStore tmpRight = (SparseStore) right;
TransformableRegion tmpTarget = (TransformableRegion) target;
tmpRight.nonzeros().stream().forEach(element -> {
SparseStore.doGenericColumnAXPY(tmpLeft, element.row(), element.column(), element.get(), tmpTarget);
});
} else if (left.getComponentType().isAssignableFrom(RationalNumber.class)) {
SparseArray tmpLeft = (SparseArray) left.getElements();
SparseStore tmpRight = (SparseStore) right;
TransformableRegion tmpTarget = (TransformableRegion) target;
tmpRight.nonzeros().stream().forEach(element -> {
SparseStore.doGenericColumnAXPY(tmpLeft, element.row(), element.column(), element.get(), tmpTarget);
});
} else if (left.getComponentType().isAssignableFrom(Quaternion.class)) {
SparseArray tmpLeft = (SparseArray) left.getElements();
SparseStore tmpRight = (SparseStore) right;
TransformableRegion tmpTarget = (TransformableRegion) target;
tmpRight.nonzeros().stream().forEach(element -> {
SparseStore.doGenericColumnAXPY(tmpLeft, element.row(), element.column(), element.get(), tmpTarget);
});
} else {
throw new IllegalStateException("Unsupported element type!");
}
}
private final SparseArray myElements;
private final int[] myFirsts;
private final int[] myLimits;
private TransformableRegion.FillByMultiplying myMultiplyer;
SparseStore(final PhysicalStore.Factory factory, final int rowsCount, final int columnsCount) {
super(factory, rowsCount, columnsCount);
myElements = SparseArray.factory(factory.array()).limit(this.count()).initial(Math.max(rowsCount, columnsCount)).make();
myFirsts = new int[rowsCount];
myLimits = new int[rowsCount];
Arrays.fill(myFirsts, columnsCount);
// Arrays.fill(myLimits, 0); // Behövs inte, redan 0
Class extends Comparable> tmpType = factory.scalar().zero().get().getClass();
if (tmpType.equals(Double.class)) {
myMultiplyer = (TransformableRegion.FillByMultiplying) MultiplyBoth.newPrimitive64(rowsCount, columnsCount);
} else {
myMultiplyer = (TransformableRegion.FillByMultiplying) MultiplyBoth.newGeneric(rowsCount, columnsCount);
}
}
public void add(final long row, final long col, final Comparable> addend) {
synchronized (myElements) {
myElements.add(Structure2D.index(myFirsts.length, row, col), addend);
}
this.updateNonZeros(row, col);
}
public void add(final long row, final long col, final double addend) {
synchronized (myElements) {
myElements.add(Structure2D.index(myFirsts.length, row, col), addend);
}
this.updateNonZeros(row, col);
}
public double doubleValue(final long row, final long col) {
return myElements.doubleValue(Structure2D.index(myFirsts.length, row, col));
}
@Override
public boolean equals(final Object obj) {
if (this == obj) {
return true;
}
if (!super.equals(obj) || !(obj instanceof SparseStore)) {
return false;
}
SparseStore> other = (SparseStore>) obj;
if (myElements == null) {
if (other.myElements != null) {
return false;
}
} else if (!myElements.equals(other.myElements)) {
return false;
}
if (!Arrays.equals(myFirsts, other.myFirsts) || !Arrays.equals(myLimits, other.myLimits)) {
return false;
}
return true;
}
public void fillByMultiplying(final Access1D left, final Access1D right) {
int complexity = Math.toIntExact(left.count() / this.countRows());
if (complexity != Math.toIntExact(right.count() / this.countColumns())) {
ProgrammingError.throwForMultiplicationNotPossible();
}
myMultiplyer.invoke(this, left, complexity, right);
}
public void fillOne(final long row, final long col, final Access1D> values, final long valueIndex) {
this.set(row, col, values.get(valueIndex));
}
public void fillOne(final long row, final long col, final N value) {
synchronized (myElements) {
myElements.fillOne(Structure2D.index(myFirsts.length, row, col), value);
}
this.updateNonZeros(row, col);
}
public void fillOne(final long row, final long col, final NullaryFunction> supplier) {
synchronized (myElements) {
myElements.fillOne(Structure2D.index(myFirsts.length, row, col), supplier);
}
this.updateNonZeros(row, col);
}
public int firstInColumn(final int col) {
long structure = myFirsts.length;
long rangeFirst = structure * col;
long rangeLimit = structure * (col + 1);
long firstInRange = myElements.firstInRange(rangeFirst, rangeLimit);
if (rangeFirst == firstInRange) {
return 0;
}
return (int) (firstInRange % structure);
}
public int firstInRow(final int row) {
return myFirsts[row];
}
public N get(final long row, final long col) {
return myElements.get(Structure2D.index(myFirsts.length, row, col));
}
@Override
public int hashCode() {
int prime = 31;
int result = super.hashCode();
result = prime * result + (myElements == null ? 0 : myElements.hashCode());
result = prime * result + Arrays.hashCode(myFirsts);
result = prime * result + Arrays.hashCode(myLimits);
return result;
}
@Override
public int limitOfColumn(final int col) {
long structure = myFirsts.length;
long rangeFirst = structure * col;
long rangeLimit = rangeFirst + structure;
long limitOfRange = myElements.limitOfRange(rangeFirst, rangeLimit);
if (rangeLimit == limitOfRange) {
return (int) structure;
}
return (int) (limitOfRange % structure);
}
@Override
public int limitOfRow(final int row) {
return myLimits[row];
}
public void modifyAll(final UnaryFunction modifier) {
long tmpLimit = this.count();
if (this.isPrimitive()) {
for (long i = 0L; i < tmpLimit; i++) {
this.set(i, modifier.invoke(this.doubleValue(i)));
}
} else {
for (long i = 0L; i < tmpLimit; i++) {
this.set(i, modifier.invoke(this.get(i)));
}
}
}
public void modifyMatching(final Access1D left, final BinaryFunction function) {
long limit = Math.min(left.count(), this.count());
boolean notModifiesZero = function.invoke(E, ZERO) == ZERO;
if (this.isPrimitive()) {
if (notModifiesZero) {
for (NonzeroView element : myElements.nonzeros()) {
element.modify(left.doubleValue(element.index()), function);
}
} else {
for (long i = 0L; i < limit; i++) {
this.set(i, function.invoke(left.doubleValue(i), this.doubleValue(i)));
}
}
} else if (notModifiesZero) {
for (NonzeroView element : myElements.nonzeros()) {
element.modify(left.get(element.index()), function);
}
} else {
for (long i = 0L; i < limit; i++) {
this.set(i, function.invoke(left.get(i), this.get(i)));
}
}
}
public void modifyMatching(final BinaryFunction function, final Access1D right) {
long limit = Math.min(this.count(), right.count());
boolean notModifiesZero = function.invoke(ZERO, E) == ZERO;
if (this.isPrimitive()) {
if (notModifiesZero) {
for (NonzeroView element : myElements.nonzeros()) {
element.modify(function, right.doubleValue(element.index()));
}
} else {
for (long i = 0L; i < limit; i++) {
this.set(i, function.invoke(this.doubleValue(i), right.doubleValue(i)));
}
}
} else if (notModifiesZero) {
for (NonzeroView element : myElements.nonzeros()) {
element.modify(function, right.get(element.index()));
}
} else {
for (long i = 0L; i < limit; i++) {
this.set(i, function.invoke(this.get(i), right.get(i)));
}
}
}
public void modifyOne(final long row, final long col, final UnaryFunction modifier) {
if (this.isPrimitive()) {
this.set(row, col, modifier.invoke(this.doubleValue(row, col)));
} else {
this.set(row, col, modifier.invoke(this.get(row, col)));
}
}
public void multiply(final Access1D right, final TransformableRegion target) {
if (right instanceof SparseStore>) {
SparseStore.multiply(this, (SparseStore) right, target);
} else if (this.isPrimitive()) {
long complexity = this.countColumns();
long numberOfColumns = target.countColumns();
target.reset();
this.nonzeros().stream().forEach(element -> {
long row = element.row();
long col = element.column();
double value = element.doubleValue();
long first = MatrixStore.firstInRow(right, col, 0L);
long limit = MatrixStore.limitOfRow(right, col, numberOfColumns);
for (long j = first; j < limit; j++) {
long index = Structure2D.index(complexity, col, j);
double addition = value * right.doubleValue(index);
if (NumberContext.compare(addition, ZERO) != 0) {
target.add(row, j, addition);
}
}
});
} else {
super.multiply(right, target);
}
}
public MatrixStore multiply(final double scalar) {
SparseStore retVal = SparseStore.makeSparse(this.physical(), this);
if (this.isPrimitive()) {
for (ElementView2D nonzero : this.nonzeros()) {
retVal.set(nonzero.index(), nonzero.doubleValue() * scalar);
}
} else {
Scalar sclr = this.physical().scalar().convert(scalar);
for (ElementView2D nonzero : this.nonzeros()) {
retVal.set(nonzero.index(), sclr.multiply(nonzero.get()).get());
}
}
return retVal;
}
public MatrixStore multiply(final MatrixStore right) {
long numberOfRows = this.countRows();
long numberOfColumns = right.countColumns();
if (right instanceof SparseStore) {
SparseStore retVal = SparseStore.makeSparse(this.physical(), numberOfRows, numberOfColumns);
SparseStore.multiply(this, (SparseStore) right, retVal);
return retVal;
}
PhysicalStore retVal = this.physical().make(numberOfRows, numberOfColumns);
this.multiply(right, retVal);
return retVal;
}
public MatrixStore multiply(final N scalar) {
SparseStore retVal = SparseStore.makeSparse(this.physical(), this);
if (this.isPrimitive()) {
double sclr = NumberDefinition.doubleValue(scalar);
for (ElementView2D nonzero : this.nonzeros()) {
retVal.set(nonzero.index(), nonzero.doubleValue() * sclr);
}
} else {
Scalar sclr = this.physical().scalar().convert(scalar);
for (ElementView2D nonzero : this.nonzeros()) {
retVal.set(nonzero.index(), sclr.multiply(nonzero.get()).get());
}
}
return retVal;
}
@Override
public N multiplyBoth(final Access1D leftAndRight) {
// TODO Auto-generated method stub
return super.multiplyBoth(leftAndRight);
}
public ElementView2D nonzeros() {
return new Access2D.ElementView<>(myElements.nonzeros(), this.countRows());
}
public ElementsSupplier premultiply(final Access1D left) {
long complexity = this.countRows();
long numberOfColumns = this.countColumns();
long numberOfRows = left.count() / complexity;
if (left instanceof SparseStore>) {
SparseStore retVal = SparseStore.makeSparse(this.physical(), numberOfRows, numberOfColumns);
SparseStore.multiply((SparseStore) left, this, retVal);
return retVal;
}
if (!this.isPrimitive()) {
return super.premultiply(left);
}
SparseStore retVal = SparseStore.makeSparse(this.physical(), numberOfRows, numberOfColumns);
this.nonzeros().stream().forEach(element -> {
long row = element.row();
long col = element.column();
double value = element.doubleValue();
long first = MatrixStore.firstInColumn(left, row, 0L);
long limit = MatrixStore.limitOfColumn(left, row, numberOfRows);
for (long i = first; i < limit; i++) {
long index = Structure2D.index(numberOfRows, i, row);
double addition = value * left.doubleValue(index);
if (NumberContext.compare(addition, ZERO) != 0) {
retVal.add(i, col, addition);
}
}
});
return retVal;
}
@SuppressWarnings("unchecked")
public void reduceColumns(final Aggregator aggregator, final Mutate1D receiver) {
if (aggregator == Aggregator.SUM && receiver instanceof Mutate1D.Modifiable) {
if (this.isPrimitive()) {
this.nonzeros().forEach(element -> ((Modifiable>) receiver).add(element.column(), element.doubleValue()));
} else {
this.nonzeros().forEach(element -> ((Modifiable) receiver).add(element.column(), element.get()));
}
} else {
super.reduceColumns(aggregator, receiver);
}
}
@SuppressWarnings("unchecked")
public void reduceRows(final Aggregator aggregator, final Mutate1D receiver) {
if (aggregator == Aggregator.SUM && receiver instanceof Mutate1D.Modifiable) {
if (this.isPrimitive()) {
this.nonzeros().forEach(element -> ((Modifiable>) receiver).add(element.row(), element.doubleValue()));
} else {
this.nonzeros().forEach(element -> ((Modifiable) receiver).add(element.row(), element.get()));
}
} else {
super.reduceColumns(aggregator, receiver);
}
}
public TransformableRegion regionByColumns(final int... columns) {
return new Subregion2D.ColumnsRegion<>(this, myMultiplyer, columns);
}
public TransformableRegion regionByLimits(final int rowLimit, final int columnLimit) {
return new Subregion2D.LimitRegion<>(this, myMultiplyer, rowLimit, columnLimit);
}
public TransformableRegion regionByOffsets(final int rowOffset, final int columnOffset) {
return new Subregion2D.OffsetRegion<>(this, myMultiplyer, rowOffset, columnOffset);
}
public TransformableRegion regionByRows(final int... rows) {
return new Subregion2D.RowsRegion<>(this, myMultiplyer, rows);
}
public TransformableRegion regionByTransposing() {
return new Subregion2D.TransposedRegion<>(this, myMultiplyer);
}
public void reset() {
myElements.reset();
Arrays.fill(myFirsts, this.getColDim());
Arrays.fill(myLimits, 0);
}
public void set(final long row, final long col, final Comparable> value) {
synchronized (myElements) {
myElements.set(Structure2D.index(myFirsts.length, row, col), value);
}
this.updateNonZeros(row, col);
}
public void set(final long row, final long col, final double value) {
synchronized (myElements) {
myElements.set(Structure2D.index(myFirsts.length, row, col), value);
}
this.updateNonZeros(row, col);
}
public void supplyTo(final TransformableRegion receiver) {
receiver.reset();
myElements.supplyNonZerosTo(receiver);
}
public void visitColumn(final long row, final long col, final VoidFunction visitor) {
long structure = this.countRows();
long first = Structure2D.index(structure, row, col);
long limit = Structure2D.index(structure, 0, col + 1L);
myElements.visitRange(first, limit, visitor);
}
public void visitRow(final long row, final long col, final VoidFunction visitor) {
int counter = 0;
if (this.isPrimitive()) {
for (ElementView2D nzv : this.nonzeros()) {
if (nzv.row() == row) {
visitor.accept(nzv.doubleValue());
counter++;
}
}
} else {
for (ElementView2D nzv : this.nonzeros()) {
if (nzv.row() == row) {
visitor.accept(nzv.get());
counter++;
}
}
}
if (col + counter < this.countColumns()) {
visitor.accept(0.0);
}
}
private void updateNonZeros(final long row, final long col) {
this.updateNonZeros((int) row, (int) col);
}
SparseArray getElements() {
return myElements;
}
void updateNonZeros(final int row, final int col) {
myFirsts[row] = Math.min(col, myFirsts[row]);
myLimits[row] = Math.max(col + 1, myLimits[row]);
}
}