org.deeplearning4j.datasets.iterator.ExistingDataSetIterator Maven / Gradle / Ivy
package org.deeplearning4j.datasets.iterator;
import lombok.Getter;
import lombok.NonNull;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import java.util.Iterator;
import java.util.List;
/**
* This wrapper provides DataSetIterator interface to existing java Iterable and Iterator
*
* @author [email protected]
*/
public class ExistingDataSetIterator implements DataSetIterator {
@Getter
private DataSetPreProcessor preProcessor;
private transient Iterable iterable;
private transient Iterator iterator;
private int totalExamples = 0;
private int numFeatures = 0;
private int numLabels = 0;
private List labels;
public ExistingDataSetIterator(@NonNull Iterator iterator) {
this.iterator = iterator;
}
public ExistingDataSetIterator(@NonNull Iterator iterator, @NonNull List labels) {
this(iterator);
this.labels = labels;
}
public ExistingDataSetIterator(@NonNull Iterable iterable) {
this.iterable = iterable;
this.iterator = iterable.iterator();
}
public ExistingDataSetIterator(@NonNull Iterable iterable, @NonNull List labels) {
this(iterable);
this.labels = labels;
}
public ExistingDataSetIterator(@NonNull Iterable iterable, int totalExamples, int numFeatures,
int numLabels) {
this(iterable);
this.totalExamples = totalExamples;
this.numFeatures = numFeatures;
this.numLabels = numLabels;
}
@Override
public DataSet next(int num) {
// TODO: this might be changed
throw new UnsupportedOperationException("next(int) isn't supported");
}
@Override
public int totalExamples() {
return totalExamples;
}
@Override
public int inputColumns() {
return numFeatures;
}
@Override
public int totalOutcomes() {
if (labels != null)
return labels.size();
return numLabels;
}
@Override
public boolean resetSupported() {
return iterable != null;
}
@Override
public boolean asyncSupported() {
//No need to asynchronously prefetch here: already in memory
return false;
}
@Override
public void reset() {
if (iterable != null)
this.iterator = iterable.iterator();
else
throw new IllegalStateException(
"To use reset() method you need to provide Iterable, not Iterator");
}
@Override
public int batch() {
return 0;
}
@Override
public int cursor() {
return 0;
}
@Override
public int numExamples() {
return totalExamples;
}
@Override
public void setPreProcessor(DataSetPreProcessor preProcessor) {
this.preProcessor = preProcessor;
}
@Override
public List getLabels() {
return labels;
}
@Override
public boolean hasNext() {
if (iterator != null)
return iterator.hasNext();
return false;
}
@Override
public DataSet next() {
if (preProcessor != null) {
DataSet ds = iterator.next();
if (!ds.isPreProcessed()) {
preProcessor.preProcess(ds);
ds.markAsPreProcessed();
}
return ds;
} else
return iterator.next();
}
@Override
public void remove() {
// no-op
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy