All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.earlystopping.saver.InMemoryModelSaver Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.deeplearning4j.earlystopping.saver;

import org.deeplearning4j.earlystopping.EarlyStoppingModelSaver;
import org.deeplearning4j.nn.api.Model;

import java.io.IOException;

/** Save the best (and latest) models for early stopping training to memory for later retrieval
 * Note: Assumes that network is cloneable via .clone() method
 * @param  Type of model. For example, {@link org.deeplearning4j.nn.multilayer.MultiLayerNetwork} or {@link org.deeplearning4j.nn.graph.ComputationGraph}
 */
public class InMemoryModelSaver implements EarlyStoppingModelSaver {

    private transient T bestModel;
    private transient T latestModel;

    @Override
    @SuppressWarnings("unchecked")
    public void saveBestModel(T net, double score) throws IOException {
        try {
            //Necessary because close is protected :S
            bestModel = (T) (net.getClass().getDeclaredMethod("clone")).invoke(net);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    @SuppressWarnings("unchecked")
    public void saveLatestModel(T net, double score) throws IOException {
        try {
            //Necessary because close is protected :S
            latestModel = (T) (net.getClass().getDeclaredMethod("clone")).invoke(net);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public T getBestModel() throws IOException {
        return bestModel;
    }

    @Override
    public T getLatestModel() throws IOException {
        return latestModel;
    }

    @Override
    public String toString() {
        return "InMemoryModelSaver()";
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy