All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.keras.DeepLearning4jEntryPoint Maven / Gradle / Ivy

The newest version!
package org.deeplearning4j.keras;

import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;

/**
 * API exposed to the Python side. This class contains methods which are used by the python wrapper.
 * It is instantiated directly in the server code.
 */
@Slf4j
public class DeepLearning4jEntryPoint {

    private final NeuralNetworkReader neuralNetworkReader = new NeuralNetworkReader();

    /**
     * Performs fitting of the model which is referenced in the parameters according to learning parameters specified.
     *
     * @param entryPointFitParameters Definition of the model and learning process
     */
    public void fit(EntryPointFitParameters entryPointFitParameters) throws Exception {

        try {
            MultiLayerNetwork multiLayerNetwork = neuralNetworkReader.readNeuralNetwork(entryPointFitParameters);

            DataSetIterator dataSetIterator =
                            new HDF5MiniBatchDataSetIterator(entryPointFitParameters.getTrainFeaturesDirectory(),
                                            entryPointFitParameters.getTrainLabelsDirectory());

            for (int i = 0; i < entryPointFitParameters.getNbEpoch(); i++) {
                log.info("Fitting: " + i);

                multiLayerNetwork.fit(dataSetIterator);
            }

            log.info("Learning model finished");
        } catch (Throwable e) {
            log.error("Error while handling request!", e);
            throw e;
        }
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy