All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.ojalgo.matrix.store.SparseStore Maven / Gradle / Ivy

Go to download

oj! Algorithms - ojAlgo - is Open Source Java code that has to do with mathematics, linear algebra and optimisation.

There is a newer version: 55.0.1
Show newest version
/*
 * Copyright 1997-2022 Optimatika
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 * SOFTWARE.
 */
package org.ojalgo.matrix.store;

import static org.ojalgo.function.constant.PrimitiveMath.*;

import java.util.Arrays;

import org.ojalgo.ProgrammingError;
import org.ojalgo.array.SparseArray;
import org.ojalgo.array.SparseArray.NonzeroView;
import org.ojalgo.function.BinaryFunction;
import org.ojalgo.function.NullaryFunction;
import org.ojalgo.function.UnaryFunction;
import org.ojalgo.function.VoidFunction;
import org.ojalgo.function.aggregator.Aggregator;
import org.ojalgo.matrix.operation.MultiplyBoth;
import org.ojalgo.scalar.ComplexNumber;
import org.ojalgo.scalar.Quaternion;
import org.ojalgo.scalar.RationalNumber;
import org.ojalgo.scalar.Scalar;
import org.ojalgo.structure.Access1D;
import org.ojalgo.structure.Access2D;
import org.ojalgo.structure.ElementView2D;
import org.ojalgo.structure.Mutate1D;
import org.ojalgo.structure.Structure2D;
import org.ojalgo.type.NumberDefinition;
import org.ojalgo.type.context.NumberContext;

public final class SparseStore> extends FactoryStore implements TransformableRegion {

    public interface Factory> {

        SparseStore make(long rowsCount, long columnsCount);

        default SparseStore make(final Structure2D shape) {
            return this.make(shape.countRows(), shape.countColumns());
        }

    }

    public static final SparseStore.Factory COMPLEX = SparseStore::makeComplex;
    public static final SparseStore.Factory PRIMITIVE32 = SparseStore::makePrimitive32;
    public static final SparseStore.Factory PRIMITIVE64 = SparseStore::makePrimitive;
    public static final SparseStore.Factory QUATERNION = SparseStore::makeQuaternion;
    public static final SparseStore.Factory RATIONAL = SparseStore::makeRational;

    public static SparseStore makeComplex(final long rowsCount, final long columnsCount) {
        return SparseStore.makeSparse(GenericStore.COMPLEX, rowsCount, columnsCount);
    }

    public static SparseStore makePrimitive(final long rowsCount, final long columnsCount) {
        return SparseStore.makeSparse(Primitive64Store.FACTORY, rowsCount, columnsCount);
    }

    public static SparseStore makePrimitive32(final long rowsCount, final long columnsCount) {
        return SparseStore.makeSparse(Primitive32Store.FACTORY, rowsCount, columnsCount);
    }

    public static SparseStore makeQuaternion(final long rowsCount, final long columnsCount) {
        return SparseStore.makeSparse(GenericStore.QUATERNION, rowsCount, columnsCount);
    }

    public static SparseStore makeRational(final long rowsCount, final long columnsCount) {
        return SparseStore.makeSparse(GenericStore.RATIONAL, rowsCount, columnsCount);
    }

    private static > void doGenericColumnAXPY(final SparseArray elements, final long colX, final long colY, final N a,
            final TransformableRegion y) {

        long structure = y.countRows();

        long first = structure * colX;
        long limit = first + structure;

        elements.visitReferenceTypeNonzerosInRange(first, limit, (index, value) -> y.add(Structure2D.row(index, structure), colY, value.multiply(a)));
    }

    private static void doPrimitiveColumnAXPY(final SparseArray elements, final long colX, final long colY, final double a,
            final TransformableRegion y) {

        long structure = y.countRows();

        long first = structure * colX;
        long limit = first + structure;

        elements.visitPrimitiveNonzerosInRange(first, limit, (index, value) -> y.add(Structure2D.row(index, structure), colY, a * value));
    }

    static > SparseStore makeSparse(final PhysicalStore.Factory physical, final long numberOfRows,
            final long numberOfColumns) {
        return new SparseStore<>(physical, Math.toIntExact(numberOfRows), Math.toIntExact(numberOfColumns));
    }

    static > SparseStore makeSparse(final PhysicalStore.Factory physical, final Structure2D shape) {
        return SparseStore.makeSparse(physical, shape.countRows(), shape.countColumns());
    }

    static > void multiply(final SparseStore left, final SparseStore right, final TransformableRegion target) {

        target.reset();

        if (left.isPrimitive()) {

            SparseArray tmpLeft = (SparseArray) left.getElements();
            TransformableRegion tmpTarget = (TransformableRegion) target;

            right.nonzeros().stream().forEach(element -> {
                SparseStore.doPrimitiveColumnAXPY(tmpLeft, element.row(), element.column(), element.doubleValue(), tmpTarget);
            });

        } else if (left.getComponentType().isAssignableFrom(ComplexNumber.class)) {

            SparseArray tmpLeft = (SparseArray) left.getElements();
            SparseStore tmpRight = (SparseStore) right;
            TransformableRegion tmpTarget = (TransformableRegion) target;

            tmpRight.nonzeros().stream().forEach(element -> {
                SparseStore.doGenericColumnAXPY(tmpLeft, element.row(), element.column(), element.get(), tmpTarget);
            });

        } else if (left.getComponentType().isAssignableFrom(RationalNumber.class)) {

            SparseArray tmpLeft = (SparseArray) left.getElements();
            SparseStore tmpRight = (SparseStore) right;
            TransformableRegion tmpTarget = (TransformableRegion) target;

            tmpRight.nonzeros().stream().forEach(element -> {
                SparseStore.doGenericColumnAXPY(tmpLeft, element.row(), element.column(), element.get(), tmpTarget);
            });

        } else if (left.getComponentType().isAssignableFrom(Quaternion.class)) {

            SparseArray tmpLeft = (SparseArray) left.getElements();
            SparseStore tmpRight = (SparseStore) right;
            TransformableRegion tmpTarget = (TransformableRegion) target;

            tmpRight.nonzeros().stream().forEach(element -> {
                SparseStore.doGenericColumnAXPY(tmpLeft, element.row(), element.column(), element.get(), tmpTarget);
            });

        } else {

            throw new IllegalStateException("Unsupported element type!");
        }
    }

    private final SparseArray myElements;
    private final int[] myFirsts;
    private final int[] myLimits;
    private TransformableRegion.FillByMultiplying myMultiplyer;

    SparseStore(final PhysicalStore.Factory factory, final int rowsCount, final int columnsCount) {

        super(factory, rowsCount, columnsCount);

        myElements = SparseArray.factory(factory.array()).limit(this.count()).initial(Math.max(rowsCount, columnsCount)).make();
        myFirsts = new int[rowsCount];
        myLimits = new int[rowsCount];
        Arrays.fill(myFirsts, columnsCount);
        // Arrays.fill(myLimits, 0); // Behövs inte, redan 0

        Class tmpType = factory.scalar().zero().get().getClass();
        if (tmpType.equals(Double.class)) {
            myMultiplyer = (TransformableRegion.FillByMultiplying) MultiplyBoth.newPrimitive64(rowsCount, columnsCount);
        } else {
            myMultiplyer = (TransformableRegion.FillByMultiplying) MultiplyBoth.newGeneric(rowsCount, columnsCount);
        }
    }

    public void add(final long row, final long col, final Comparable addend) {
        synchronized (myElements) {
            myElements.add(Structure2D.index(myFirsts.length, row, col), addend);
        }
        this.updateNonZeros(row, col);
    }

    public void add(final long row, final long col, final double addend) {
        synchronized (myElements) {
            myElements.add(Structure2D.index(myFirsts.length, row, col), addend);
        }
        this.updateNonZeros(row, col);
    }

    public double doubleValue(final long row, final long col) {
        return myElements.doubleValue(Structure2D.index(myFirsts.length, row, col));
    }

    @Override
    public boolean equals(final Object obj) {
        if (this == obj) {
            return true;
        }
        if (!super.equals(obj) || !(obj instanceof SparseStore)) {
            return false;
        }
        SparseStore other = (SparseStore) obj;
        if (myElements == null) {
            if (other.myElements != null) {
                return false;
            }
        } else if (!myElements.equals(other.myElements)) {
            return false;
        }
        if (!Arrays.equals(myFirsts, other.myFirsts) || !Arrays.equals(myLimits, other.myLimits)) {
            return false;
        }
        return true;
    }

    public void fillByMultiplying(final Access1D left, final Access1D right) {

        int complexity = Math.toIntExact(left.count() / this.countRows());
        if (complexity != Math.toIntExact(right.count() / this.countColumns())) {
            ProgrammingError.throwForMultiplicationNotPossible();
        }

        myMultiplyer.invoke(this, left, complexity, right);
    }

    public void fillOne(final long row, final long col, final Access1D values, final long valueIndex) {
        this.set(row, col, values.get(valueIndex));
    }

    public void fillOne(final long row, final long col, final N value) {
        synchronized (myElements) {
            myElements.fillOne(Structure2D.index(myFirsts.length, row, col), value);
        }
        this.updateNonZeros(row, col);
    }

    public void fillOne(final long row, final long col, final NullaryFunction supplier) {
        synchronized (myElements) {
            myElements.fillOne(Structure2D.index(myFirsts.length, row, col), supplier);
        }
        this.updateNonZeros(row, col);
    }

    public int firstInColumn(final int col) {

        long structure = myFirsts.length;

        long rangeFirst = structure * col;
        long rangeLimit = structure * (col + 1);

        long firstInRange = myElements.firstInRange(rangeFirst, rangeLimit);

        if (rangeFirst == firstInRange) {
            return 0;
        }
        return (int) (firstInRange % structure);
    }

    public int firstInRow(final int row) {
        return myFirsts[row];
    }

    public N get(final long row, final long col) {
        return myElements.get(Structure2D.index(myFirsts.length, row, col));
    }

    @Override
    public int hashCode() {
        int prime = 31;
        int result = super.hashCode();
        result = prime * result + (myElements == null ? 0 : myElements.hashCode());
        result = prime * result + Arrays.hashCode(myFirsts);
        result = prime * result + Arrays.hashCode(myLimits);
        return result;
    }

    @Override
    public int limitOfColumn(final int col) {

        long structure = myFirsts.length;

        long rangeFirst = structure * col;
        long rangeLimit = rangeFirst + structure;

        long limitOfRange = myElements.limitOfRange(rangeFirst, rangeLimit);

        if (rangeLimit == limitOfRange) {
            return (int) structure;
        }
        return (int) (limitOfRange % structure);
    }

    @Override
    public int limitOfRow(final int row) {
        return myLimits[row];
    }

    public void modifyAll(final UnaryFunction modifier) {
        long tmpLimit = this.count();
        if (this.isPrimitive()) {
            for (long i = 0L; i < tmpLimit; i++) {
                this.set(i, modifier.invoke(this.doubleValue(i)));
            }
        } else {
            for (long i = 0L; i < tmpLimit; i++) {
                this.set(i, modifier.invoke(this.get(i)));
            }
        }
    }

    public void modifyMatching(final Access1D left, final BinaryFunction function) {

        long limit = Math.min(left.count(), this.count());
        boolean notModifiesZero = function.invoke(E, ZERO) == ZERO;

        if (this.isPrimitive()) {
            if (notModifiesZero) {
                for (NonzeroView element : myElements.nonzeros()) {
                    element.modify(left.doubleValue(element.index()), function);
                }
            } else {
                for (long i = 0L; i < limit; i++) {
                    this.set(i, function.invoke(left.doubleValue(i), this.doubleValue(i)));
                }
            }
        } else if (notModifiesZero) {
            for (NonzeroView element : myElements.nonzeros()) {
                element.modify(left.get(element.index()), function);
            }
        } else {
            for (long i = 0L; i < limit; i++) {
                this.set(i, function.invoke(left.get(i), this.get(i)));
            }
        }
    }

    public void modifyMatching(final BinaryFunction function, final Access1D right) {

        long limit = Math.min(this.count(), right.count());
        boolean notModifiesZero = function.invoke(ZERO, E) == ZERO;

        if (this.isPrimitive()) {
            if (notModifiesZero) {
                for (NonzeroView element : myElements.nonzeros()) {
                    element.modify(function, right.doubleValue(element.index()));
                }
            } else {
                for (long i = 0L; i < limit; i++) {
                    this.set(i, function.invoke(this.doubleValue(i), right.doubleValue(i)));
                }
            }
        } else if (notModifiesZero) {
            for (NonzeroView element : myElements.nonzeros()) {
                element.modify(function, right.get(element.index()));
            }
        } else {
            for (long i = 0L; i < limit; i++) {
                this.set(i, function.invoke(this.get(i), right.get(i)));
            }
        }
    }

    public void modifyOne(final long row, final long col, final UnaryFunction modifier) {
        if (this.isPrimitive()) {
            this.set(row, col, modifier.invoke(this.doubleValue(row, col)));
        } else {
            this.set(row, col, modifier.invoke(this.get(row, col)));
        }
    }

    public void multiply(final Access1D right, final TransformableRegion target) {

        if (right instanceof SparseStore) {

            SparseStore.multiply(this, (SparseStore) right, target);

        } else if (this.isPrimitive()) {

            long complexity = this.countColumns();
            long numberOfColumns = target.countColumns();

            target.reset();

            this.nonzeros().stream().forEach(element -> {

                long row = element.row();
                long col = element.column();
                double value = element.doubleValue();

                long first = Structure2D.firstInRow(right, col, 0L);
                long limit = Structure2D.limitOfRow(right, col, numberOfColumns);
                for (long j = first; j < limit; j++) {
                    long index = Structure2D.index(complexity, col, j);
                    double addition = value * right.doubleValue(index);
                    if (NumberContext.compare(addition, ZERO) != 0) {
                        target.add(row, j, addition);
                    }
                }
            });

        } else {

            super.multiply(right, target);
        }
    }

    public MatrixStore multiply(final double scalar) {

        SparseStore retVal = SparseStore.makeSparse(this.physical(), this);

        if (this.isPrimitive()) {

            for (ElementView2D nonzero : this.nonzeros()) {
                retVal.set(nonzero.index(), nonzero.doubleValue() * scalar);
            }

        } else {

            Scalar sclr = this.physical().scalar().convert(scalar);

            for (ElementView2D nonzero : this.nonzeros()) {
                retVal.set(nonzero.index(), sclr.multiply(nonzero.get()).get());
            }
        }

        return retVal;
    }

    public MatrixStore multiply(final MatrixStore right) {

        long numberOfRows = this.countRows();
        long numberOfColumns = right.countColumns();

        if (right instanceof SparseStore) {

            SparseStore retVal = SparseStore.makeSparse(this.physical(), numberOfRows, numberOfColumns);

            SparseStore.multiply(this, (SparseStore) right, retVal);

            return retVal;
        }

        PhysicalStore retVal = this.physical().make(numberOfRows, numberOfColumns);

        this.multiply(right, retVal);

        return retVal;
    }

    public MatrixStore multiply(final N scalar) {

        SparseStore retVal = SparseStore.makeSparse(this.physical(), this);

        if (this.isPrimitive()) {

            double sclr = NumberDefinition.doubleValue(scalar);

            for (ElementView2D nonzero : this.nonzeros()) {
                retVal.set(nonzero.index(), nonzero.doubleValue() * sclr);
            }

        } else {

            Scalar sclr = this.physical().scalar().convert(scalar);

            for (ElementView2D nonzero : this.nonzeros()) {
                retVal.set(nonzero.index(), sclr.multiply(nonzero.get()).get());
            }
        }

        return retVal;
    }

    @Override
    public N multiplyBoth(final Access1D leftAndRight) {
        // TODO Auto-generated method stub
        return super.multiplyBoth(leftAndRight);
    }

    public ElementView2D nonzeros() {
        return new Access2D.ElementView<>(myElements.nonzeros(), this.countRows());
    }

    public ElementsSupplier premultiply(final Access1D left) {

        long complexity = this.countRows();
        long numberOfColumns = this.countColumns();
        long numberOfRows = left.count() / complexity;

        if (left instanceof SparseStore) {

            SparseStore retVal = SparseStore.makeSparse(this.physical(), numberOfRows, numberOfColumns);

            SparseStore.multiply((SparseStore) left, this, retVal);

            return retVal;

        }
        if (!this.isPrimitive()) {

            return super.premultiply(left);
        }
        SparseStore retVal = SparseStore.makeSparse(this.physical(), numberOfRows, numberOfColumns);

        this.nonzeros().stream().forEach(element -> {

            long row = element.row();
            long col = element.column();
            double value = element.doubleValue();

            long first = Structure2D.firstInColumn(left, row, 0L);
            long limit = Structure2D.limitOfColumn(left, row, numberOfRows);
            for (long i = first; i < limit; i++) {
                long index = Structure2D.index(numberOfRows, i, row);
                double addition = value * left.doubleValue(index);
                if (NumberContext.compare(addition, ZERO) != 0) {
                    retVal.add(i, col, addition);
                }
            }
        });

        return retVal;
    }

    @SuppressWarnings("unchecked")
    public void reduceColumns(final Aggregator aggregator, final Mutate1D receiver) {
        if (aggregator == Aggregator.SUM && receiver instanceof Mutate1D.Modifiable) {
            if (this.isPrimitive()) {
                this.nonzeros().forEach(element -> ((Modifiable) receiver).add(element.column(), element.doubleValue()));
            } else {
                this.nonzeros().forEach(element -> ((Modifiable) receiver).add(element.column(), element.get()));
            }
        } else {
            super.reduceColumns(aggregator, receiver);
        }
    }

    @SuppressWarnings("unchecked")
    public void reduceRows(final Aggregator aggregator, final Mutate1D receiver) {
        if (aggregator == Aggregator.SUM && receiver instanceof Mutate1D.Modifiable) {
            if (this.isPrimitive()) {
                this.nonzeros().forEach(element -> ((Modifiable) receiver).add(element.row(), element.doubleValue()));
            } else {
                this.nonzeros().forEach(element -> ((Modifiable) receiver).add(element.row(), element.get()));
            }
        } else {
            super.reduceColumns(aggregator, receiver);
        }
    }

    public TransformableRegion regionByColumns(final int... columns) {
        return new Subregion2D.ColumnsRegion<>(this, myMultiplyer, columns);
    }

    public TransformableRegion regionByLimits(final int rowLimit, final int columnLimit) {
        return new Subregion2D.LimitRegion<>(this, myMultiplyer, rowLimit, columnLimit);
    }

    public TransformableRegion regionByOffsets(final int rowOffset, final int columnOffset) {
        return new Subregion2D.OffsetRegion<>(this, myMultiplyer, rowOffset, columnOffset);
    }

    public TransformableRegion regionByRows(final int... rows) {
        return new Subregion2D.RowsRegion<>(this, myMultiplyer, rows);
    }

    public TransformableRegion regionByTransposing() {
        return new Subregion2D.TransposedRegion<>(this, myMultiplyer);
    }

    public void reset() {
        myElements.reset();
        Arrays.fill(myFirsts, this.getColDim());
        Arrays.fill(myLimits, 0);
    }

    public void set(final long row, final long col, final Comparable value) {
        synchronized (myElements) {
            myElements.set(Structure2D.index(myFirsts.length, row, col), value);
        }
        this.updateNonZeros(row, col);
    }

    public void set(final long row, final long col, final double value) {
        synchronized (myElements) {
            myElements.set(Structure2D.index(myFirsts.length, row, col), value);
        }
        this.updateNonZeros(row, col);
    }

    public void supplyTo(final TransformableRegion receiver) {

        receiver.reset();

        myElements.supplyNonZerosTo(receiver);
    }

    public void visitColumn(final long row, final long col, final VoidFunction visitor) {

        long structure = this.countRows();
        long first = Structure2D.index(structure, row, col);
        long limit = Structure2D.index(structure, 0, col + 1L);

        myElements.visitRange(first, limit, visitor);
    }

    public void visitRow(final long row, final long col, final VoidFunction visitor) {
        int counter = 0;
        if (this.isPrimitive()) {
            for (ElementView2D nzv : this.nonzeros()) {
                if (nzv.row() == row) {
                    visitor.accept(nzv.doubleValue());
                    counter++;
                }
            }
        } else {
            for (ElementView2D nzv : this.nonzeros()) {
                if (nzv.row() == row) {
                    visitor.accept(nzv.get());
                    counter++;
                }
            }
        }
        if (col + counter < this.countColumns()) {
            visitor.accept(0.0);
        }
    }

    private void updateNonZeros(final long row, final long col) {
        this.updateNonZeros((int) row, (int) col);
    }

    SparseArray getElements() {
        return myElements;
    }

    void updateNonZeros(final int row, final int col) {
        myFirsts[row] = Math.min(col, myFirsts[row]);
        myLimits[row] = Math.max(col + 1, myLimits[row]);
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy