org.deeplearning4j.keras.HDF5MiniBatchDataSetIterator Maven / Gradle / Ivy
The newest version!
package org.deeplearning4j.keras;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import java.io.File;
import java.nio.file.Paths;
import java.util.List;
/**
* Iterator reading mini batches of data stored in separate files. Labels and features are expected to be dumped into
* separate directories. Filenames are expected to adhere to a predefined pattern: `batch_%d.h5`.
* This class supports only a very narrow subset of the DataSetIterator interface! (e.g.
* there is no support for pre-processing of data)
*
* @author [email protected]
*/
@Slf4j
public class HDF5MiniBatchDataSetIterator implements DataSetIterator {
private static final String FILE_NAME_PATTERN = "batch_%d.h5";
private final NDArrayHDF5Reader ndArrayHDF5Reader = new NDArrayHDF5Reader();
private final File trainFeaturesDirectory;
private final File trainLabelsDirectory;
private final int batchesCount;
private int currentIdx;
private DataSetPreProcessor preProcessor;
public HDF5MiniBatchDataSetIterator(String trainFeaturesDirectory, String trainLabelsDirectory) {
this.trainFeaturesDirectory = new File(trainFeaturesDirectory);
this.trainLabelsDirectory = new File(trainLabelsDirectory);
this.batchesCount = this.trainFeaturesDirectory.list().length;
}
@Override
public boolean hasNext() {
return currentIdx < batchesCount;
}
@Override
public DataSet next() {
DataSet dataSet = readIdx(currentIdx);
currentIdx++;
if (preProcessor != null) {
if (!dataSet.isPreProcessed()) {
preProcessor.preProcess(dataSet);
dataSet.markAsPreProcessed();
}
}
return dataSet;
}
private DataSet readIdx(int currentIdx) {
String batchFileName = fileNameForIdx(currentIdx);
if (log.isTraceEnabled()) {
log.trace("Reading: " + batchFileName);
}
INDArray features = ndArrayHDF5Reader
.readFromPath(Paths.get(trainFeaturesDirectory.getAbsolutePath(), batchFileName));
INDArray labels = ndArrayHDF5Reader
.readFromPath(Paths.get(trainLabelsDirectory.getAbsolutePath(), batchFileName));
return new DataSet(features, labels);
}
private String fileNameForIdx(int currentIdx) {
return String.format(FILE_NAME_PATTERN, currentIdx);
}
@Override
public boolean resetSupported() {
return true;
}
@Override
public boolean asyncSupported() {
/**
* Async support is turned off on purpose: otherwise there are indeterministic segfaults in JavaCPP
* when cleaning memory after HDF5 libs.
*/
return false;
}
@Override
public void reset() {
currentIdx = 0;
}
@Override
public int cursor() {
return currentIdx;
}
@Override
public DataSet next(int num) {
throw new UnsupportedOperationException("Can't load custom number of samples");
}
@Override
public int totalExamples() {
throw new UnsupportedOperationException();
}
@Override
public int inputColumns() {
throw new UnsupportedOperationException();
}
@Override
public int totalOutcomes() {
throw new UnsupportedOperationException();
}
@Override
public int batch() {
throw new UnsupportedOperationException();
}
@Override
public int numExamples() {
throw new UnsupportedOperationException();
}
@Override
public void setPreProcessor(DataSetPreProcessor preProcessor) {
this.preProcessor = preProcessor;
}
@Override
public DataSetPreProcessor getPreProcessor() {
return preProcessor;
}
@Override
public List getLabels() {
throw new UnsupportedOperationException();
}
@Override
public void remove() {
// no-op
}
}