All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.evaluation.custom.CustomEvaluation Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.nd4j.evaluation.custom;

import org.nd4j.shade.guava.collect.Lists;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import org.nd4j.evaluation.BaseEvaluation;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.IMetric;
import org.nd4j.linalg.api.ndarray.INDArray;

@Data
@EqualsAndHashCode(callSuper = true)
public class CustomEvaluation extends BaseEvaluation {

    /**
     * The metric used to get a score for the CustomEvaluation.  Uses a ResultLambda
     */
    @AllArgsConstructor
    @RequiredArgsConstructor
    public static class Metric implements IMetric{

        @Getter
        @NonNull private ResultLambda getResult;

        private boolean minimize = false;

        @Override
        public Class getEvaluationClass() {
            return CustomEvaluation.class;
        }

        @Override
        public boolean minimize() {
            return minimize;
        }

        /**
         * A metric that takes the average of a list of doubles
         */
        public static Metric doubleAverage(boolean minimize){
            return new Metric<>(new ResultLambda() {
                @Override
                public double toResult(List data) {
                    int count = 0;
                    double sum = 0;
                    for (Double d : data) {
                        count++;
                        sum += d;
                    }
                    return sum / count;
                }
            }, minimize);
        }


        /**
         * A metric that takes the max of a list of doubles
         */
        public static Metric doubleMax(boolean minimize){
            return new Metric<>(new ResultLambda() {
                @Override
                public double toResult(List data) {
                    double max = 0;
                    for (Double d : data) {
                        if(d > max)
                            max = d;
                    }
                    return max;
                }
            }, minimize);
        }


        /**
         * A metric that takes the min of a list of doubles
         */
        public static Metric doubleMin(boolean minimize){
            return new Metric<>(new ResultLambda() {
                @Override
                public double toResult(List data) {
                    double max = 0;
                    for (Double d : data) {
                        if(d < max)
                            max = d;
                    }
                    return max;
                }
            }, minimize);
        }
    }

    /**
     * A MergeLambda that merges by concatenating the two lists
     */
    public static  MergeLambda mergeConcatenate(){
        return new MergeLambda() {
            @Override
            public List merge(List a, List b) {
                List res = Lists.newArrayList(a);
                res.addAll(b);
                return res;
            }
        };
    }

    @NonNull private EvaluationLambda evaluationLambda;
    @NonNull private MergeLambda mergeLambda;

    private List evaluations = new ArrayList<>();

    @Override
    public void eval(INDArray labels, INDArray networkPredictions, INDArray maskArray,
            List recordMetaData) {
        evaluations.add(evaluationLambda.eval(labels, networkPredictions, maskArray, recordMetaData));
    }

    @Override
    public void merge(CustomEvaluation other) {
        evaluations = mergeLambda.merge(evaluations, other.evaluations);
    }

    @Override
    public void reset() {
        evaluations = new ArrayList<>();
    }

    @Override
    public String stats() {
        return "";
    }

    @Override
    public double getValue(IMetric metric) {
        if(metric instanceof Metric){
            return ((Metric) metric).getGetResult().toResult(evaluations);
        } else
            throw new IllegalStateException("Can't get value for non-regression Metric " + metric);
    }

    @Override
    public CustomEvaluation newInstance() {
        return new CustomEvaluation(evaluationLambda, mergeLambda);
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy