org.deeplearning4j.eval.RegressionEvaluation Maven / Gradle / Ivy
package org.deeplearning4j.eval;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.Abs;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/**
* Evaluation method for the evaluation of regression algorithms.
* Provides the following metrics, for each column:
* - MSE: mean squared error
* - MAE: mean absolute error
* - RMSE: root mean squared error
* - RSE: relative squared error
* - correlation coefficient
* See for example: http://www.saedsayad.com/model_evaluation_r.htm
* For classification, see {@link Evaluation}
*
* @author Alex Black
*/
public class RegressionEvaluation {
public static final int DEFAULT_PRECISION = 5;
private List columnNames;
private int precision;
private int exampleCount = 0;
private INDArray labelsSumPerColumn; //sum(actual) per column -> used to calculate mean
private INDArray sumSquaredErrorsPerColumn; //(predicted - actual)^2
private INDArray sumAbsErrorsPerColumn; //abs(predicted-actial)
private INDArray currentMean;
private INDArray currentPredictionMean;
private INDArray m2Actual;
private INDArray sumOfProducts;
private INDArray sumSquaredLabels;
private INDArray sumSquaredPredicted;
/** Create a regression evaluation object with the specified number of columns, and default precision
* for the stats() method.
* @param nColumns Number of columns
*/
public RegressionEvaluation(int nColumns) {
this(createDefaultColumnNames(nColumns), DEFAULT_PRECISION);
}
/** Create a regression evaluation object with the specified number of columns, and specified precision
* for the stats() method.
* @param nColumns Number of columns
*/
public RegressionEvaluation(int nColumns, int precision) {
this(createDefaultColumnNames(nColumns), precision);
}
/** Create a regression evaluation object with default precision for the stats() method
* @param columnNames Names of the columns
*/
public RegressionEvaluation(String... columnNames) {
this(Arrays.asList(columnNames), DEFAULT_PRECISION);
}
/** Create a regression evaluation object with default precision for the stats() method
* @param columnNames Names of the columns
*/
public RegressionEvaluation(List columnNames) {
this(columnNames, DEFAULT_PRECISION);
}
/** Create a regression evaluation object with specified precision for the stats() method
* @param columnNames Names of the columns
*/
public RegressionEvaluation(List columnNames, int precision) {
this.columnNames = columnNames;
this.precision = precision;
int n = columnNames.size();
labelsSumPerColumn = Nd4j.zeros(n);
sumSquaredErrorsPerColumn = Nd4j.zeros(n);
sumAbsErrorsPerColumn = Nd4j.zeros(n);
currentMean = Nd4j.zeros(n);
m2Actual = Nd4j.zeros(n);
currentPredictionMean = Nd4j.zeros(n);
sumOfProducts = Nd4j.zeros(n);
sumSquaredLabels = Nd4j.zeros(n);
sumSquaredPredicted = Nd4j.zeros(n);
}
private static List createDefaultColumnNames(int nColumns) {
List list = new ArrayList<>(nColumns);
for (int i = 0; i < nColumns; i++) list.add("col_" + i);
return list;
}
public void eval(INDArray labels, INDArray predictions) {
//References for the calculations is this section:
//https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Online_algorithm
//https://en.wikipedia.org/wiki/Pearson_product-moment_correlation_coefficient#For_a_sample
//Doing online calculation of means, sum of squares, etc.
labelsSumPerColumn.addi(labels.sum(0));
INDArray error = predictions.sub(labels);
INDArray absErrorSum = Nd4j.getExecutioner().execAndReturn(new Abs(error.dup())).sum(0);
INDArray squaredErrorSum = error.mul(error).sum(0);
sumAbsErrorsPerColumn.addi(absErrorSum);
sumSquaredErrorsPerColumn.addi(squaredErrorSum);
sumOfProducts.addi(labels.mul(predictions).sum(0));
sumSquaredLabels.addi(labels.mul(labels).sum(0));
sumSquaredPredicted.addi(predictions.mul(predictions).sum(0));
int nRows = labels.size(0);
for( int i=0; i
© 2015 - 2025 Weber Informatics LLC | Privacy Policy