All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.training.evaluator.Evaluator Maven / Gradle / Ivy

There is a newer version: 0.30.0
Show newest version
/*
 * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.training.evaluator;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/**
 * Base class for all {@code Evaluator}s that can be used to evaluate the performance of a model.
 *
 * 

The {@code Evaluator}s can all be monitored to make an assessment about the performance of the * model. However, only ones that further extend {@link ai.djl.training.loss.Loss} are suited to * being used to directly optimize a model. * *

In addition to computing the evaluation, an evaluator can accumulate values to compute a total * evaluation. For different purposes, it is possible to have multiple accumulators on a single * evaluator. Each accumulator must be added with a String key to identify the accumulator. Before * using an accumulator, you must {@link Evaluator#addAccumulator(String)}. Then, call {@link * Evaluator#updateAccumulator(String, NDList, NDList)} to add more data to the accumulator. You can * use {@link Evaluator#getAccumulator(String)} to retrieve the accumulated value and {@link * Evaluator#resetAccumulator(String)} to reset the accumulator to the same value as when just * added. */ public abstract class Evaluator { private String name; protected Map totalInstances; /** * Creates an evaluator with abstract update methods. * * @param name the name of the evaluator */ public Evaluator(String name) { this.name = name; totalInstances = new ConcurrentHashMap<>(); } /** * Returns the name of this {@code Evaluator}. * * @return the name of this {@code Evaluator} */ public String getName() { return name; } /** * Calculates the evaluation between the labels and the predictions. * * @param labels the correct values * @param predictions the predicted values * @return the evaluation result */ public abstract NDArray evaluate(NDList labels, NDList predictions); /** * Adds an accumulator for the results of the evaluation with the given key. * * @param key the key for the new accumulator */ public abstract void addAccumulator(String key); /** * Updates the evaluator with the given key based on a {@link NDList} of labels and predictions. * *

This is a synchronized operation. You should only call it at the end of a batch or epoch. * * @param key the key of the accumulator to update * @param labels a {@code NDList} of labels * @param predictions a {@code NDList} of predictions */ public abstract void updateAccumulator(String key, NDList labels, NDList predictions); /** * Resets the evaluator value with the given key. * * @param key the key of the accumulator to reset */ public abstract void resetAccumulator(String key); /** * Returns the accumulated evaluator value. * * @param key the key of the accumulator to get * @return the accumulated value * @throws IllegalArgumentException if no accumulator was added with the given key */ public abstract float getAccumulator(String key); /** * Checks if the two input {@code NDArray} have the same length or shape. * * @param labels a {@code NDArray} of labels * @param predictions a {@code NDArray} of predictions * @param checkDimOnly whether to check for first dimension only */ protected void checkLabelShapes(NDArray labels, NDArray predictions, boolean checkDimOnly) { if (labels.getShape().get(0) != predictions.getShape().get(0)) { throw new IllegalArgumentException( "The size of labels(" + labels.size() + ") does not match that of predictions(" + predictions.size() + ")"); } if (!checkDimOnly) { if (!labels.getShape().equals(predictions.getShape())) { throw new IllegalArgumentException( "The shape of labels(" + labels.getShape() + ") does not match that of predictions(" + predictions.getShape() + ")"); } } } /** * Checks the length of NDArrays. * * @param labels a {@code NDArray} of labels * @param predictions a {@code NDArray} of predictions */ protected void checkLabelShapes(NDArray labels, NDArray predictions) { checkLabelShapes(labels, predictions, true); } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy