All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.earlystopping.trainer.BaseEarlyStoppingTrainer Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*-
 *
 *  * Copyright 2016 Skymind,Inc.
 *  *
 *  *    Licensed under the Apache License, Version 2.0 (the "License");
 *  *    you may not use this file except in compliance with the License.
 *  *    You may obtain a copy of the License at
 *  *
 *  *        http://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  *    Unless required by applicable law or agreed to in writing, software
 *  *    distributed under the License is distributed on an "AS IS" BASIS,
 *  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  *    See the License for the specific language governing permissions and
 *  *    limitations under the License.
 *
 */

package org.deeplearning4j.earlystopping.trainer;

import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
import org.deeplearning4j.earlystopping.EarlyStoppingResult;
import org.deeplearning4j.earlystopping.listener.EarlyStoppingListener;
import org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator;
import org.deeplearning4j.earlystopping.termination.EpochTerminationCondition;
import org.deeplearning4j.earlystopping.termination.IterationTerminationCondition;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;

/**Base/abstract class for conducting early stopping training locally (single machine).
* Can be used to train a {@link MultiLayerNetwork} or a {@link ComputationGraph} via early stopping * @author Alex Black */ public abstract class BaseEarlyStoppingTrainer implements IEarlyStoppingTrainer { private static Logger log = LoggerFactory.getLogger(BaseEarlyStoppingTrainer.class); protected T model; protected final EarlyStoppingConfiguration esConfig; private final DataSetIterator train; private final MultiDataSetIterator trainMulti; private final Iterator iterator; private EarlyStoppingListener listener; private double bestModelScore = Double.MAX_VALUE; private int bestModelEpoch = -1; protected BaseEarlyStoppingTrainer(EarlyStoppingConfiguration earlyStoppingConfiguration, T model, DataSetIterator train, MultiDataSetIterator trainMulti, EarlyStoppingListener listener) { this.esConfig = earlyStoppingConfiguration; this.model = model; this.train = train; this.trainMulti = trainMulti; this.iterator = (train != null ? train : trainMulti); this.listener = listener; } protected abstract void fit(DataSet ds); protected abstract void fit(MultiDataSet mds); @Override public EarlyStoppingResult fit() { log.info("Starting early stopping training"); if (esConfig.getScoreCalculator() == null) log.warn("No score calculator provided for early stopping. Score will be reported as 0.0 to epoch termination conditions"); //Initialize termination conditions: if (esConfig.getIterationTerminationConditions() != null) { for (IterationTerminationCondition c : esConfig.getIterationTerminationConditions()) { c.initialize(); } } if (esConfig.getEpochTerminationConditions() != null) { for (EpochTerminationCondition c : esConfig.getEpochTerminationConditions()) { c.initialize(); } } if (listener != null) { listener.onStart(esConfig, model); } Map scoreVsEpoch = new LinkedHashMap<>(); int epochCount = 0; while (true) { reset(); double lastScore; boolean terminate = false; IterationTerminationCondition terminationReason = null; int iterCount = 0; while (iterator.hasNext()) { try { if (train != null) { fit((DataSet) iterator.next()); } else fit(trainMulti.next()); } catch (Exception e) { log.warn("Early stopping training terminated due to exception at epoch {}, iteration {}", epochCount, iterCount, e); //Load best model to return T bestModel; try { bestModel = esConfig.getModelSaver().getBestModel(); } catch (IOException e2) { throw new RuntimeException(e2); } return new EarlyStoppingResult<>(EarlyStoppingResult.TerminationReason.Error, e.toString(), scoreVsEpoch, bestModelEpoch, bestModelScore, epochCount, bestModel); } //Check per-iteration termination conditions lastScore = model.score(); for (IterationTerminationCondition c : esConfig.getIterationTerminationConditions()) { if (c.terminate(lastScore)) { terminate = true; terminationReason = c; break; } } if (terminate) { break; } iterCount++; } if (terminate) { //Handle termination condition: log.info("Hit per iteration epoch termination condition at epoch {}, iteration {}. Reason: {}", epochCount, iterCount, terminationReason); if (esConfig.isSaveLastModel()) { //Save last model: try { esConfig.getModelSaver().saveLatestModel(model, 0.0); } catch (IOException e) { throw new RuntimeException("Error saving most recent model", e); } } T bestModel; try { bestModel = esConfig.getModelSaver().getBestModel(); } catch (IOException e2) { throw new RuntimeException(e2); } EarlyStoppingResult result = new EarlyStoppingResult<>( EarlyStoppingResult.TerminationReason.IterationTerminationCondition, terminationReason.toString(), scoreVsEpoch, bestModelEpoch, bestModelScore, epochCount, bestModel); if (listener != null) { listener.onCompletion(result); } return result; } log.info("Completed training epoch {}", epochCount); if ((epochCount == 0 && esConfig.getEvaluateEveryNEpochs() == 1) || epochCount % esConfig.getEvaluateEveryNEpochs() == 0) { //Calculate score at this epoch: ScoreCalculator sc = esConfig.getScoreCalculator(); double score = (sc == null ? 0.0 : esConfig.getScoreCalculator().calculateScore(model)); scoreVsEpoch.put(epochCount - 1, score); if (sc != null && score < bestModelScore) { //Save best model: if (bestModelEpoch == -1) { //First calculated/reported score log.info("Score at epoch {}: {}", epochCount, score); } else { log.info("New best model: score = {}, epoch = {} (previous: score = {}, epoch = {})", score, epochCount, bestModelScore, bestModelEpoch); } bestModelScore = score; bestModelEpoch = epochCount; try { esConfig.getModelSaver().saveBestModel(model, score); } catch (IOException e) { throw new RuntimeException("Error saving best model", e); } } if (esConfig.isSaveLastModel()) { //Save last model: try { esConfig.getModelSaver().saveLatestModel(model, score); } catch (IOException e) { throw new RuntimeException("Error saving most recent model", e); } } if (listener != null) { listener.onEpoch(epochCount, score, esConfig, model); } //Check per-epoch termination conditions: boolean epochTerminate = false; EpochTerminationCondition termReason = null; for (EpochTerminationCondition c : esConfig.getEpochTerminationConditions()) { if (c.terminate(epochCount, score)) { epochTerminate = true; termReason = c; break; } } if (epochTerminate) { log.info("Hit epoch termination condition at epoch {}. Details: {}", epochCount, termReason.toString()); T bestModel; try { bestModel = esConfig.getModelSaver().getBestModel(); } catch (IOException e2) { throw new RuntimeException(e2); } EarlyStoppingResult result = new EarlyStoppingResult<>( EarlyStoppingResult.TerminationReason.EpochTerminationCondition, termReason.toString(), scoreVsEpoch, bestModelEpoch, bestModelScore, epochCount + 1, bestModel); if (listener != null) { listener.onCompletion(result); } return result; } } epochCount++; } } @Override public void setListener(EarlyStoppingListener listener) { this.listener = listener; } protected void reset() { if (train != null) { train.reset(); } if (trainMulti != null) { trainMulti.reset(); } } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy