All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.indexing.IntervalIndex Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.nd4j.linalg.indexing;

import org.nd4j.linalg.api.ndarray.INDArray;

/**
 * And indexing representing
 * an interval
 *
 * @author Adam Gibson
 */
public class IntervalIndex implements INDArrayIndex {

    protected int begin,end;
    protected boolean inclusive;
    protected int stride = 1;
    protected int index = 0;
    protected int length = 0;
    /**
     *
     * @param inclusive whether to include the last number
     * @param stride the stride for the interval
     */
    public IntervalIndex(boolean inclusive, int stride) {
        this.inclusive = inclusive;
        this.stride = stride;
        this.length = Math.abs((end - begin)) / stride;
    }

    @Override
    public int end() {
        return end;
    }

    @Override
    public int offset() {
        return begin;
    }

    @Override
    public int length() {
        return length;
    }

    @Override
    public int stride() {
        return stride;
    }

    @Override
    public int current() {
        return index;
    }

    @Override
    public boolean hasNext() {
        return index < end();
    }

    @Override
    public int next() {
        int ret = index;
        index += stride;
        return  ret;
    }


    @Override
    public void reverse() {
        int oldEnd = end;
        int oldBegin = begin;
        this.end = oldBegin;
        this.begin = oldEnd;
    }

    @Override
    public boolean isInterval() {
        return true;
    }

    @Override
    public void setInterval(boolean isInterval) {

    }

    @Override
    public void init(INDArray arr, int begin, int dimension) {
        this.begin = begin;
        this.index = begin;
        this.end = inclusive ? arr.size(dimension) + 1 : arr.size(dimension);
        for(int i = begin; i < end; i+= stride) {
            length++;
        }
    }

    @Override
    public void init(INDArray arr, int dimension) {
        init(arr,0,dimension);
    }

    @Override
    public void init(int begin, int end) {
        this.begin = begin;
        this.index = begin;
        this.end = inclusive ? end + 1 : end;
        for(int i = begin; i < this.end; i+= stride) {
            length++;
        }

    }

    @Override
    public void reset() {

    }

    @Override
    public boolean equals(Object o) {
        if (this == o) return true;
        if (!(o instanceof IntervalIndex)) return false;

        IntervalIndex that = (IntervalIndex) o;

        if (begin != that.begin) return false;
        if (end != that.end) return false;
        if (inclusive != that.inclusive) return false;
        if (stride != that.stride) return false;
        return index == that.index;

    }

    @Override
    public int hashCode() {
        int result = begin;
        result = 31 * result + end;
        result = 31 * result + (inclusive ? 1 : 0);
        result = 31 * result + stride;
        result = 31 * result + index;
        return result;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy