org.deeplearning4j.spark.earlystopping.BaseSparkEarlyStoppingTrainer Maven / Gradle / Ivy
/*
*
* * Copyright 2016 Skymind,Inc.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*/
package org.deeplearning4j.spark.earlystopping;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
import org.deeplearning4j.earlystopping.EarlyStoppingResult;
import org.deeplearning4j.earlystopping.listener.EarlyStoppingListener;
import org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator;
import org.deeplearning4j.earlystopping.termination.EpochTerminationCondition;
import org.deeplearning4j.earlystopping.termination.IterationTerminationCondition;
import org.deeplearning4j.earlystopping.trainer.IEarlyStoppingTrainer;
import org.deeplearning4j.nn.api.Model;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.LinkedHashMap;
import java.util.Map;
/**
* Base/abstract class for conducting early stopping training via Spark, on a {@link org.deeplearning4j.nn.multilayer.MultiLayerNetwork}
* or a {@link org.deeplearning4j.nn.graph.ComputationGraph}
* @author Alex Black
*/
public abstract class BaseSparkEarlyStoppingTrainer implements IEarlyStoppingTrainer {
private static Logger log = LoggerFactory.getLogger(BaseSparkEarlyStoppingTrainer.class);
private JavaSparkContext sc;
private final EarlyStoppingConfiguration esConfig;
private T net;
private final JavaRDD train;
private final JavaRDD trainMulti;
private EarlyStoppingListener listener;
private double bestModelScore = Double.MAX_VALUE;
private int bestModelEpoch = -1;
protected BaseSparkEarlyStoppingTrainer(JavaSparkContext sc, EarlyStoppingConfiguration esConfig, T net,
JavaRDD train, JavaRDD trainMulti, EarlyStoppingListener listener) {
if((esConfig.getEpochTerminationConditions() == null || esConfig.getEpochTerminationConditions().size() == 0)
&& (esConfig.getIterationTerminationConditions() == null || esConfig.getIterationTerminationConditions().size() == 0)){
throw new IllegalArgumentException("Cannot conduct early stopping without a termination condition (both Iteration "
+ "and Epoch termination conditions are null/empty)");
}
this.sc = sc;
this.esConfig = esConfig;
this.net = net;
this.train = train;
this.trainMulti = trainMulti;
this.listener = listener;
}
protected abstract void fit(JavaRDD data );
protected abstract void fitMulti(JavaRDD data);
protected abstract double getScore();
@Override
public EarlyStoppingResult fit() {
log.info("Starting early stopping training");
if(esConfig.getScoreCalculator() == null) log.warn("No score calculator provided for early stopping. Score will be reported as 0.0 to epoch termination conditions");
//Initialize termination conditions:
if(esConfig.getIterationTerminationConditions() != null){
for( IterationTerminationCondition c : esConfig.getIterationTerminationConditions()){
c.initialize();
}
}
if(esConfig.getEpochTerminationConditions() != null){
for( EpochTerminationCondition c : esConfig.getEpochTerminationConditions()){
c.initialize();
}
}
if(listener != null)
listener.onStart(esConfig,net);
Map scoreVsEpoch = new LinkedHashMap<>();
if(train != null) train.cache();
else trainMulti.cache();
int epochCount = 0;
while (true) { //Iterate (do epochs) until termination condition hit
double lastScore;
boolean terminate = false;
IterationTerminationCondition terminationReason = null;
if(train != null) fit(train);
else fitMulti(trainMulti);
//TODO revisit per iteration termination conditions, ensuring they are evaluated *per averaging* not per epoch
//Check per-iteration termination conditions
lastScore = getScore();
for (IterationTerminationCondition c : esConfig.getIterationTerminationConditions()) {
if (c.terminate(lastScore)) {
terminate = true;
terminationReason = c;
break;
}
}
if(terminate){
//Handle termination condition:
log.info("Hit per iteration epoch termination condition at epoch {}, iteration {}. Reason: {}",
epochCount, epochCount, terminationReason);
if(esConfig.isSaveLastModel()) {
//Save last model:
try {
esConfig.getModelSaver().saveLatestModel(net, 0.0);
} catch (IOException e) {
throw new RuntimeException("Error saving most recent model", e);
}
}
T bestModel;
try{
bestModel = esConfig.getModelSaver().getBestModel();
}catch(IOException e2){
throw new RuntimeException(e2);
}
EarlyStoppingResult result = new EarlyStoppingResult<>(
EarlyStoppingResult.TerminationReason.IterationTerminationCondition,
terminationReason.toString(),
scoreVsEpoch,
bestModelEpoch,
bestModelScore,
epochCount,
bestModel);
if(listener != null) listener.onCompletion(result);
return result;
}
log.info("Completed training epoch {}",epochCount);
if( (epochCount==0 && esConfig.getEvaluateEveryNEpochs()==1) || epochCount % esConfig.getEvaluateEveryNEpochs() == 0 ){
//Calculate score at this epoch:
ScoreCalculator sc = esConfig.getScoreCalculator();
double score = (sc == null ? 0.0 : esConfig.getScoreCalculator().calculateScore(net));
scoreVsEpoch.put(epochCount-1,score);
if (sc != null && score < bestModelScore) {
//Save best model:
if (bestModelEpoch == -1) {
//First calculated/reported score
log.info("Score at epoch {}: {}", epochCount, score);
} else {
log.info("New best model: score = {}, epoch = {} (previous: score = {}, epoch = {})",
score, epochCount, bestModelScore, bestModelEpoch);
}
bestModelScore = score;
bestModelEpoch = epochCount;
try{
esConfig.getModelSaver().saveBestModel(net,score);
}catch(IOException e){
throw new RuntimeException("Error saving best model",e);
}
}
if(esConfig.isSaveLastModel()) {
//Save last model:
try {
esConfig.getModelSaver().saveLatestModel(net, score);
} catch (IOException e) {
throw new RuntimeException("Error saving most recent model", e);
}
}
if(listener != null) listener.onEpoch(epochCount,score,esConfig,net);
//Check per-epoch termination conditions:
boolean epochTerminate = false;
EpochTerminationCondition termReason = null;
for(EpochTerminationCondition c : esConfig.getEpochTerminationConditions()){
if(c.terminate(epochCount,score)){
epochTerminate = true;
termReason = c;
break;
}
}
if(epochTerminate){
log.info("Hit epoch termination condition at epoch {}. Details: {}", epochCount, termReason.toString());
T bestModel;
try{
bestModel = esConfig.getModelSaver().getBestModel();
}catch(IOException e2){
throw new RuntimeException(e2);
}
EarlyStoppingResult result = new EarlyStoppingResult<>(
EarlyStoppingResult.TerminationReason.EpochTerminationCondition,
termReason.toString(),
scoreVsEpoch,
bestModelEpoch,
bestModelScore,
epochCount+1,
bestModel);
if(listener != null) listener.onCompletion(result);
return result;
}
epochCount++;
}
}
}
@Override
public void setListener(EarlyStoppingListener listener) {
this.listener = listener;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy