org.nd4j.evaluation.custom.CustomEvaluation Maven / Gradle / Ivy
The newest version!
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.nd4j.evaluation.custom;
import org.nd4j.shade.guava.collect.Lists;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import org.nd4j.evaluation.BaseEvaluation;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.IMetric;
import org.nd4j.linalg.api.ndarray.INDArray;
@Data
@EqualsAndHashCode(callSuper = true)
public class CustomEvaluation extends BaseEvaluation {
/**
* The metric used to get a score for the CustomEvaluation. Uses a ResultLambda
*/
@AllArgsConstructor
@RequiredArgsConstructor
public static class Metric implements IMetric{
@Getter
@NonNull private ResultLambda getResult;
private boolean minimize = false;
@Override
public Class getEvaluationClass() {
return CustomEvaluation.class;
}
@Override
public boolean minimize() {
return minimize;
}
/**
* A metric that takes the average of a list of doubles
*/
public static Metric doubleAverage(boolean minimize){
return new Metric<>(new ResultLambda() {
@Override
public double toResult(List data) {
int count = 0;
double sum = 0;
for (Double d : data) {
count++;
sum += d;
}
return sum / count;
}
}, minimize);
}
/**
* A metric that takes the max of a list of doubles
*/
public static Metric doubleMax(boolean minimize){
return new Metric<>(new ResultLambda() {
@Override
public double toResult(List data) {
double max = 0;
for (Double d : data) {
if(d > max)
max = d;
}
return max;
}
}, minimize);
}
/**
* A metric that takes the min of a list of doubles
*/
public static Metric doubleMin(boolean minimize){
return new Metric<>(new ResultLambda() {
@Override
public double toResult(List data) {
double max = 0;
for (Double d : data) {
if(d < max)
max = d;
}
return max;
}
}, minimize);
}
}
/**
* A MergeLambda that merges by concatenating the two lists
*/
public static MergeLambda mergeConcatenate(){
return new MergeLambda() {
@Override
public List merge(List a, List b) {
List res = Lists.newArrayList(a);
res.addAll(b);
return res;
}
};
}
@NonNull private EvaluationLambda evaluationLambda;
@NonNull private MergeLambda mergeLambda;
private List evaluations = new ArrayList<>();
@Override
public void eval(INDArray labels, INDArray networkPredictions, INDArray maskArray,
List recordMetaData) {
evaluations.add(evaluationLambda.eval(labels, networkPredictions, maskArray, recordMetaData));
}
@Override
public void merge(CustomEvaluation other) {
evaluations = mergeLambda.merge(evaluations, other.evaluations);
}
@Override
public void reset() {
evaluations = new ArrayList<>();
}
@Override
public String stats() {
return "";
}
@Override
public double getValue(IMetric metric) {
if(metric instanceof Metric){
return ((Metric) metric).getGetResult().toResult(evaluations);
} else
throw new IllegalStateException("Can't get value for non-regression Metric " + metric);
}
@Override
public CustomEvaluation newInstance() {
return new CustomEvaluation(evaluationLambda, mergeLambda);
}
}