org.nd4j.linalg.indexing.IntervalIndex Maven / Gradle / Ivy
package org.nd4j.linalg.indexing;
import org.nd4j.linalg.api.ndarray.INDArray;
/**
* And indexing representing
* an interval
*
* @author Adam Gibson
*/
public class IntervalIndex implements INDArrayIndex {
protected int begin, end;
protected boolean inclusive;
protected int stride = 1;
protected int index = 0;
protected int length = 0;
/**
*
* @param inclusive whether to include the last number
* @param stride the stride for the interval
*/
public IntervalIndex(boolean inclusive, int stride) {
this.inclusive = inclusive;
this.stride = stride;
this.length = Math.abs((end - begin)) / stride;
}
@Override
public int end() {
return end;
}
@Override
public int offset() {
return begin;
}
@Override
public int length() {
return length;
}
@Override
public int stride() {
return stride;
}
@Override
public int current() {
return index;
}
@Override
public boolean hasNext() {
return index < end();
}
@Override
public int next() {
int ret = index;
index += stride;
return ret;
}
@Override
public void reverse() {
int oldEnd = end;
int oldBegin = begin;
this.end = oldBegin;
this.begin = oldEnd;
}
@Override
public boolean isInterval() {
return true;
}
@Override
public void setInterval(boolean isInterval) {
}
@Override
public void init(INDArray arr, int begin, int dimension) {
this.begin = begin;
this.index = begin;
this.end = inclusive ? arr.size(dimension) + 1 : arr.size(dimension);
for (int i = begin; i < end; i += stride) {
length++;
}
}
@Override
public void init(INDArray arr, int dimension) {
init(arr, 0, dimension);
}
@Override
public void init(int begin, int end) {
this.begin = begin;
this.index = begin;
this.end = inclusive ? end + 1 : end;
for (int i = begin; i < this.end; i += stride) {
length++;
}
}
@Override
public void reset() {
}
@Override
public boolean equals(Object o) {
if (this == o)
return true;
if (!(o instanceof IntervalIndex))
return false;
IntervalIndex that = (IntervalIndex) o;
if (begin != that.begin)
return false;
if (end != that.end)
return false;
if (inclusive != that.inclusive)
return false;
if (stride != that.stride)
return false;
return index == that.index;
}
@Override
public int hashCode() {
int result = begin;
result = 31 * result + end;
result = 31 * result + (inclusive ? 1 : 0);
result = 31 * result + stride;
result = 31 * result + index;
return result;
}
}