All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.keras.NeuralNetworkReader Maven / Gradle / Ivy

The newest version!
package org.deeplearning4j.keras;

import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
import org.deeplearning4j.nn.modelimport.keras.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;

import java.io.IOException;

/**
 * Reads the neural network model from Keras, specified by the parameters. Reuses the -modelimport code.
 *
 * @author [email protected]
 */
@Slf4j
public class NeuralNetworkReader {

    public MultiLayerNetwork readNeuralNetwork(EntryPointFitParameters entryPointFitParameters)
                    throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {

        MultiLayerNetwork multiLayerNetwork;
        if (KerasModelType.SEQUENTIAL.equals(entryPointFitParameters.getType())) {
            multiLayerNetwork = KerasModelImport
                            .importKerasSequentialModelAndWeights(entryPointFitParameters.getModelFilePath());
            multiLayerNetwork.init();
        } else {
            throw new RuntimeException("Model type unsupported! (" + entryPointFitParameters.getType() + ")");
        }

        return multiLayerNetwork;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy