All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.earlystopping.EarlyStoppingConfiguration Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*-
 *
 *  * Copyright 2015 Skymind,Inc.
 *  *
 *  *    Licensed under the Apache License, Version 2.0 (the "License");
 *  *    you may not use this file except in compliance with the License.
 *  *    You may obtain a copy of the License at
 *  *
 *  *        http://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  *    Unless required by applicable law or agreed to in writing, software
 *  *    distributed under the License is distributed on an "AS IS" BASIS,
 *  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  *    See the License for the specific language governing permissions and
 *  *    limitations under the License.
 *
 */

package org.deeplearning4j.earlystopping;

import lombok.Data;
import org.deeplearning4j.earlystopping.saver.InMemoryModelSaver;
import org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator;
import org.deeplearning4j.earlystopping.termination.EpochTerminationCondition;
import org.deeplearning4j.earlystopping.termination.IterationTerminationCondition;
import org.deeplearning4j.nn.api.Model;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

/** Early stopping configuration: Specifies the various configuration options for running training with early stopping.
* Users need to specify the following:
* (a) EarlyStoppingModelSaver: How models will be saved (to disk, to memory, etc) (Default: in memory)
* (b) Termination conditions: at least one termination condition must be specified
* (i) Iteration termination conditions: calculated once for each minibatch. For example, maxTime or invalid (NaN/infinite) scores
* (ii) Epoch termination conditions: calculated once per epoch. For example, maxEpochs or no improvement for N epochs
* (c) Score calculator: what score should be calculated at every epoch? (For example: test set loss or test set accuracy)
* (d) How frequently (ever N epochs) should scores be calculated? (Default: every epoch)
* @param Type of model. For example, {@link org.deeplearning4j.nn.multilayer.MultiLayerNetwork} or {@link org.deeplearning4j.nn.graph.ComputationGraph} * @author Alex Black */ @Data public class EarlyStoppingConfiguration implements Serializable { private EarlyStoppingModelSaver modelSaver; private List epochTerminationConditions; private List iterationTerminationConditions; private boolean saveLastModel; private int evaluateEveryNEpochs; private ScoreCalculator scoreCalculator; private EarlyStoppingConfiguration(Builder builder) { this.modelSaver = builder.modelSaver; this.epochTerminationConditions = builder.epochTerminationConditions; this.iterationTerminationConditions = builder.iterationTerminationConditions; this.saveLastModel = builder.saveLastModel; this.evaluateEveryNEpochs = builder.evaluateEveryNEpochs; this.scoreCalculator = builder.scoreCalculator; } public static class Builder { private EarlyStoppingModelSaver modelSaver = new InMemoryModelSaver<>(); private List epochTerminationConditions = new ArrayList<>(); private List iterationTerminationConditions = new ArrayList<>(); private boolean saveLastModel = false; private int evaluateEveryNEpochs = 1; private ScoreCalculator scoreCalculator; /** How should models be saved? (Default: in memory)*/ public Builder modelSaver(EarlyStoppingModelSaver modelSaver) { this.modelSaver = modelSaver; return this; } /** Termination conditions to be evaluated every N epochs, with N set by evaluateEveryNEpochs option */ public Builder epochTerminationConditions(EpochTerminationCondition... terminationConditions) { epochTerminationConditions.clear(); Collections.addAll(epochTerminationConditions, terminationConditions); return this; } /** Termination conditions to be evaluated every N epochs, with N set by evaluateEveryNEpochs option */ public Builder epochTerminationConditions(List terminationConditions) { this.epochTerminationConditions = terminationConditions; return this; } /** Termination conditions to be evaluated every iteration (minibatch)*/ public Builder iterationTerminationConditions(IterationTerminationCondition... terminationConditions) { iterationTerminationConditions.clear(); Collections.addAll(iterationTerminationConditions, terminationConditions); return this; } /** Save the last model? If true: save the most recent model at each epoch, in addition to the best * model (whenever the best model improves). If false: only save the best model. Default: false * Useful for example if you might want to continue training after a max-time terminatino condition * occurs. */ public Builder saveLastModel(boolean saveLastModel) { this.saveLastModel = saveLastModel; return this; } /** How frequently should evaluations be conducted (in terms of epochs)? Defaults to every (1) epochs. */ public Builder evaluateEveryNEpochs(int everyNEpochs) { this.evaluateEveryNEpochs = everyNEpochs; return this; } /** Score calculator. Used to calculate a score (such as loss function on a test set), every N epochs, * where N is set by {@link #evaluateEveryNEpochs} */ public Builder scoreCalculator(ScoreCalculator scoreCalculator) { this.scoreCalculator = scoreCalculator; return this; } /** Create the early stopping configuration */ public EarlyStoppingConfiguration build() { return new EarlyStoppingConfiguration<>(this); } } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy