All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.indexing.Indices Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.nd4j.linalg.indexing;

import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;

import java.util.ArrayList;
import java.util.List;

/**
 * Indexing util.
 * @author Adam Gibson
 */
public class Indices {

    /**
     * The offsets (begin index) for each index
     * @param indices the indices
     * @return the offsets for the given set of indices
     *
     */
    public static int[] offsets(NDArrayIndex...indices) {
        int[] ret = new int[indices.length];
        for(int i = 0; i < indices.length; i++) {
            int offset = indices[i].offset();
            if(offset == 0 && i > 0 && i < indices.length - 1)
                ret[i] = 1;
            else
                ret[i] = indices[i].offset();
        }
        return ret;
    }


    /**
     * Fill in the missing indices to be the
     * same length as the original shape.
     *
     * Think of this as what fills in the indices for numpy or matlab:
     * Given a which is (4,3,2) in numpy:
     *
     * a[1:3] is filled in by the rest
     * to give back the full slice
     *
     * This algorithm fills in that delta
     *
     * @param shape the original shape
     * @param indexes the indexes to start from
     * @return the filled in indices
     */
    public static NDArrayIndex[] fillIn(int[] shape,NDArrayIndex...indexes) {
        if(shape.length == indexes.length)
            return indexes;

        NDArrayIndex[] newIndexes = new NDArrayIndex[shape.length];
        System.arraycopy(indexes,0,newIndexes,0,indexes.length);

        for(int i = indexes.length; i < shape.length; i++) {
            newIndexes[i] = NDArrayIndex.interval(0,shape[i]);
        }
        return newIndexes;

    }

    /**
     * Prunes indices of greater length than the shape
     * and fills in missing indices if there are any
     * @param originalShape the original shape to adjust to
     * @param indexes the indexes to adjust
     * @return the  adjusted indices
     */
    public static NDArrayIndex[] adjustIndices(int[] originalShape,NDArrayIndex...indexes) {
        if(indexes.length < originalShape.length)
            indexes = fillIn(originalShape,indexes);
        if(indexes.length > originalShape.length) {
            NDArrayIndex[] ret = new NDArrayIndex[originalShape.length];
            System.arraycopy(indexes,0,ret,0,originalShape.length);
            return ret;
        }

        if(indexes.length == originalShape.length)
            return indexes;
        for(int i = 0; i < indexes.length; i++) {
            if(indexes[i].end() >= originalShape[i])
                indexes[i] = NDArrayIndex.interval(0,originalShape[i] - 1);
        }

        return indexes;
    }


    /**
     * Calculate the strides based on the given indices
     * @param ordering the ordering to calculate strides for
     * @param indexes the indices to calculate stride for
     * @return the strides for the given indices
     */
    public static int[] strides(char ordering,NDArrayIndex...indexes) {
        return Nd4j.getStrides(shape(indexes), ordering);
    }

    /**
     * Calculate the shape for the given set of indices.
     *
     * The shape is defined as (for each dimension)
     * the difference between the end index + 1 and
     * the begin index
     * @param indices the indices to calculate the shape for
     * @return the shape for the given indices
     */
    public static int[] shape(NDArrayIndex...indices) {
        int[] ret = new int[indices.length];
        for(int i = 0; i < ret.length; i++) {
            int[] currIndices = indices[i].indices();

            int end  = currIndices[currIndices.length - 1] + 1;
            int begin = currIndices[0];
            ret[i] = Math.abs(end - begin);
        }

        List nonZeros = new ArrayList<>();
        for(int i = 0; i < ret.length; i++) {
            if(ret[i] > 0)
                nonZeros.add(ret[i]);
        }

        return ArrayUtil.toArray(nonZeros);
    }
    /**
     * Returns whether the indices are contiguous by one or not
     * @param indexes the indices to test
     * @return whether the indices are contiguous by one or not
     */
    public static boolean isContiguous(NDArrayIndex...indexes) {
        return isContiguous(1,indexes);
    }

    /**
     * Returns whether indices are contiguous
     * by a certain amount or not
     * @param indexes the indices to test
     * @param diff the difference considered to be contiguous
     * @return whether the given indices are contiguous or not
     */
    public static boolean isContiguous(int diff,NDArrayIndex...indexes) {
        if(indexes.length < 1)
            return true;
        boolean contiguous = isContiguous(indexes[0].indices(),diff);
        for(int i = 1; i < indexes.length; i++)
            contiguous = contiguous && isContiguous(indexes[i].indices(),diff);

        return contiguous;
    }
    /**
     * Returns whether the indices are contiguous by one or not
     * @param indices the indices to test
     * @return whether the indices are contiguous by one or not
     */
    public static boolean isContiguous(int[] indices) {
        return isContiguous(indices,1);
    }

    /**
     * Returns whether indices are contiguous
     * by a certain amount or not
     * @param indices the indices to test
     * @param diff the difference considered to be contiguous
     * @return whether the given indices are contiguous or not
     */
    public static boolean isContiguous(int[] indices,int diff) {
        if(indices.length < 1)
            return true;
        for(int i = 1; i < indices.length; i++) {
            if(Math.abs(indices[i] - indices[i - 1]) > diff)
                return false;
        }

        return true;
    }

    /**
     * Calculate the shape for the given set of indices.
     *
     * The shape is defined as (for each dimension)
     * the difference between the end index + 1 and
     * the begin index
     *
     * If specified, this will check for whether any of the indices are >= to end - 1
     * and if so, prune it down
     *
     * @param shape the original shape
     * @param indices the indices to calculate the shape for
     * @return the shape for the given indices
     */
    public static int[] shape(int[] shape,NDArrayIndex...indices) {
        if(indices.length > shape.length)
            return shape;

        int[] ret = new int[indices.length];
        for(int i = 0; i < ret.length; i++) {
            int[] currIndices = indices[i].indices();
            if(currIndices.length < 1)
                continue;
            int end  = currIndices[currIndices.length - 1];
            if(end > shape[i])
                end = shape[i] - 1;
            int begin = currIndices[0];

            ret[i]  = indices[i].isInterval() ? Math.abs(end - begin) + 1 :
                    indices[i].indices().length;
        }

        List nonZeros = new ArrayList<>();
        for(int i = 0; i < ret.length; i++) {
            if(ret[i] > 0)
                nonZeros.add(ret[i]);
        }


        return ArrayUtil.toArray(nonZeros);

    }






}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy