All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.keras.NDArrayHDF5Reader Maven / Gradle / Ivy

The newest version!
package org.deeplearning4j.keras;

import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.hdf5;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;

import java.nio.file.Path;

import static org.bytedeco.javacpp.hdf5.H5F_ACC_RDONLY;

/**
 * Reads and INDArray from a HDF5 dataset. The array is expected to be included as the "data" dataset inside the file.
 * The shape of the output array is the same as the one stored in the HDF5 file.
 *
 * @author [email protected]
 */
public class NDArrayHDF5Reader {

    private hdf5.DataType dataType = new hdf5.DataType(hdf5.PredType.NATIVE_FLOAT());

    /**
     * Reads an HDF5 file into an NDArray.
     *
     * @param inputFilePath Path of the HDF5 file
     * @return NDArray with data and a correct shape
     */
    public INDArray readFromPath(Path inputFilePath) {
        try (hdf5.H5File h5File = new hdf5.H5File()) {
            h5File.openFile(inputFilePath.toString(), H5F_ACC_RDONLY);
            hdf5.DataSet dataSet = h5File.asCommonFG().openDataSet("data");
            int[] shape = extractShape(dataSet);
            long totalSize = ArrayUtil.prodLong(shape);
            DataBuffer dataBuffer = readFromDataSet(dataSet, (int) totalSize);

            INDArray input = Nd4j.create(shape);
            input.setData(dataBuffer);

            return input;
        }
    }

    private DataBuffer readFromDataSet(hdf5.DataSet dataSet, int total) {
        float[] dataBuffer = new float[total];
        FloatPointer fp = new FloatPointer(dataBuffer);
        dataSet.read(fp, dataType);
        fp.get(dataBuffer);
        return Nd4j.createBuffer(dataBuffer);
    }

    private int[] extractShape(hdf5.DataSet dataSet) {
        hdf5.DataSpace space = dataSet.getSpace();
        int nbDims = space.getSimpleExtentNdims();
        long[] shape = new long[nbDims];
        space.getSimpleExtentDims(shape);
        return ArrayUtil.toInts(shape);
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy