All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.indexing.IntervalIndex Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.nd4j.linalg.indexing;

import org.nd4j.linalg.api.ndarray.INDArray;

/**
 * And indexing representing
 * an interval
 *
 * @author Adam Gibson
 */
public class IntervalIndex implements INDArrayIndex {

    protected int begin, end;
    protected boolean inclusive;
    protected int stride = 1;
    protected int index = 0;
    protected int length = 0;

    /**
     *
     * @param inclusive whether to include the last number
     * @param stride the stride for the interval
     */
    public IntervalIndex(boolean inclusive, int stride) {
        this.inclusive = inclusive;
        this.stride = stride;
        this.length = Math.abs((end - begin)) / stride;
    }

    @Override
    public int end() {
        return end;
    }

    @Override
    public int offset() {
        return begin;
    }

    @Override
    public int length() {
        return length;
    }

    @Override
    public int stride() {
        return stride;
    }

    @Override
    public int current() {
        return index;
    }

    @Override
    public boolean hasNext() {
        return index < end();
    }

    @Override
    public int next() {
        int ret = index;
        index += stride;
        return ret;
    }


    @Override
    public void reverse() {
        int oldEnd = end;
        int oldBegin = begin;
        this.end = oldBegin;
        this.begin = oldEnd;
    }

    @Override
    public boolean isInterval() {
        return true;
    }

    @Override
    public void setInterval(boolean isInterval) {

    }

    @Override
    public void init(INDArray arr, int begin, int dimension) {
        this.begin = begin;
        this.index = begin;
        this.end = inclusive ? arr.size(dimension) + 1 : arr.size(dimension);
        for (int i = begin; i < end; i += stride) {
            length++;
        }
    }

    @Override
    public void init(INDArray arr, int dimension) {
        init(arr, 0, dimension);
    }

    @Override
    public void init(int begin, int end) {
        this.begin = begin;
        this.index = begin;
        this.end = inclusive ? end + 1 : end;
        for (int i = begin; i < this.end; i += stride) {
            length++;
        }

    }

    @Override
    public void reset() {

    }

    @Override
    public boolean equals(Object o) {
        if (this == o)
            return true;
        if (!(o instanceof IntervalIndex))
            return false;

        IntervalIndex that = (IntervalIndex) o;

        if (begin != that.begin)
            return false;
        if (end != that.end)
            return false;
        if (inclusive != that.inclusive)
            return false;
        if (stride != that.stride)
            return false;
        return index == that.index;

    }

    @Override
    public int hashCode() {
        int result = begin;
        result = 31 * result + end;
        result = 31 * result + (inclusive ? 1 : 0);
        result = 31 * result + stride;
        result = 31 * result + index;
        return result;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy