All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.indexing.NDArrayIndex Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.nd4j.linalg.indexing;

import com.google.common.primitives.Ints;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.util.ArrayUtil;

import java.util.Arrays;

/**
 * NDArray indexing
 *
 * @author Adam Gibson
 */
public class NDArrayIndex {

    private int[] indices = new int[1];
    private boolean isInterval = false;





    public NDArrayIndex(int...indices) {
        this.indices = indices;

    }



    public int end() {
        if(indices != null && indices.length > 0)
            return indices[indices.length - 1];
        return 0;
    }

    public int offset() {
        if(indices.length < 1)
            return 0;
        return indices[0];
    }


    /**
     * Returns the length of the indices
     * @return the length of the range
     */
    public int length() {
        if(indices.length < 1)
            return 0;
        return indices[indices.length - 1] - indices[0];
    }

    public int[] indices() {
        return indices;
    }

    public void reverse() {
        ArrayUtil.reverse(indices);
    }


    @Override
    public String toString() {
        return "NDArrayIndex{" +
                "indices=" + Arrays.toString(indices) +
                '}';
    }

    @Override
    public boolean equals(Object o) {
        if (this == o) return true;
        if (!(o instanceof NDArrayIndex)) return false;

        NDArrayIndex that = (NDArrayIndex) o;

        if (!Arrays.equals(indices, that.indices)) return false;

        return true;
    }

    @Override
    public int hashCode() {
        return Arrays.hashCode(indices);
    }


    /**
     * Create from a matrix. The rows are the indices
     * The columns are the individual element in each ndarrayindex
     * @param index the matrix to getFloat indices from
     * @return the indices to getFloat
     */
    public static NDArrayIndex[] create(INDArray index) {

        if(index.isMatrix()) {

            NDArrayIndex[] ret = new NDArrayIndex[index.rows()];
            for(int i = 0; i < index.rows(); i++) {
                INDArray row = index.getRow(i);
                int[] nums = new int[index.getRow(i).columns()];
                for(int j = 0; j < row.columns(); j++) {
                    nums[j] = (int) row.getFloat(j);
                }

                NDArrayIndex idx = new NDArrayIndex(nums);
                ret[i]  = idx;

            }


            return ret;

        }
        else if(index.isVector()) {
            int[] indices = ArrayUtil.toInts(index);
            return new NDArrayIndex[]{new NDArrayIndex(indices)};
        }


        throw new IllegalArgumentException("Passed in ndarray must be a matrix or a vector");

    }

    public boolean isInterval() {
        return isInterval;
    }

    public void setInterval(boolean isInterval) {
        this.isInterval = isInterval;
    }

    /**
     * Concatneate all of the given indices in to one
     * @param indexes the indexes to concatneate
     * @return the merged indices
     */
    public static NDArrayIndex concat(NDArrayIndex...indexes) {
        int[][] indices = new int[indexes.length][];
        for(int i = 0; i < indexes.length; i++)
            indices[i] = indexes[i].indices();
        return new NDArrayIndex(Ints.concat(indices));
    }

    /**
     * Generates an interval from begin (inclusive) to end (exclusive)
     * @param begin the begin
     * @param end the end index
     * @return the interval
     */
    public static NDArrayIndex interval(int begin,int end) {
        return interval(begin,end,false);
    }


    /**
     * Generates an interval from begin (inclusive) to end (exclusive)
     * @param begin the begin
     * @param end the end index
     * @param inclusive whether the end should be inclusive or not
     * @return the interval
     */
    public static NDArrayIndex interval(int begin,int end,boolean inclusive) {
        assert begin <= end : "Beginning index in range must be less than end";
        NDArrayIndex ret =  new NDArrayIndex(ArrayUtil.range(begin,inclusive ?  end + 1 : end));
        ret.isInterval = true;
        return ret;
    }


}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy