All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.keras.HDF5MiniBatchDataSetIterator Maven / Gradle / Ivy

The newest version!
package org.deeplearning4j.keras;

import lombok.extern.slf4j.Slf4j;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;

import java.io.File;
import java.nio.file.Paths;
import java.util.List;

/**
 * Iterator reading mini batches of data stored in separate files. Labels and features are expected to be dumped into
 * separate directories. Filenames are expected to adhere to a predefined pattern: `batch_%d.h5`.
 * This class supports only a very narrow subset of the DataSetIterator interface! (e.g.
 * there is no support for pre-processing of data)
 *
 * @author [email protected]
 */
@Slf4j
public class HDF5MiniBatchDataSetIterator implements DataSetIterator {

    private static final String FILE_NAME_PATTERN = "batch_%d.h5";

    private final NDArrayHDF5Reader ndArrayHDF5Reader = new NDArrayHDF5Reader();

    private final File trainFeaturesDirectory;
    private final File trainLabelsDirectory;
    private final int batchesCount;
    private int currentIdx;
    private DataSetPreProcessor preProcessor;


    public HDF5MiniBatchDataSetIterator(String trainFeaturesDirectory, String trainLabelsDirectory) {
        this.trainFeaturesDirectory = new File(trainFeaturesDirectory);
        this.trainLabelsDirectory = new File(trainLabelsDirectory);
        this.batchesCount = this.trainFeaturesDirectory.list().length;
    }

    @Override
    public boolean hasNext() {
        return currentIdx < batchesCount;
    }

    @Override
    public DataSet next() {
        DataSet dataSet = readIdx(currentIdx);
        currentIdx++;

        if (preProcessor != null) {
            if (!dataSet.isPreProcessed()) {
                preProcessor.preProcess(dataSet);
                dataSet.markAsPreProcessed();
            }
        }

        return dataSet;
    }

    private DataSet readIdx(int currentIdx) {
        String batchFileName = fileNameForIdx(currentIdx);

        if (log.isTraceEnabled()) {
            log.trace("Reading: " + batchFileName);
        }

        INDArray features = ndArrayHDF5Reader
                        .readFromPath(Paths.get(trainFeaturesDirectory.getAbsolutePath(), batchFileName));
        INDArray labels = ndArrayHDF5Reader
                        .readFromPath(Paths.get(trainLabelsDirectory.getAbsolutePath(), batchFileName));

        return new DataSet(features, labels);
    }

    private String fileNameForIdx(int currentIdx) {
        return String.format(FILE_NAME_PATTERN, currentIdx);
    }

    @Override
    public boolean resetSupported() {
        return true;
    }

    @Override
    public boolean asyncSupported() {
        /**
         * Async support is turned off on purpose: otherwise there are indeterministic segfaults in JavaCPP
         * when cleaning memory after HDF5 libs.
         */
        return false;
    }

    @Override
    public void reset() {
        currentIdx = 0;
    }

    @Override
    public int cursor() {
        return currentIdx;
    }


    @Override
    public DataSet next(int num) {
        throw new UnsupportedOperationException("Can't load custom number of samples");
    }

    @Override
    public int totalExamples() {
        throw new UnsupportedOperationException();
    }

    @Override
    public int inputColumns() {
        throw new UnsupportedOperationException();
    }

    @Override
    public int totalOutcomes() {
        throw new UnsupportedOperationException();
    }

    @Override
    public int batch() {
        throw new UnsupportedOperationException();
    }

    @Override
    public int numExamples() {
        throw new UnsupportedOperationException();
    }

    @Override
    public void setPreProcessor(DataSetPreProcessor preProcessor) {
        this.preProcessor = preProcessor;
    }

    @Override
    public DataSetPreProcessor getPreProcessor() {
        return preProcessor;
    }

    @Override
    public List getLabels() {
        throw new UnsupportedOperationException();
    }

    @Override
    public void remove() {
        // no-op
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy