All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.evaluation.custom.CustomEvaluation Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*
 * Copyright (c) 2015-2019 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 */

package org.nd4j.evaluation.custom;

import org.nd4j.shade.guava.collect.Lists;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import org.nd4j.evaluation.BaseEvaluation;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.IMetric;
import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
import org.nd4j.linalg.api.ndarray.INDArray;

/**
 * A evaluation using lambdas to calculate the score.
 *
 * Uses 3 lambdas:
* EvaluationLambda: takes in the labels, predictions, mask, and metadata and returns a value of type T
* MergeLambda: takes in two lists of Ts, returns one. Used in merging for distributed training
* ResultLambda (in Metric): takes a list of Ts, returns a double value
*
* The EvaluationLambda will be called on each batch, and the results will be stored in a list. * MergeLambda merges two of those lists for distributed training (think Spark or Map-Reduce). * ResultLambda gets a score from that list. * */ @Data @EqualsAndHashCode(callSuper = true) public class CustomEvaluation extends BaseEvaluation { /** * The metric used to get a score for the CustomEvaluation. Uses a ResultLambda */ @AllArgsConstructor @RequiredArgsConstructor public static class Metric implements IMetric{ @Getter @NonNull private ResultLambda getResult; private boolean minimize = false; @Override public Class getEvaluationClass() { return CustomEvaluation.class; } @Override public boolean minimize() { return minimize; } /** * A metric that takes the average of a list of doubles */ public static Metric doubleAverage(boolean minimize){ return new Metric<>(new ResultLambda() { @Override public double toResult(List data) { int count = 0; double sum = 0; for (Double d : data) { count++; sum += d; } return sum / count; } }, minimize); } /** * A metric that takes the max of a list of doubles */ public static Metric doubleMax(boolean minimize){ return new Metric<>(new ResultLambda() { @Override public double toResult(List data) { double max = 0; for (Double d : data) { if(d > max) max = d; } return max; } }, minimize); } /** * A metric that takes the min of a list of doubles */ public static Metric doubleMin(boolean minimize){ return new Metric<>(new ResultLambda() { @Override public double toResult(List data) { double max = 0; for (Double d : data) { if(d < max) max = d; } return max; } }, minimize); } } /** * A MergeLambda that merges by concatenating the two lists */ public static MergeLambda mergeConcatenate(){ return new MergeLambda() { @Override public List merge(List a, List b) { List res = Lists.newArrayList(a); res.addAll(b); return res; } }; } @NonNull private EvaluationLambda evaluationLambda; @NonNull private MergeLambda mergeLambda; private List evaluations = new ArrayList<>(); @Override public void eval(INDArray labels, INDArray networkPredictions, INDArray maskArray, List recordMetaData) { evaluations.add(evaluationLambda.eval(labels, networkPredictions, maskArray, recordMetaData)); } @Override public void merge(CustomEvaluation other) { evaluations = mergeLambda.merge(evaluations, other.evaluations); } @Override public void reset() { evaluations = new ArrayList<>(); } @Override public String stats() { return ""; } @Override public double getValue(IMetric metric) { if(metric instanceof Metric){ return ((Metric) metric).getGetResult().toResult(evaluations); } else throw new IllegalStateException("Can't get value for non-regression Metric " + metric); } @Override public CustomEvaluation newInstance() { return new CustomEvaluation(evaluationLambda, mergeLambda); } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy