org.deeplearning4j.keras.NeuralNetworkReader Maven / Gradle / Ivy
The newest version!
package org.deeplearning4j.keras;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
import org.deeplearning4j.nn.modelimport.keras.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import java.io.IOException;
/**
* Reads the neural network model from Keras, specified by the parameters. Reuses the -modelimport code.
*
* @author [email protected]
*/
@Slf4j
public class NeuralNetworkReader {
public MultiLayerNetwork readNeuralNetwork(EntryPointFitParameters entryPointFitParameters)
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
MultiLayerNetwork multiLayerNetwork;
if (KerasModelType.SEQUENTIAL.equals(entryPointFitParameters.getType())) {
multiLayerNetwork = KerasModelImport
.importKerasSequentialModelAndWeights(entryPointFitParameters.getModelFilePath());
multiLayerNetwork.init();
} else {
throw new RuntimeException("Model type unsupported! (" + entryPointFitParameters.getType() + ")");
}
return multiLayerNetwork;
}
}