org.deeplearning4j.earlystopping.EarlyStoppingConfiguration Maven / Gradle / Ivy
/*-
*
* * Copyright 2015 Skymind,Inc.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*/
package org.deeplearning4j.earlystopping;
import lombok.Data;
import org.deeplearning4j.earlystopping.saver.InMemoryModelSaver;
import org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator;
import org.deeplearning4j.earlystopping.termination.EpochTerminationCondition;
import org.deeplearning4j.earlystopping.termination.IterationTerminationCondition;
import org.deeplearning4j.nn.api.Model;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
/** Early stopping configuration: Specifies the various configuration options for running training with early stopping.
* Users need to specify the following:
* (a) EarlyStoppingModelSaver: How models will be saved (to disk, to memory, etc) (Default: in memory)
* (b) Termination conditions: at least one termination condition must be specified
* (i) Iteration termination conditions: calculated once for each minibatch. For example, maxTime or invalid (NaN/infinite) scores
* (ii) Epoch termination conditions: calculated once per epoch. For example, maxEpochs or no improvement for N epochs
* (c) Score calculator: what score should be calculated at every epoch? (For example: test set loss or test set accuracy)
* (d) How frequently (ever N epochs) should scores be calculated? (Default: every epoch)
* @param Type of model. For example, {@link org.deeplearning4j.nn.multilayer.MultiLayerNetwork} or {@link org.deeplearning4j.nn.graph.ComputationGraph}
* @author Alex Black
*/
@Data
public class EarlyStoppingConfiguration implements Serializable {
private EarlyStoppingModelSaver modelSaver;
private List epochTerminationConditions;
private List iterationTerminationConditions;
private boolean saveLastModel;
private int evaluateEveryNEpochs;
private ScoreCalculator scoreCalculator;
private EarlyStoppingConfiguration(Builder builder) {
this.modelSaver = builder.modelSaver;
this.epochTerminationConditions = builder.epochTerminationConditions;
this.iterationTerminationConditions = builder.iterationTerminationConditions;
this.saveLastModel = builder.saveLastModel;
this.evaluateEveryNEpochs = builder.evaluateEveryNEpochs;
this.scoreCalculator = builder.scoreCalculator;
}
public static class Builder {
private EarlyStoppingModelSaver modelSaver = new InMemoryModelSaver<>();
private List epochTerminationConditions = new ArrayList<>();
private List iterationTerminationConditions = new ArrayList<>();
private boolean saveLastModel = false;
private int evaluateEveryNEpochs = 1;
private ScoreCalculator scoreCalculator;
/** How should models be saved? (Default: in memory)*/
public Builder modelSaver(EarlyStoppingModelSaver modelSaver) {
this.modelSaver = modelSaver;
return this;
}
/** Termination conditions to be evaluated every N epochs, with N set by evaluateEveryNEpochs option */
public Builder epochTerminationConditions(EpochTerminationCondition... terminationConditions) {
epochTerminationConditions.clear();
Collections.addAll(epochTerminationConditions, terminationConditions);
return this;
}
/** Termination conditions to be evaluated every N epochs, with N set by evaluateEveryNEpochs option */
public Builder epochTerminationConditions(List terminationConditions) {
this.epochTerminationConditions = terminationConditions;
return this;
}
/** Termination conditions to be evaluated every iteration (minibatch)*/
public Builder iterationTerminationConditions(IterationTerminationCondition... terminationConditions) {
iterationTerminationConditions.clear();
Collections.addAll(iterationTerminationConditions, terminationConditions);
return this;
}
/** Save the last model? If true: save the most recent model at each epoch, in addition to the best
* model (whenever the best model improves). If false: only save the best model. Default: false
* Useful for example if you might want to continue training after a max-time terminatino condition
* occurs.
*/
public Builder saveLastModel(boolean saveLastModel) {
this.saveLastModel = saveLastModel;
return this;
}
/** How frequently should evaluations be conducted (in terms of epochs)? Defaults to every (1) epochs. */
public Builder evaluateEveryNEpochs(int everyNEpochs) {
this.evaluateEveryNEpochs = everyNEpochs;
return this;
}
/** Score calculator. Used to calculate a score (such as loss function on a test set), every N epochs,
* where N is set by {@link #evaluateEveryNEpochs}
*/
public Builder scoreCalculator(ScoreCalculator scoreCalculator) {
this.scoreCalculator = scoreCalculator;
return this;
}
/** Create the early stopping configuration */
public EarlyStoppingConfiguration build() {
return new EarlyStoppingConfiguration<>(this);
}
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy