Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
org.nd4j.evaluation.BaseEvaluation Maven / Gradle / Ivy
/* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.evaluation;
import lombok.EqualsAndHashCode;
import lombok.NonNull;
import org.nd4j.common.base.Preconditions;
import org.nd4j.evaluation.classification.*;
import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastTo;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.primitives.Triple;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.serde.json.JsonMappers;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import org.nd4j.shade.jackson.databind.exc.InvalidTypeIdException;
import java.io.IOException;
import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
/**
* BaseEvaluation implement common evaluation functionality (for time series, etc) for {@link Evaluation},
* {@link RegressionEvaluation}, {@link ROC}, {@link ROCMultiClass} etc.
*
* @author Alex Black
*/
@EqualsAndHashCode
public abstract class BaseEvaluation implements IEvaluation {
/**
* @param yaml YAML representation
* @param clazz Class
* @param Type to return
* @return Evaluation instance
*/
public static T fromYaml(String yaml, Class clazz) {
try {
return JsonMappers.getYamlMapper().readValue(yaml, clazz);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
/**
* @param json Jason representation of the evaluation instance
* @param clazz Class
* @param Type to return
* @return Evaluation instance
*/
public static T fromJson(String json, Class clazz) {
try {
return JsonMappers.getMapper().readValue(json, clazz);
} catch (InvalidTypeIdException e) {
if (e.getMessage().contains("Could not resolve type id")) {
try {
return (T) attempFromLegacyFromJson(json, e);
} catch (Throwable t) {
throw new RuntimeException("Cannot deserialize from JSON - JSON is invalid?", t);
}
}
throw new RuntimeException(e);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
/**
* Attempt to load DL4J IEvaluation JSON from 1.0.0-beta2 or earlier.
* Given IEvaluation classes were moved to ND4J with no major changes, a simple "find and replace" for the class
* names is used.
*
* @param json JSON to attempt to deserialize
* @param originalException Original exception to be re-thrown if it isn't legacy JSON
*/
protected static T attempFromLegacyFromJson(String json, InvalidTypeIdException originalException) throws InvalidTypeIdException {
if (json.contains("org.deeplearning4j.eval.Evaluation")) {
String newJson = json.replaceAll("org.deeplearning4j.eval.Evaluation", "org.nd4j.evaluation.classification.Evaluation");
return (T) fromJson(newJson, Evaluation.class);
}
if (json.contains("org.deeplearning4j.eval.EvaluationBinary")) {
String newJson = json.replaceAll("org.deeplearning4j.eval.EvaluationBinary", "org.nd4j.evaluation.classification.EvaluationBinary")
.replaceAll("org.deeplearning4j.eval.ROC", "org.nd4j.evaluation.classification.ROC")
.replaceAll("org.deeplearning4j.eval.curves.", "org.nd4j.evaluation.curves.");
return (T) fromJson(newJson, EvaluationBinary.class);
}
if (json.contains("org.deeplearning4j.eval.EvaluationCalibration")) {
String newJson = json.replaceAll("org.deeplearning4j.eval.EvaluationCalibration", "org.nd4j.evaluation.classification.EvaluationCalibration")
.replaceAll("org.deeplearning4j.eval.curves.", "org.nd4j.evaluation.curves.");
return (T) fromJson(newJson, EvaluationCalibration.class);
}
if (json.contains("org.deeplearning4j.eval.ROCBinary")) {
String newJson = json.replaceAll("org.deeplearning4j.eval.ROCBinary", "org.nd4j.evaluation.classification.ROCBinary")
.replaceAll("org.deeplearning4j.eval.ROC", "org.nd4j.evaluation.classification.ROC") //Nested ROC instances internally
.replaceAll("org.deeplearning4j.eval.curves.", "org.nd4j.evaluation.curves.");
return (T) fromJson(newJson, ROCBinary.class);
}
if (json.contains("org.deeplearning4j.eval.ROCMultiClass")) {
String newJson = json.replaceAll("org.deeplearning4j.eval.ROCMultiClass", "org.nd4j.evaluation.classification.ROCMultiClass")
.replaceAll("org.deeplearning4j.eval.ROC", "org.nd4j.evaluation.classification.ROC") //Nested ROC instances internally
.replaceAll("org.deeplearning4j.eval.curves.", "org.nd4j.evaluation.curves.");
return (T) fromJson(newJson, ROCMultiClass.class);
}
if (json.contains("org.deeplearning4j.eval.ROC")) { //Has to be checked after ROCBinary/ROCMultiClass due to it being a prefix
String newJson = json.replaceAll("org.deeplearning4j.eval.ROC", "org.nd4j.evaluation.classification.ROC")
.replaceAll("org.deeplearning4j.eval.curves.", "org.nd4j.evaluation.curves.");
return (T) fromJson(newJson, ROC.class);
}
if (json.contains("org.deeplearning4j.eval.RegressionEvaluation")) {
String newJson = json.replaceAll("org.deeplearning4j.eval.RegressionEvaluation", "org.nd4j.evaluation.regression.RegressionEvaluation");
return (T) fromJson(newJson, RegressionEvaluation.class);
}
throw originalException;
}
public static Triple reshapeAndExtractNotMasked(INDArray labels, INDArray predictions, INDArray mask, int axis) {
if (labels.rank() == 2) {
Preconditions.checkState(axis == 1, "Only axis=1 is supported 2d data - got axis=%s for labels array shape %ndShape", axis, labels);
if (mask == null) {
//no-op
return new Triple<>(labels, predictions, null);
} else {
//2 possible cases: per-output masking, and per example masking
if (mask.rank() == 1 || mask.isColumnVector()) {
int notMaskedCount = mask.neq(0.0).castTo(DataType.INT).sumNumber().intValue();
if (notMaskedCount == 0) {
//All steps masked - nothing left to evaluate
return null;
}
if (notMaskedCount == mask.length()) {
//No masked steps - returned as-is
return new Triple<>(labels, predictions, null);
}
int[] arr = mask.toIntVector();
int[] idxs = new int[notMaskedCount];
int pos = 0;
for (int i = 0; i < arr.length; i++) {
if (arr[i] != 0) {
idxs[pos++] = i;
}
}
INDArray retLabel = Nd4j.pullRows(labels, 1, idxs, 'c');
INDArray retPredictions = Nd4j.pullRows(predictions, 1, idxs, 'c');
return new Triple<>(retLabel, retPredictions, null);
} else {
Preconditions.checkState(labels.equalShapes(mask), "If a mask array is present for 2d data, it must either be a vector (column vector)" +
" or have shape equal to the labels (for per-output masking, when supported). Got labels shape %ndShape, mask shape %ndShape",
labels, mask);
//Assume evaluation instances with per-output masking will handle that themselves (or throw exception if not supported)
return new Triple<>(labels, predictions, mask);
}
}
} else if (labels.rank() == 3 || labels.rank() == 4 || labels.rank() == 5) {
if(mask == null){
return reshapeSameShapeTo2d(axis, labels, predictions, mask);
} else {
if(labels.rank() == 3) {
if (mask.rank() == 2) {
//Per time step masking
Pair p = EvaluationUtils.extractNonMaskedTimeSteps(labels, predictions, mask);
if (p == null) {
return null;
}
return new Triple<>(p.getFirst(), p.getSecond(), null);
} else {
//Per output mask
Preconditions.checkState(labels.equalShapes(mask), "If a mask array is present for 3d data, it must either be 2d (shape [minibatch, sequenceLength])" +
" or have shape equal to the labels (for per-output masking, when supported). Got labels shape %ndShape, mask shape %ndShape",
labels, mask);
//Assume evaluation instances with per-output masking will handle that themselves (or throw exception if not supported)
//Just reshape to 2d
return reshapeSameShapeTo2d(axis, labels, predictions, mask);
}
} else {
if(labels.equalShapes(mask)){
//Per output masking case
return reshapeSameShapeTo2d(axis, labels, predictions, mask);
} else if(mask.rank() == 1){
//Treat 1D mask as per-example masking
Preconditions.checkState(mask.length() == labels.size(0), "For rank 4 labels with shape %ndShape and 1d" +
" mask of shape %ndShape, the mask array length must equal labels dimension 0 size", labels, mask);
long[] reshape = ArrayUtil.nTimes(labels.rank(), 1L);
reshape[0] = mask.size(0);
INDArray mReshape = mask.reshape(reshape);
INDArray bMask = Nd4j.createUninitialized(mask.dataType(), labels.shape());
BroadcastTo b = new BroadcastTo(mReshape, labels.shape(), bMask);
Nd4j.exec(b);
return reshapeSameShapeTo2d(axis, labels, predictions, bMask);
} else if(mask.rank() == labels.rank() && Shape.areShapesBroadcastable(mask.shape(), labels.shape())){
//Same rank, but different shape -> broadcast
INDArray bMask = Nd4j.createUninitialized(mask.dataType(), labels.shape());
BroadcastTo b = new BroadcastTo(mask, labels.shape(), bMask);
Nd4j.exec(b);
return reshapeSameShapeTo2d(axis, labels, predictions, bMask);
}
throw new UnsupportedOperationException("Evaluation case not supported: labels shape " + Arrays.toString(labels.shape()) +
" with mask shape " + Arrays.toString(mask.shape()));
}
}
} else {
throw new IllegalStateException("Unknown array type passed to evaluation: labels array rank " + labels.rank()
+ " with shape " + labels.shapeInfoToString() + ". Labels and predictions must always be rank 2 or higher, with leading dimension being minibatch dimension");
}
}
private static Triple reshapeSameShapeTo2d(int axis, INDArray labels, INDArray predictions, INDArray mask){
int[] permuteDims = new int[labels.rank()];
int j=0;
for( int i=0; i(lOut, pOut, mOut);
}
@Override
public void eval(INDArray labels, INDArray networkPredictions) {
eval(labels, networkPredictions, null, null);
}
@Override
public void eval(@NonNull INDArray labels, @NonNull final INDArray predictions, final List extends Serializable> recordMetaData) {
eval(labels, predictions, null, recordMetaData);
}
@Override
public void eval(INDArray labels, INDArray networkPredictions, INDArray maskArray) {
eval(labels, networkPredictions, maskArray, null);
}
@Override
public void evalTimeSeries(INDArray labels, INDArray predicted) {
evalTimeSeries(labels, predicted, null);
}
@Override
public void evalTimeSeries(INDArray labels, INDArray predictions, INDArray labelsMask) {
Pair pair = EvaluationUtils.extractNonMaskedTimeSteps(labels, predictions, labelsMask);
if (pair == null) {
//No non-masked steps
return;
}
INDArray labels2d = pair.getFirst();
INDArray predicted2d = pair.getSecond();
eval(labels2d, predicted2d);
}
/**
* @return JSON representation of the evaluation instance
*/
@Override
public String toJson() {
try {
return JsonMappers.getMapper().writeValueAsString(this);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
@Override
public String toString() {
return stats();
}
/**
* @return YAML representation of the evaluation instance
*/
@Override
public String toYaml() {
try {
return JsonMappers.getYamlMapper().writeValueAsString(this);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
}