All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.training.loss.Loss Maven / Gradle / Ivy

There is a newer version: 0.30.0
Show newest version
/*
 * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.training.loss;

import ai.djl.ndarray.NDList;
import ai.djl.training.evaluator.Evaluator;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/**
 * Loss functions (or Cost functions) are used to evaluate the model predictions against true labels
 * for optimization.
 *
 * 

Although all evaluators can be used to measure the performance of a model, not all of them are * suited to being used by an optimizer. Loss functions are usually non-negative where a larger loss * represents worse performance. They are also real-valued to accurately compare models. * *

When creating a loss function, you should avoid having the loss depend on the batch size. For * example, if you have a loss per item in a batch and sum those losses, your loss would be {@code * numItemsInBatch*avgLoss}. Instead, you should take the mean of those losses to reduce out the * batchSize factor. Otherwise, it can make it difficult to tune the learning rate since any change * in the batch size would throw it off. If you have a variable batch size, it would be even more * difficult. * *

For more details about the class internals, see {@link Evaluator}. */ public abstract class Loss extends Evaluator { private Map totalLoss; /** * Base class for metric with abstract update methods. * * @param name The display name of the Loss */ public Loss(String name) { super(name); totalLoss = new ConcurrentHashMap<>(); } /** * Returns a new instance of {@link L1Loss} with default weight and batch axis. * * @return a new instance of {@link L1Loss} */ public static L1Loss l1Loss() { return new L1Loss(); } /** * Returns a new instance of {@link L1Loss} with default weight and batch axis. * * @param name the name of the loss * @return a new instance of {@link L1Loss} */ public static L1Loss l1Loss(String name) { return new L1Loss(name); } /** * Returns a new instance of {@link L1Loss} with given weight and batch axis. * * @param name the name of the loss * @param weight the weight to apply on loss value, default 1 * @return a new instance of {@link L1Loss} */ public static L1Loss l1Loss(String name, float weight) { return new L1Loss(name, weight); } /** * Returns a new instance of {@link L2Loss} with default weight and batch axis. * * @return a new instance of {@link L2Loss} */ public static L2Loss l2Loss() { return new L2Loss(); } /** * Returns a new instance of {@link L2Loss} with default weight and batch axis. * * @param name the name of the loss * @return a new instance of {@link L2Loss} */ public static L2Loss l2Loss(String name) { return new L2Loss(name); } /** * Returns a new instance of {@link L2Loss} with given weight and batch axis. * * @param name the name of the loss * @param weight the weight to apply on loss value, default 1 * @return a new instance of {@link L2Loss} */ public static L2Loss l2Loss(String name, float weight) { return new L2Loss(name, weight); } /** * Returns a new instance of {@link SigmoidBinaryCrossEntropyLoss} with default arguments. * * @return a new instance of {@link SigmoidBinaryCrossEntropyLoss} */ public static SigmoidBinaryCrossEntropyLoss sigmoidBinaryCrossEntropyLoss() { return new SigmoidBinaryCrossEntropyLoss(); } /** * Returns a new instance of {@link SigmoidBinaryCrossEntropyLoss} with default arguments. * * @param name the name of the loss * @return a new instance of {@link SigmoidBinaryCrossEntropyLoss} */ public static SigmoidBinaryCrossEntropyLoss sigmoidBinaryCrossEntropyLoss(String name) { return new SigmoidBinaryCrossEntropyLoss(name); } /** * Returns a new instance of {@link SigmoidBinaryCrossEntropyLoss} with the given arguments. * * @param name the name of the loss * @param weight the weight to apply on the loss value, default 1 * @param fromSigmoid whether the input is from the output of sigmoid, default false * @return a new instance of {@link SigmoidBinaryCrossEntropyLoss} */ public static SigmoidBinaryCrossEntropyLoss sigmoidBinaryCrossEntropyLoss( String name, float weight, boolean fromSigmoid) { return new SigmoidBinaryCrossEntropyLoss(name, weight, fromSigmoid); } /** * Returns a new instance of {@link SoftmaxCrossEntropyLoss} with default arguments. * * @return a new instance of {@link SoftmaxCrossEntropyLoss} */ public static SoftmaxCrossEntropyLoss softmaxCrossEntropyLoss() { return new SoftmaxCrossEntropyLoss(); } /** * Returns a new instance of {@link SoftmaxCrossEntropyLoss} with default arguments. * * @param name the name of the loss * @return a new instance of {@link SoftmaxCrossEntropyLoss} */ public static SoftmaxCrossEntropyLoss softmaxCrossEntropyLoss(String name) { return new SoftmaxCrossEntropyLoss(name); } /** * Returns a new instance of {@link SoftmaxCrossEntropyLoss} with the given arguments. * * @param name the name of the loss * @param weight the weight to apply on the loss value, default 1 * @param classAxis the axis that represents the class probabilities, default -1 * @param sparseLabel whether labels are integer array or probabilities, default true * @param fromLogit whether labels are log probabilities or un-normalized numbers * @return a new instance of {@link SoftmaxCrossEntropyLoss} */ public static SoftmaxCrossEntropyLoss softmaxCrossEntropyLoss( String name, float weight, int classAxis, boolean sparseLabel, boolean fromLogit) { return new SoftmaxCrossEntropyLoss(name, weight, classAxis, sparseLabel, fromLogit); } /** * Returns a new instance of {@link MaskedSoftmaxCrossEntropyLoss} with default arguments. * * @return a new instance of {@link MaskedSoftmaxCrossEntropyLoss} */ public static MaskedSoftmaxCrossEntropyLoss maskedSoftmaxCrossEntropyLoss() { return new MaskedSoftmaxCrossEntropyLoss(); } /** * Returns a new instance of {@link MaskedSoftmaxCrossEntropyLoss} with default arguments. * * @param name the name of the loss * @return a new instance of {@link MaskedSoftmaxCrossEntropyLoss} */ public static MaskedSoftmaxCrossEntropyLoss maskedSoftmaxCrossEntropyLoss(String name) { return new MaskedSoftmaxCrossEntropyLoss(name); } /** * Returns a new instance of {@link MaskedSoftmaxCrossEntropyLoss} with the given arguments. * * @param name the name of the loss * @param weight the weight to apply on the loss value, default 1 * @param classAxis the axis that represents the class probabilities, default -1 * @param sparseLabel whether labels are integer array or probabilities, default true * @param fromLogit whether labels are log probabilities or un-normalized numbers * @return a new instance of {@link MaskedSoftmaxCrossEntropyLoss} */ public static MaskedSoftmaxCrossEntropyLoss maskedSoftmaxCrossEntropyLoss( String name, float weight, int classAxis, boolean sparseLabel, boolean fromLogit) { return new MaskedSoftmaxCrossEntropyLoss(name, weight, classAxis, sparseLabel, fromLogit); } /** * Returns a new instance of {@link HingeLoss} with default arguments. * * @return a new instance of {@link HingeLoss} */ public static HingeLoss hingeLoss() { return new HingeLoss(); } /** * Returns a new instance of {@link HingeLoss} with default arguments. * * @param name the name of the loss * @return a new instance of {@link HingeLoss} */ public static HingeLoss hingeLoss(String name) { return new HingeLoss(name); } /** * Returns a new instance of {@link HingeLoss} with the given arguments. * * @param name the name of the loss * @param margin the margin in hinge loss. Defaults to 1.0 * @param weight the weight to apply on loss value, default 1 * @return a new instance of {@link HingeLoss} */ public static HingeLoss hingeLoss(String name, int margin, float weight) { return new HingeLoss(name, margin, weight); } /** {@inheritDoc} */ @Override public void addAccumulator(String key) { totalInstances.put(key, 0L); totalLoss.put(key, 0f); } /** {@inheritDoc} */ @Override public void updateAccumulator(String key, NDList labels, NDList predictions) { // this is a synchronized operation, only call it at end of batch or epoch float update = evaluate(labels, predictions).sum().getFloat(); totalInstances.compute(key, (k, v) -> v + 1); totalLoss.compute(key, (k, v) -> v + update); } /** {@inheritDoc} */ @Override public void resetAccumulator(String key) { totalInstances.compute(key, (k, v) -> 0L); totalLoss.compute(key, (k, v) -> 0f); } /** {@inheritDoc} */ @Override public float getAccumulator(String key) { Long total = totalInstances.get(key); if (total == null) { throw new IllegalArgumentException("No loss found at that path"); } if (total == 0) { return Float.NaN; } return totalLoss.get(key) / totalInstances.get(key); } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy