All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.api.ndarray.BaseSparseNDArrayCOO Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.nd4j.linalg.api.ndarray;

import org.nd4j.shade.guava.primitives.Doubles;
import org.nd4j.shade.guava.primitives.Ints;
import org.nd4j.shade.guava.primitives.Longs;
import com.google.flatbuffers.FlatBufferBuilder;
import net.ericaro.neoitertools.Generator;
import org.nd4j.linalg.api.blas.params.MMulTranspose;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.*;
import org.nd4j.linalg.profiler.OpProfiler;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.linalg.util.LongUtils;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.NoSuchElementException;

import static org.nd4j.base.Preconditions.checkArgument;
import static org.nd4j.base.Preconditions.checkNotNull;

/**
 * @author Audrey Loeffel
 */

/*
* TODO :
* - Implement the INDArray methods
* - Sort the databuffers
* - Check at the creation if there are any 0 values and remove them
* - add indexesOrdering in constructor
* - BaseSparseNDArray should extend from BaseNDArray : remove the duplicate methods
* */
public class BaseSparseNDArrayCOO extends BaseSparseNDArray {
    protected static final SparseFormat format = SparseFormat.COO;
    protected transient volatile DataBuffer values;
    protected transient volatile DataBuffer indices;
    protected transient volatile boolean isSorted = false;

    public BaseSparseNDArrayCOO(DataBuffer values, DataBuffer indices, long[] shape) {
        checkArgument(values.length() * shape.length == indices.length());

        this.values = values;
        this.indices = indices;
        setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, Nd4j.dataType()));
        init(shape);
        this.length = values.length();
        int[] flags = new int[rank()];
        long[] sparseOffsets = new long[rank()];
        int[] hiddenDimension = new int[] {-1};
        this.sparseInformation = Nd4j.getSparseInfoProvider().createSparseInformation(flags, sparseOffsets,
                hiddenDimension, rank());

    }

    public BaseSparseNDArrayCOO(double[] values, long[][] indices, long[] shape) {
        this.indices = createIndiceBuffer(indices, shape);
        this.values = createValueBuffer(values);
        length = values.length;

        setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, Nd4j.dataType()));
        init(shape);
        this.sparseInformation = createSparseInformationBuffer(rank());
        checkBufferCoherence();
    }

    public BaseSparseNDArrayCOO(float[] values, long[][] indices, long[] shape) {
        this.indices = createIndiceBuffer(indices, shape);
        this.values = createValueBuffer(values);
        length = values.length;

        setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, Nd4j.dataType()));
        init(shape);
        this.sparseInformation = createSparseInformationBuffer(rank());
        checkBufferCoherence();
    }

    public BaseSparseNDArrayCOO(double[] values, int[][] indices, long[] shape) {
        this.indices = createIndiceBuffer(indices, shape);
        this.values = createValueBuffer(values);
        length = values.length;

        setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, Nd4j.dataType()));
        init(shape);
        this.sparseInformation = createSparseInformationBuffer(rank());
        checkBufferCoherence();
    }

    public BaseSparseNDArrayCOO(float[] values, int[][] indices, long[] shape) {
        this.indices = createIndiceBuffer(indices, shape);
        this.values = createValueBuffer(values);
        length = values.length;

        setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, Nd4j.dataType()));
        init(shape);
        this.sparseInformation = createSparseInformationBuffer(rank());
        checkBufferCoherence();
    }


    public BaseSparseNDArrayCOO(DataBuffer values, DataBuffer indices, DataBuffer sparseInformation, long[] shape) {
        this.values = Nd4j.createBuffer(values, 0, values.length());
        this.indices = indices;
        setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, Nd4j.dataType()));
        init(shape);
        this.sparseInformation = sparseInformation;
        this.length = countNNZ();

    }

    public BaseSparseNDArrayCOO(DataBuffer values, DataBuffer indices, long[] sparseOffsets, int[] flags,
                                int[] hiddenDimensions, int underlyingRank, long[] shape) {
        this(values, indices, Nd4j.getSparseInfoProvider().createSparseInformation(flags, sparseOffsets,
                        hiddenDimensions, underlyingRank), shape);
    }


    /**
     * Check that the length of indices and values are coherent and matches the rank of the matrix.
     */
    protected void checkBufferCoherence(){
        if (values.length() < length){
            throw new IllegalStateException("nnz is larger than capacity of buffers");
        }

        if (values.length() * rank() != indices.length()){
            throw new IllegalArgumentException("Sizes of values, indices and shape are incoherent.");
        }
    }

    /**
     * Create a SparseInfo databuffer given rank if of the sparse matrix.
     * @param rank
     * @return
     */
    protected static DataBuffer createSparseInformationBuffer(int rank){
        int[] flags = new int[rank];
        long[] sparseOffsets = new long[rank];
        int[] hiddenDimension = new int[] {-1};
        return Nd4j.getSparseInfoProvider().createSparseInformation(flags, sparseOffsets,
                hiddenDimension, rank);
    }


    /**
     * Create a DataBuffer for values of given array of values.
     * @param values
     * @return
     */
    protected static DataBuffer createValueBuffer(float[] values) {
        checkNotNull(values);
        if (values.length == 0){
            return Nd4j.createBuffer(1);
        }
        return Nd4j.createBuffer(values);
    }


    /**
     * Create a DataBuffer for values of given array of values.
     * @param values
     * @return
     */
    protected static DataBuffer createValueBuffer(double[] values) {
        checkNotNull(values);
        if (values.length == 0){
            return Nd4j.createBuffer(1);
        }
        return Nd4j.createBuffer(values);
    }



    /**
     * Create a DataBuffer for indices of given arrays of indices.
     * @param indices
     * @param shape
     * @return
     */
    protected static DataBuffer createIndiceBuffer(long[][] indices, long[] shape){
        checkNotNull(indices);
        checkNotNull(shape);
        if(indices.length == 0){
            return Nd4j.getDataBufferFactory().createLong(shape.length);
        }

        if (indices.length == shape.length) {
            return Nd4j.createBuffer(ArrayUtil.flattenF(indices));
        }

        return Nd4j.createBuffer(ArrayUtil.flatten(indices));
    }

    /**
     * Create a DataBuffer for indices of given arrays of indices.
     * @param indices
     * @param shape
     * @return
     */
    protected static DataBuffer createIndiceBuffer(int[][] indices, long[] shape){
        checkNotNull(indices);
        checkNotNull(shape);
        if(indices.length == 0){
            return Nd4j.getDataBufferFactory().createLong(shape.length);
        }

        if (indices.length == shape.length) {
            return Nd4j.createBuffer(ArrayUtil.toLongArray(ArrayUtil.flattenF(indices)));
        }

        return Nd4j.createBuffer(ArrayUtil.toLongArray(ArrayUtil.flatten(indices)));
    }

    @Override
    public int toFlatArray(FlatBufferBuilder builder) {
        throw new UnsupportedOperationException();
    }



    /**
     * Count the number of value that are included in the ndarray (view) according to the sparse offsets and the shape
     * @return nnz
     * */
    public long countNNZ() {
        long count = 0;

        for (int i = 0; i < values.length(); i++) {
            int[] idx = getUnderlyingIndicesOf(i).asInt();
            boolean isIn = true;
            int idxNotFixed = 0;
            for (int dim = 0; dim < idx.length; dim++) {

                if (flags()[dim] == 1) {
                    if (sparseOffsets()[dim] != idx[dim]) {
                        isIn = false;
                        break;
                    }
                } else {
                    int lowerBound = sparseOffsets()[dim];
                    long upperBound = sparseOffsets()[dim] + shape()[idxNotFixed];
                    if (!(idx[dim] >= lowerBound && idx[dim] < upperBound)) {
                        isIn = false;
                        break;
                    }
                    idxNotFixed++;
                }
            }
            count = isIn ? count + 1 : count;
        }
        return count;
    }

    @Override
    public INDArray assign(final INDArray arr) {
        sort();
        // TODO - set via native op
        return this;
    }

    /**
     * Sort the indexes and the values buffers
     * */
    public void sort() {
        if (!isSorted) {
            Nd4j.sparseFactory().sortCooIndices(this);
            isSorted = true;
        }
    }

    /**
     * Translate the view index to the corresponding index of the original ndarray
     * @param virtualIndexes the view indexes
     * @return the original indexes
     * */
    public long[] translateToPhysical(long[] virtualIndexes) {

        long[] physicalIndexes = new long[underlyingRank()];
        int idxPhy = 0;
        int hidden = 0;

        for (int idxVir = 0; idxVir < virtualIndexes.length; idxVir++) {
            if (hidden < getNumHiddenDimension() && hiddenDimensions()[hidden] == idxVir) {
                hidden++;
            } else {
                while (idxPhy < underlyingRank() && isDimensionFixed(idxPhy)) {
                    physicalIndexes[idxPhy] = sparseOffsets()[idxPhy];
                    idxPhy++;
                }
                if (idxPhy < underlyingRank() && !isDimensionFixed(idxPhy)) {
                    physicalIndexes[idxPhy] = sparseOffsets()[idxPhy] + virtualIndexes[idxVir];
                    idxPhy++;
                }
            }
        }
        return physicalIndexes;
    }

    public int[] translateToPhysical(int[] virtualIndexes) {

        int[] physicalIndexes = new int[underlyingRank()];
        int idxPhy = 0;
        int hidden = 0;

        for (int idxVir = 0; idxVir < virtualIndexes.length; idxVir++) {
            if (hidden < getNumHiddenDimension() && hiddenDimensions()[hidden] == idxVir) {
                hidden++;
            } else {
                while (idxPhy < underlyingRank() && isDimensionFixed(idxPhy)) {
                    physicalIndexes[idxPhy] = sparseOffsets()[idxPhy];
                    idxPhy++;
                }
                if (idxPhy < underlyingRank() && !isDimensionFixed(idxPhy)) {
                    physicalIndexes[idxPhy] = sparseOffsets()[idxPhy] + virtualIndexes[idxVir];
                    idxPhy++;
                }
            }
        }
        return physicalIndexes;
    }

    /**
     * Return if the dimension in argument is a fixed dimension.
     * */
    public boolean isDimensionFixed(int i) {
        return flags()[i] == 1;
    }


    @Override
    public INDArray putScalar(long i, double value) {
        if (i < 0)
            i += rank();
        if (isScalar()) {
            if (Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.DISABLED && Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.SCOPE_PANIC)
                OpProfiler.getInstance().processScalarCall();

            addOrUpdate(new long[] {0, 0}, value);
            return this;
        }
        if (isRowVector()) {
            addOrUpdate(new long[] {0, i}, value);
            return this;
        } else if (isColumnVector()) {
            addOrUpdate(new long[] {i, 0}, value);
            return this;
        }
        long[] indexes = ordering() == 'c' ? Shape.ind2subC(this, i) : Shape.ind2sub(this, i);
        return putScalar(indexes, value);
    }

    @Override
    public INDArray putScalar(long i, float value) {
        return putScalar(i, (double) value); //todo - move in baseSparse?
    }

    @Override
    public INDArray putScalar(long i, int value) {
        return putScalar(i, (double) value); //todo
    }

    @Override
    public INDArray putScalar(int[] indexes, double value) {
        return putScalar(ArrayUtil.toLongArray(indexes), value);
    }

    @Override
    public INDArray putScalar(long[] indexes, double value) {
        for (int i = 0; i < indexes.length; i++) {
            if (indexes[i] < 0)
                indexes[i] += rank();
        }

        if (indexes.length == 1) {
            return putScalar(indexes[0], value);
        }
        if (indexes.length != rank) {
            throw new IllegalStateException(
                    "Cannot use putScalar with indexes length " + indexes.length + " on rank " + rank);
        }
        addOrUpdate(indexes, value);
        return this;
    }

    @Override
    public INDArray putScalar(long[] i, float value) {
        return null;
    }

    @Override
    public INDArray putScalar(long[] i, int value) {
        return null;
    }

    @Override
    public INDArray putScalar(long row, long col, double value) {
        return putScalar(new long[] {row, col}, value);
    }

    @Override
    public INDArray putScalar(long dim0, long dim1, long dim2, double value) {
        return putScalar(new long[] {dim0, dim1, dim2}, value);
    }

    @Override
    public INDArray putScalar(long dim0, long dim1, long dim2, long dim3, double value) {
        return putScalar(new long[] {dim0, dim1, dim2, dim3}, value);
    }

    @Override
    public INDArray putRow(long row, INDArray toPut) {
        if (isRowVector() && toPut.isVector()) {
            return assign(toPut);
        }
        return put(new INDArrayIndex[] {NDArrayIndex.point(row), NDArrayIndex.all()}, toPut);
    }

    @Override
    public INDArray putColumn(int column, INDArray toPut) {
        if (isColumnVector() && toPut.isVector()) {
            return assign(toPut);
        }
        return put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.point(column)}, toPut);
    }

    @Override
    public INDArray put(INDArrayIndex[] indices, INDArray element) {
        get(indices).assign(element);
        return this;
    }

    @Override
    public INDArray put(INDArrayIndex[] indices, Number element) {
        INDArray get = get(indices);
        for (int i = 0; i < get.length(); i++)
            get.putScalar(i, element.doubleValue());
        return this;
    }

    @Override
    public INDArray put(int[] indexes, INDArray element) {

        if (!element.isScalar())
            throw new IllegalArgumentException("Unable to insert anything but a scalar");
        if (indexes.length != rank)
            throw new IllegalStateException(
                            "Cannot use putScalar with indexes length " + indexes.length + " on rank " + rank);

        addOrUpdate(ArrayUtil.toLongArray(indexes), element.getDouble(0));
        return this;
    }

    @Override
    public INDArray put(int i, int j, INDArray element) { // TODO in base ?
        return put(new int[] {i, j}, element);
    }

    @Override
    public INDArray put(int i, int j, Number element) { // TODO in base
        return putScalar(new int[] {i, j}, element.doubleValue());
    }

    @Override
    public INDArray put(int i, INDArray element) { // TODO remove and use basendarray method
        if (!element.isScalar())
            throw new IllegalArgumentException("Element must be a scalar");
        return putScalar(i, element.getDouble(0));
    }

    /**
     * Add a new element in the ndarray or update the value if there is already a non-null element at this position
     * @param indexes the indexes of the element to be added
     * @param value the value of the element to be added
     * */
    public void addOrUpdate(long[] indexes, double value) {

        long[] physicalIndexes = isView() ? translateToPhysical(indexes) : indexes;

        for (int i = 0; i < length; i++) {
            long[] idx = getUnderlyingIndicesOf(i).asLong();
            if (Arrays.equals(idx, physicalIndexes)) {
                // There is already a non-null value at this index
                // -> update the current value, the sort is maintained
                if (value == 0) {
                    removeEntry(i);
                    length--;
                } else {
                    values.put(i, value);
                    length++;
                }
                return;
            }
        }

        // If the value is 0 and there is no existing non-null value at the given index
        if (value == 0) {
            return;
        }

        /* It's a new non-null element. We add the value and the indexes at the end of their respective databuffers.
        * The buffers are no longer sorted !
        * /!\ We need to reallocate the buffers if they are full
        */
        while (!canInsert(values, 1)) {
            long size = (long) Math.ceil((values.capacity() * THRESHOLD_MEMORY_ALLOCATION));
            values.reallocate(size);
        }
        values.put(length, value);
        while (!canInsert(indices, physicalIndexes.length)) {
            long size = (long) Math.ceil((indices.capacity() * THRESHOLD_MEMORY_ALLOCATION));
            indices.reallocate(size);
        }
        for (int i = 0; i < physicalIndexes.length; i++) {
            indices.put(length * rank() + i, physicalIndexes[i]);
        }
        length++;
        isSorted = false;
    }

    /**
     * Return if there is enough allocated memory space to add data of a given length in the databuffer
     * @param buffer a databuffer in which we want to add data
     * @param length the length of the data
     * @return a boolean if the insertion is possible
     * */
    public boolean canInsert(DataBuffer buffer, int length) {
        return buffer.capacity() - buffer.length() >= length;
    }

    public DataBuffer shiftLeft(DataBuffer buffer, int from, int offset, long datalength) {
        for (int i = from; i < datalength; i++) {
            buffer.put(i - offset, buffer.getDouble(i));
        }
        return buffer;
    }

    /**
     * Remove an element of the ndarray
     * @param idx the index of the element to be removed
     * @return the ndarray
     * */
    public INDArray removeEntry(int idx) {
        values = shiftLeft(values, idx + 1, 1, length());
        indices = shiftLeft(indices, (int) (idx * shape.length() + shape.length()), (int) shape.length(),
                        indices.length());
        return this;
    }

    @Override
    public INDArray get(INDArrayIndex... indexes) {

        sort();

        if (indexes.length == 1 && indexes[0] instanceof NDArrayIndexAll || (indexes.length == 2 && (isRowVector()
                        && indexes[0] instanceof PointIndex && indexes[0].offset() == 0
                        && indexes[1] instanceof NDArrayIndexAll
                        || isColumnVector() && indexes[1] instanceof PointIndex && indexes[0].offset() == 0
                                        && indexes[0] instanceof NDArrayIndexAll)))
            return this;

        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public INDArray assign(boolean value) {
        return assign(value ? 1 : 0);
    }

    @Override
    public INDArray repeat(int dimension, long... repeats) {
        return null;
    }


    public int[] filterOutFixedDimensions(int[] flags, List idx) {
        checkArgument(flags.length == idx.size());
        int lastIdx = idx.size() - 1;
        for (int i = lastIdx; i >= 0; i--) {
            if (flags[i] == 1) {
                idx.remove(i);
            }
        }
        return Ints.toArray(idx);
    }

    /**
     * Return the index of the value corresponding to the indexes
     * @param indexes
     * @return index of the value
     * */
    public int reverseIndexes(int... indexes) {
        long[] idx = translateToPhysical(ArrayUtil.toLongArray(indexes));
        sort();

        // FIXME: int cast
        return indexesBinarySearch(0, (int) length(), ArrayUtil.toInts(idx));
    }

    /**
     * Return the position of the idx array into the indexes buffer between the lower and upper bound.
     * @param idx a set of coordinates
     * @param lowerBound the lower bound of the position
     * @param upperBound the upper bound of the position
     * @return the position of the idx array into the indexes buffers, which corresponds to the position of
     * the corresponding value in the values data.
     * */
    public int indexesBinarySearch(int lowerBound, int upperBound, int[] idx) {
        int min = lowerBound;
        int max = upperBound;

        int mid = (max + min) / 2;
        int[] midIdx = getUnderlyingIndicesOf(mid).asInt();
        if (Arrays.equals(idx, midIdx)) {
            return mid;
        }
        if (ArrayUtil.lessThan(idx, midIdx)) {
            max = mid;
        }
        if (ArrayUtil.greaterThan(idx, midIdx)) {
            min = mid;
        }
        if (min == max) {
            return -1;
        }
        return indexesBinarySearch(min, max, idx);
    }

    @Override
    public INDArray getScalar(int... indices) {
        return super.getScalar(indices);
    }

    @Override
    public INDArray getScalar(long... indices) {
        return null;
    }

    @Override
    public int getInt(int... indices) {
        return super.getInt(indices);
    }

    @Override
    public long getLong(long index) {
        return 0;
    }

    @Override
    public double getDouble(int... indices) {
        int valIdx = reverseIndexes(indices);
        if (valIdx == -1) {
            return 0;
        } else {
            return values.getDouble(valIdx);
        }
    }

    @Override
    public double getDouble(long... indices) {
        return 0;
    }

    @Override
    public float getFloat(int[] indices) {
        return (float) getDouble(indices);
    }

    @Override
    public float getFloat(long[] indices) {
        return 0;
    }

    @Override
    public double getDouble(long i) {
        if (i >= length()) {
            throw new IllegalArgumentException("Unable to get linear index >= " + length());
        }

        if (Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.DISABLED && Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.SCOPE_PANIC)
            OpProfiler.getInstance().processScalarCall();

        if (i == 0)
            return data().getDouble(i);

        long[] dimensions = ordering() == 'c' ? Shape.ind2subC(this, i) : Shape.ind2sub(this, i);
        Shape.assertShapeLessThan(dimensions, shape());
        return getDouble(dimensions);
    }

    @Override
    public double getDouble(long i, long j) {
        return getDouble(new long[] {i, j});
    }

    @Override
    public float getFloat(long i) {
        return (float) getDouble(i);
    }

    @Override
    public float getFloat(long i, long j) {
        return (float) getDouble(i, j);
    }

    @Override
    public INDArray reshape(char order, int... newShape) {
        return null;
    }

    @Override
    public INDArray reshape(char order, boolean enforceView, long... newShape) {
        return null;
    }

    @Override
    public INDArray reshape(int[] shape) {
        return null;
    }

    @Override
    public LongShapeDescriptor shapeDescriptor() {
        throw new UnsupportedOperationException();
    }

    public SparseFormat getFormat() {
        return format;
    }

    @Override
    public int underlyingRank() {
        return Shape.underlyingRank(sparseInformation);
    }

    @Override
    public DataBuffer data() {
        return values;
    }

    /**
     * Return the indices buffer
     * @return indices
     * */
    public DataBuffer getUnderlyingIndices() {
        return indices;
    }

    /**
    * Return a copy of the indices included in the view.
    * /!\ Change this DataBuffer won't change the ndarray!
    * @return an array containing the virtual indexes of the values (think about views).
    * */
    public DataBuffer getIncludedIndices() {

        if (isScalar()) {
            return Nd4j.createBuffer(new int[] {0, 0});
        }

        List ind = new ArrayList<>();

        for (int i = 0; i < values.length(); i++) {
            boolean isIn = true;
            int idxNotFixed = 0;
            int[] idx = getUnderlyingIndicesOf(i).asInt(); // TODO change for getIndicesOf(i)

            for (int dim = 0; dim < idx.length; dim++) {
                if (flags()[dim] == 1) {
                    if (sparseOffsets()[dim] != idx[dim]) {
                        isIn = false;
                        break;
                    }
                } else {
                    int lowerBound = sparseOffsets()[dim];
                    long upperBound = sparseOffsets()[dim] + shape()[idxNotFixed];
                    if (!(idx[dim] >= lowerBound && idx[dim] < upperBound)) {
                        isIn = false;
                        break;
                    }
                    idxNotFixed++;
                }
            }
            if (isIn) {
                int notFixedDim = 0;
                for (int dim = 0; dim < idx.length; dim++) {
                    if (flags()[dim] == 0) {
                        if (shape()[notFixedDim] == 1) {
                            ind.add(0);
                            notFixedDim++;
                        } else {
                            ind.add(idx[dim] - sparseOffsets()[dim]);
                        }
                    }
                }
            }
        }
        return Nd4j.createBuffer(Ints.toArray(ind));
    }

    /**
     * Return the values buffer
     * @return values
     * */
    public DataBuffer getUnderlyingValues() {
        return values;
    }

    /**
     * Return a copy of the values included in the array.
     * /!\ Change this DataBuffer won't change the ndarray!
     * @return an array containing the values
     * */
    public DataBuffer getIncludedValues() {
        List val = new ArrayList<>();

        for (int i = 0; i < values.length(); i++) {
            boolean isIn = true;
            int idxNotFixed = 0;
            int[] idx = getUnderlyingIndicesOf(i).asInt();
            for (int dim = 0; dim < idx.length; dim++) {
                if (flags()[dim] == 1) {
                    if (sparseOffsets()[dim] != idx[dim]) {
                        isIn = false;
                        break;
                    }
                } else {
                    int lowerBound = sparseOffsets()[dim];
                    long upperBound = sparseOffsets()[dim] + shape()[idxNotFixed];
                    if (!(idx[dim] >= lowerBound && idx[dim] < upperBound)) {
                        isIn = false;
                        break;
                    }
                    idxNotFixed++;
                }
            }
            if (isIn) {
                val.add(values.getDouble(i));
            }
        }
        return Nd4j.createBuffer(Doubles.toArray(val));
    }

    /**
     * Returns the indices of non-zero element of the vector
     *
     * @return indices in Databuffer
     * */
    @Override
    public DataBuffer getVectorCoordinates() {
        int idx;
        if (isRowVector()) {
            idx = 1;
        } else if (isColumnVector()) {
            idx = 0;
        } else {
            throw new UnsupportedOperationException();
        }

        // FIXME: int cast
        int[] temp = new int[(int) length()];
        for (int i = 0; i < length(); i++) {
            temp[i] = getUnderlyingIndicesOf(i).getInt(idx);
        }
        return Nd4j.createBuffer(temp);
    }

    /**
     * Converts the sparse ndarray into a dense one
     * @return a dense ndarray
     */
    @Override
    public INDArray toDense() {
        // TODO support view conversion
        INDArray result = Nd4j.zeros(shape());

        switch (data().dataType()) {
            case DOUBLE:
                for (int i = 0; i < length; i++) {
                    int[] idx = getUnderlyingIndicesOf(i).asInt();
                    double value = values.getDouble(i);
                    result.putScalar(idx, value);
                }
                break;
            case FLOAT:
                for (int i = 0; i < length; i++) {
                    int[] idx = getUnderlyingIndicesOf(i).asInt();
                    float value = values.getFloat(i);
                    result.putScalar(idx, value);
                }
                break;
            default:
                throw new UnsupportedOperationException();
        }
        return result;
    }

    @Override
    public DataBuffer shapeInfoDataBuffer() {
        return shapeInformation;
    }

    /**
     * Compute the sparse offsets of the view we are getting, for each dimension according to the original ndarray
     * @param offset the offset of the view
     * @return an int array containing the sparse offsets
     * */
    private long[] createSparseOffsets(long offset) {

        // resolve the offsets in the view dimension
        int underlyingRank = sparseOffsets().length;
        long[] newOffsets = new long[rank()];
        List shapeList = Longs.asList(shape());
        int penultimate = rank() - 1;
        for (int i = 0; i < penultimate; i++) {
            long prod = ArrayUtil.prodLong(shapeList.subList(i + 1, rank()));
            newOffsets[i] = offset / prod;
            offset = offset - newOffsets[i] * prod;
        }
        newOffsets[rank() - 1] = offset % shape()[rank() - 1];

        // Merge the offsets with the original sparseOffsets
        long[] finalOffsets = new long[underlyingRank];
        int dimNotFixed = 0;
        for (int dim = 0; dim < underlyingRank; dim++) {
            if (flags()[dim] == 1) {
                finalOffsets[dim] = sparseOffsets()[dim];
            } else {
                finalOffsets[dim] = newOffsets[dimNotFixed] + sparseOffsets()[dim];
                dimNotFixed++;
            }
        }
        return finalOffsets;
    }

    private int[] createHiddenDimensions(int[] newAxis) {
        if (newAxis == null || newAxis.length == 0) {
            return hiddenDimensions();
        }
        if (getNumHiddenDimension() == 0) {
            return newAxis;
        }
        int size = newAxis.length + hiddenDimensions().length;
        int[] newHiddenDim = new int[size];
        int newDim = 0;
        int actualArrayIdx = 0;
        for (int oldDim = 0; oldDim < getNumHiddenDimension(); oldDim++) {
            while ((newDim < newAxis.length) && (oldDim >=getNumHiddenDimension() || newAxis[newDim] <= hiddenDimensions()[oldDim])) {
                newHiddenDim[actualArrayIdx] = newAxis[newDim] + oldDim;
                actualArrayIdx++;
                newDim++;
            }
            newHiddenDim[actualArrayIdx] = hiddenDimensions()[oldDim] + newDim;
            actualArrayIdx++;
        }
        return newHiddenDim;
    }

    /**
    * Adjust the flags array according on the current context:
    * - In case of a vector or a scalar, we need to keep flags to 0
    * - There must always be at least 2 non-flags dimensions
    * - We must keep the flags dimensions of the original array
    *
    * @param viewFlags the Fixed array calculate within the view
    * @param newShape the shape of the view
    * */
    private int[] updateFlags(int[] viewFlags, long[] newShape) {
        // Check if flags is well-formed
        int count = 0;
        for (int i = 0; i < viewFlags.length; i++) {
            if (viewFlags[i] == 0) {
                count++;
            }
        }

        for (int dim = 0; dim < viewFlags.length; dim++) {
            if (viewFlags[dim] == 1) {
                if (newShape[dim] == 1 && count < 2) {
                    viewFlags[dim] = 0;
                }

            }
        }
        // Take the original Fixed into account
        int[] extendedFlags = new int[underlyingRank()];
        int notFixedDim = 0;
        for (int dim = 0; dim < underlyingRank(); dim++) {
            int[] temp = flags();
            if (flags()[dim] == 0) {
                extendedFlags[dim] = viewFlags[notFixedDim];
                notFixedDim++;
            } else {
                extendedFlags[dim] = 1;
            }
        }
        return extendedFlags;
    }

    /**
     * Returns the underlying indices of the element of the given index
     * such as there really are in the original ndarray
     *
     * @param i the index of the element+
     * @return a dataBuffer containing the indices of element
     * */
    public DataBuffer getUnderlyingIndicesOf(int i) {
        int from = underlyingRank() * i;
        //int to = from + underlyingRank();
        int[] res = new int[underlyingRank()];
        for(int j = 0; j< underlyingRank(); j++){
            res[j] = indices.getInt(from + j);
        }

        ///int[] arr = Arrays.copyOfRange(indices.asInt(), from, to);
        return Nd4j.getDataBufferFactory().createInt(res);
    }

    /**
     * Returns the indices of the element of the given index in the array context
     *
     * @param i the index of the element
     * @return a dataBuffer containing the indices of element
     * */
    public DataBuffer getIndicesOf(int i) {
        int from = underlyingRank() * i;
        int to = from + underlyingRank(); //not included

        int[] arr = new int[rank];
        int j = 0; // iterator over underlying indices
        int k = 0; //iterator over hiddenIdx
        for (int dim = 0; dim < rank; dim++) {
            if (k < hiddenDimensions().length && hiddenDimensions()[k] == j) {
                arr[dim] = 0;
                k++;
            } else {
                arr[dim] = indices.getInt(j);
                j++;
            }
        }
        return Nd4j.getDataBufferFactory().createInt(arr);
    }


    public boolean isZero(int... indexes) {
        for (int i = 0; i < length(); i++) {
            int[] idx = getUnderlyingIndicesOf(i).asInt();
            if (Arrays.equals(idx, translateToPhysical(indexes))) {
                return false;
            }
        }
        return true;
    }

    @Override
    public boolean isView() {
        return Shape.offset(shapeInformation) > 0 || data().originalDataBuffer() != null; // TODO or if sparseOffset/flags != [0, ..,0]
    }

    @Override
    public String getString(long index) {
        return null;
    }


    public int getNumHiddenDimension() {
        if (hiddenDimensions() == null || hiddenDimensions().length == 0) {
            throw new IllegalStateException("HiddenDimension array is malformed");
        }
        return hiddenDimensions()[0] == -1 ? 0 : hiddenDimensions().length;
    }

    public boolean isSorted() {
        return isSorted;
    }

    public DataBuffer getValues() {
        return values;
    }

    public DataBuffer getIndices() {
        return indices;
    }

    @Override
    public INDArray putiColumnVector(INDArray columnVector) {
        return null;
    }

    @Override
    public INDArray putiRowVector(INDArray rowVector) {
        return null;
    }

    @Override
    public INDArray mmul(INDArray other, MMulTranspose mMulTranspose) {
        return null;
    }

    @Override
    public INDArray mmul(INDArray other, INDArray result, MMulTranspose mMulTranspose) {
        return null;
    }

    @Override
    public INDArray mmuli(INDArray other, MMulTranspose transpose) {
        return null;
    }

    @Override
    public INDArray mmuli(INDArray other, INDArray result, MMulTranspose transpose) {
        return null;
    }

    /**
     * This method returns true if this INDArray is special case: no-value INDArray
     *
     * @return
     */
    @Override
    public boolean isEmpty() {
        return false;
    }

    @Override
    public boolean isR() {
        return false;
    }

    @Override
    public boolean isZ() {
        return false;
    }

    @Override
    public boolean isB() {
        return false;
    }

    @Override
    public INDArray castTo(DataType dataType) {
        return null;
    }

    @Override
    public boolean all() {
        return false;
    }

    @Override
    public boolean any() {
        return false;
    }

    @Override
    public boolean none() {
        return false;
    }

    @Override
    public boolean closeable() {
        return false;
    }

    @Override
    public void close() {

    }

    @Override
    public boolean wasClosed() {
        return false;
    }

    @Override
    public boolean isS() {
        return false;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy