All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.autodiff.listeners.records.EvaluationRecord Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*
 * Copyright (c) 2015-2019 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 */

package org.nd4j.autodiff.listeners.records;

import org.nd4j.shade.guava.base.Predicates;
import org.nd4j.shade.guava.collect.Collections2;
import org.nd4j.shade.guava.collect.ImmutableMap;
import org.nd4j.shade.guava.collect.Lists;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import lombok.Getter;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.base.Preconditions;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.IMetric;

/**
 * A helper class to hold evaluations and provide methods to easily query them
 */
@Getter
public class EvaluationRecord {

    private Map> evaluations;
    private Map, IEvaluation> classEvaluations = new HashMap<>();
    private boolean isEmpty = true;

    public EvaluationRecord(Map> evaluations) {
        this.evaluations = Collections.unmodifiableMap(evaluations);

        for (List le : evaluations.values()) {
            for (IEvaluation e : le) {
                isEmpty = false;
                if (classEvaluations.containsKey(e.getClass()))
                    classEvaluations.remove(e.getClass());
                else
                    classEvaluations.put(e.getClass(), e);
            }
        }
    }

    private EvaluationRecord() {

    }

    public boolean isEmpty() {
        return isEmpty;
    }

    /**
     * Get all evaluations
     */
    public Map> evaluations() {
        return evaluations;
    }

    /**
     * Get evaluations for a given param/variable
     *
     * @param param The target param/variable
     */
    public List evaluations(String param) {
        Preconditions.checkArgument(evaluations.containsKey(param),
                "No evaluations for %s.", param);

        return evaluations.get(param);
    }

    /**
     * Get evaluations for a given param/variable
     *
     * @param param The target param/variable
     */
    public List evaluations(SDVariable param) {
        return evaluations(param.getVarName());
    }

    /**
     * Get the evaluation for param at the specified index
     */
    public IEvaluation evaluation(String param, int index) {
        return evaluations(param).get(index);
    }

    /**
     * Get the evaluation for param at the specified index
     */
    public IEvaluation evaluation(SDVariable param, int index) {
        return evaluation(param.getVarName(), index);
    }

    /**
     * Get the evaluation for a given param/variable
     * 

* Will throw an exception if there are more than one or no evaluations for the param * * @param param The target param/variable */ public T evaluation(String param) { Preconditions.checkArgument(evaluations.containsKey(param), "No evaluations for %s.", param); Preconditions.checkArgument(evaluations.get(param).size() == 1, "Multiple evaluations for %s. Use evaluations().", param); return (T) evaluations.get(param).get(0); } /** * Get the evaluation for a given param/variable *

* Will throw an exception if there are more than one or no evaluations for the param * * @param param The target param/variable */ public T evaluation(SDVariable param) { return evaluation(param.getVarName()); } /** * Get the evaluation of a given type *

* Will throw an exception if there are more than one or no evaluations of that type * * @param evalClass The type of evaluation to look for */ public > T evaluation(Class evalClass) { Preconditions.checkArgument(classEvaluations.containsKey(evalClass), "Can't get evaluation for %s. Either no evaluations with that class are present, or more than one are.", evalClass); return (T) classEvaluations.get(evalClass); } /** * Get the evaluation of a given type, for a given param/variable *

* Will throw an exception if there are more than one or no evaluations of that type for the given param * * @param param The target param/variable * @param evalClass The type of evaluation to look for */ public > T evaluation(String param, Class evalClass) { Collection evals = Collections2.filter(evaluations(param), Predicates.instanceOf(evalClass)); Preconditions.checkArgument(evals.size() == 1, "Multiple or no evaluations of type %s for param %s.", evalClass, param); return (T) evals.iterator().next(); } /** * Get the evaluation of a given type, for a given param/variable *

* Will throw an exception if there are more than one or no evaluations of that type for the given param * * @param param The target param/variable * @param evalClass The type of evaluation to look for */ public > T evaluation(SDVariable param, Class evalClass) { return evaluation(param.getVarName(), evalClass); } /** * Get the metric's value for the evaluation of the metric's type *

* Will throw an exception if there are more than one or no evaluations of that type * * @param metric The metric to calculate */ public double getValue(IMetric metric) { return evaluation(metric.getEvaluationClass()).getValue(metric); } /** * Get the metric's value for the evaluation of the metric's type, for a given param/variable *

* Will throw an exception if there are more than one or no evaluations of that type for the given param * * @param param The target param/variable * @param metric The metric to calculate */ public double getValue(String param, IMetric metric) { return evaluation(param, metric.getEvaluationClass()).getValue(metric); } /** * Get the metric's value for the evaluation of the metric's type, for a given param/variable *

* Will throw an exception if there are more than one or no evaluations of that type for the given param * * @param param The target param/variable * @param metric The metric to calculate */ public double getValue(SDVariable param, IMetric metric) { return getValue(param.getVarName(), metric); } /** * Get the metric's value for the evaluation for a given param/variable at the given index *

* Will throw an exception if the target evaluation doesn't support the given metric * * @param param The target param/variable * @param index The index of the target evaluation on the param * @param metric The metric to calculate */ public double getValue(String param, int index, IMetric metric) { return evaluation(param, index).getValue(metric); } /** * Get the metric's value for the evaluation for a given param/variable at the given index *

* Will throw an exception if the target evaluation doesn't support the given metric * * @param param The target param/variable * @param index The index of the target evaluation on the param * @param metric The metric to calculate */ public double getValue(SDVariable param, int index, IMetric metric) { return getValue(param.getVarName(), index, metric); } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy