All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.earlystopping.EarlyStoppingConfiguration Maven / Gradle / Ivy

/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.earlystopping;

import lombok.Data;
import lombok.NoArgsConstructor;
import org.deeplearning4j.earlystopping.saver.InMemoryModelSaver;
import org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator;
import org.deeplearning4j.earlystopping.termination.EpochTerminationCondition;
import org.deeplearning4j.earlystopping.termination.IterationTerminationCondition;
import org.deeplearning4j.exception.DL4JInvalidConfigException;
import org.deeplearning4j.nn.api.Model;
import org.nd4j.common.function.Supplier;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

@Data
@NoArgsConstructor
public class EarlyStoppingConfiguration implements Serializable {

    private EarlyStoppingModelSaver modelSaver;
    private List epochTerminationConditions;
    private List iterationTerminationConditions;
    private boolean saveLastModel;
    private int evaluateEveryNEpochs;
    private ScoreCalculator scoreCalculator;
    private Supplier scoreCalculatorSupplier;

    private EarlyStoppingConfiguration(Builder builder) {
        this.modelSaver = builder.modelSaver;
        this.epochTerminationConditions = builder.epochTerminationConditions;
        this.iterationTerminationConditions = builder.iterationTerminationConditions;
        this.saveLastModel = builder.saveLastModel;
        this.evaluateEveryNEpochs = builder.evaluateEveryNEpochs;
        this.scoreCalculator = builder.scoreCalculator;
        this.scoreCalculatorSupplier = builder.scoreCalculatorSupplier;
    }

    public ScoreCalculator getScoreCalculator(){
        if(scoreCalculatorSupplier != null){
            return scoreCalculatorSupplier.get();
        }
        return scoreCalculator;
    }


    public void validate() {
        if(scoreCalculator == null && scoreCalculatorSupplier == null) {
            throw new DL4JInvalidConfigException("A score calculator or score calculator supplier must be defined.");
        }

        if(modelSaver == null) {
            throw new DL4JInvalidConfigException("A model saver must be defined");
        }

        boolean hasTermination = false;
        if(iterationTerminationConditions != null && !iterationTerminationConditions.isEmpty()) {
            hasTermination = true;
        }

        else if(epochTerminationConditions != null && !epochTerminationConditions.isEmpty()) {
            hasTermination = true;
        }

        if(!hasTermination) {
            throw new DL4JInvalidConfigException("No termination conditions defined.");
        }
    }


    public static class Builder {

        private EarlyStoppingModelSaver modelSaver = new InMemoryModelSaver<>();
        private List epochTerminationConditions = new ArrayList<>();
        private List iterationTerminationConditions = new ArrayList<>();
        private boolean saveLastModel = false;
        private int evaluateEveryNEpochs = 1;
        private ScoreCalculator scoreCalculator;
        private Supplier scoreCalculatorSupplier;


        /** How should models be saved? (Default: in memory)*/
        public Builder modelSaver(EarlyStoppingModelSaver modelSaver) {
            this.modelSaver = modelSaver;
            return this;
        }

        /** Termination conditions to be evaluated every N epochs, with N set by evaluateEveryNEpochs option */
        public Builder epochTerminationConditions(EpochTerminationCondition... terminationConditions) {
            epochTerminationConditions.clear();
            Collections.addAll(epochTerminationConditions, terminationConditions);
            return this;
        }

        /** Termination conditions to be evaluated every N epochs, with N set by evaluateEveryNEpochs option */
        public Builder epochTerminationConditions(List terminationConditions) {
            this.epochTerminationConditions = terminationConditions;
            return this;
        }

        /** Termination conditions to be evaluated every iteration (minibatch)*/
        public Builder iterationTerminationConditions(IterationTerminationCondition... terminationConditions) {
            iterationTerminationConditions.clear();
            Collections.addAll(iterationTerminationConditions, terminationConditions);
            return this;
        }

        /** Save the last model? If true: save the most recent model at each epoch, in addition to the best
         * model (whenever the best model improves). If false: only save the best model. Default: false
         * Useful for example if you might want to continue training after a max-time terminatino condition
         * occurs.
         */
        public Builder saveLastModel(boolean saveLastModel) {
            this.saveLastModel = saveLastModel;
            return this;
        }

        /** How frequently should evaluations be conducted (in terms of epochs)? Defaults to every (1) epochs. */
        public Builder evaluateEveryNEpochs(int everyNEpochs) {
            this.evaluateEveryNEpochs = everyNEpochs;
            return this;
        }

        /** Score calculator. Used to calculate a score (such as loss function on a test set), every N epochs,
         * where N is set by {@link #evaluateEveryNEpochs}
         */
        public Builder scoreCalculator(ScoreCalculator scoreCalculator) {
            this.scoreCalculator = scoreCalculator;
            return this;
        }

        /** Score calculator. Used to calculate a score (such as loss function on a test set), every N epochs,
         * where N is set by {@link #evaluateEveryNEpochs}
         */
        public Builder scoreCalculator(Supplier scoreCalculatorSupplier){
            this.scoreCalculatorSupplier = scoreCalculatorSupplier;
            return this;
        }

        /** Create the early stopping configuration */
        public EarlyStoppingConfiguration build() {
            return new EarlyStoppingConfiguration<>(this);
        }

    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy