org.nd4j.autodiff.listeners.records.History Maven / Gradle / Ivy
The newest version!
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.nd4j.autodiff.listeners.records;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import lombok.Getter;
import org.nd4j.autodiff.listeners.Listener;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.IMetric;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
@Getter
public class History {
private List trainingHistory;
private List validationHistory;
private LossCurve lossCurve;
private long trainingTimeMillis;
private List validationTimesMillis;
public History(List training, List validation, LossCurve loss,
long trainingTimeMillis, List validationTimesMillis){
trainingHistory = Collections.unmodifiableList(training);
validationHistory = Collections.unmodifiableList(validation);
this.lossCurve = loss;
this.trainingTimeMillis = trainingTimeMillis;
this.validationTimesMillis = Collections.unmodifiableList(validationTimesMillis);
}
/**
* Get the training evaluations
*/
public List trainingEval(){
return trainingHistory;
}
/**
* Get the validation evaluations
*/
public List validationEval(){
return validationHistory;
}
/**
* Get the loss curve
*/
public LossCurve lossCurve(){
return lossCurve;
}
/**
* Get the total training time, in milliseconds
*/
public long trainingTimeMillis(){
return trainingTimeMillis;
}
/**
* Get the total validation time, in milliseconds
*/
public List validationTimesMillis(){
return validationTimesMillis;
}
/**
* Get the number of epochs trained for
*/
public int trainingEpochs(){
return trainingHistory.size();
}
/**
* Get the number of epochs validation was ran on
*/
public int validationEpochs(){
return validationHistory.size();
}
/**
* Get the results of a training evaluation on a given parameter for a given metric
*
* Only works if there is only one evaluation with the given metric for param
*/
public List trainingEval(String param, IMetric metric){
List data = new ArrayList<>();
for(EvaluationRecord er : trainingHistory)
data.add(er.getValue(param, metric));
return data;
}
/**
* Get the results of a training evaluation on a given parameter for a given metric
*
* Only works if there is only one evaluation with the given metric for param
*/
public List trainingEval(SDVariable param, IMetric metric){
return trainingEval(param.name(), metric);
}
/**
* Get the results of a training evaluation on a given parameter at a given index, for a given metric
*
* Note that it returns all recorded evaluations.
* Index determines the evaluation used not the epoch's results to return.
*/
public List trainingEval(String param, int index, IMetric metric){
List data = new ArrayList<>();
for(EvaluationRecord er : trainingHistory)
data.add(er.getValue(param, index, metric));
return data;
}
/**
* Get the results of a training evaluation on a given parameter at a given index, for a given metric
*
* Note that it returns all recorded evaluations.
* Index determines the evaluation used not the epoch's results to return.
*/
public List trainingEval(SDVariable param, int index, IMetric metric){
return trainingEval(param.name(), index, metric);
}
/**
* Get the results of a training evaluation for a given metric
*
* Only works if there is only one evaluation with the given metric
*/
public List trainingEval(IMetric metric){
List data = new ArrayList<>();
for(EvaluationRecord er : trainingHistory)
data.add(er.getValue(metric));
return data;
}
/**
* Get the results of a training evaluation on a given parameter
*
* Only works if there is only one evaluation for param.
*/
public List trainingEval(String param){
List data = new ArrayList<>();
for(EvaluationRecord er : trainingHistory)
data.add(er.evaluation(param));
return data;
}
/**
* Get the results of a training evaluation on a given parameter
*
* Only works if there is only one evaluation for param.
*/
public List trainingEval(SDVariable param){
return trainingEval(param.name());
}
/**
* Get the results of a training evaluation on a given parameter at a given index
*
* Note that it returns all recorded evaluations.
* Index determines the evaluation used not the epoch's results to return.
*/
public List trainingEval(String param, int index){
List data = new ArrayList<>();
for(EvaluationRecord er : trainingHistory)
data.add(er.evaluation(param, index));
return data;
}
/**
* Get the results of a training evaluation on a given parameter at a given index
*
* Note that it returns all recorded evaluations.
* Index determines the evaluation used not the epoch's results to return.
*/
public List trainingEval(SDVariable param, int index){
return trainingEval(param.name(), index);
}
/**
* Get the results of a validation evaluation on a given parameter for a given metric
*
* Only works if there is only one evaluation with the given metric for param
*/
public List validationEval(String param, IMetric metric){
List data = new ArrayList<>();
for(EvaluationRecord er : validationHistory)
data.add(er.getValue(param, metric));
return data;
}
/**
* Get the results of a validation evaluation on a given parameter for a given metric
*
* Only works if there is only one evaluation with the given metric for param
*/
public List validationEval(SDVariable param, IMetric metric){
return validationEval(param.name(), metric);
}
/**
* Get the results of a validation evaluation on a given parameter at a given index, for a given metric
*
* Note that it returns all recorded evaluations.
* Index determines the evaluation used not the epoch's results to return.
*/
public List validationEval(String param, int index, IMetric metric){
List data = new ArrayList<>();
for(EvaluationRecord er : validationHistory)
data.add(er.getValue(param, index, metric));
return data;
}
/**
* Get the results of a validation evaluation on a given parameter at a given index, for a given metric
*
* Note that it returns all recorded evaluations.
* Index determines the evaluation used not the epoch's results to return.
*/
public List validationEval(SDVariable param, int index, IMetric metric){
return validationEval(param.name(), index, metric);
}
/**
* Get the results of a validation evaluation for a given metric
*
* Only works if there is only one evaluation with the given metric
*/
public List validationEval(IMetric metric){
List data = new ArrayList<>();
for(EvaluationRecord er : validationHistory)
data.add(er.getValue(metric));
return data;
}
/**
* Get the results of a validation evaluation on a given parameter
*
* Only works if there is only one evaluation for param.
*/
public List validationEval(String param){
List data = new ArrayList<>();
for(EvaluationRecord er : validationHistory)
data.add(er.evaluation(param));
return data;
}
/**
* Get the results of a validation evaluation on a given parameter
*
* Only works if there is only one evaluation for param.
*/
public List validationEval(SDVariable param){
return validationEval(param.name());
}
/**
* Get the results of a validation evaluation on a given parameter at a given index
*
* Note that it returns all recorded evaluations.
* Index determines the evaluation used not the epoch's results to return.
*/
public List validationEval(String param, int index){
List data = new ArrayList<>();
for(EvaluationRecord er : validationHistory)
data.add(er.evaluation(param, index));
return data;
}
/**
* Get the results of a validation evaluation on a given parameter at a given index
*
* Note that it returns all recorded evaluations.
* Index determines the evaluation used not the epoch's results to return.
*/
public List validationEval(SDVariable param, int index){
return validationEval(param.name(), index);
}
/**
* Gets the training evaluations ran during the last epoch
*/
public EvaluationRecord finalTrainingEvaluations(){
Preconditions.checkState(!trainingHistory.isEmpty(), "Cannot get final training evaluation - history is empty");
return trainingHistory.get(trainingHistory.size() - 1);
}
/**
* Gets the validation evaluations ran during the last epoch
*/
public EvaluationRecord finalValidationEvaluations(){
Preconditions.checkState(!validationHistory.isEmpty(), "Cannot get final validation evaluation - history is empty");
return validationHistory.get(validationHistory.size() - 1);
}
/**
* Gets the evaluation record for a given epoch.
* @param epoch The epoch to get results for. If negative, returns results for the epoch that many epochs from the end.
*/
public EvaluationRecord trainingEvaluations(int epoch){
if(epoch >= 0){
return trainingHistory.get(epoch);
} else {
return trainingHistory.get(trainingHistory.size() - epoch);
}
}
/**
* Gets the evaluation record for a given epoch.
* @param epoch The epoch to get results for. If negative, returns results for the epoch that many epochs from the end.
*/
public EvaluationRecord validationEvaluations(int epoch){
if(epoch >= 0){
return trainingHistory.get(epoch);
} else {
return validationHistory.get(validationHistory.size() - epoch);
}
}
}