All Downloads are FREE. Search and download functionalities are using the official Maven repository.

scripts.nn.layers.cross_entropy_loss.dml Maven / Gradle / Ivy

There is a newer version: 1.2.0
Show newest version
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------

/*
 * Cross-Entropy loss function.
 */

forward = function(matrix[double] pred, matrix[double] y)
    return (double loss) {
  /*
   * Computes the forward pass for a cross-entropy loss function.  The
   * inputs consist of N examples, each with K dimensions corresponding
   * to normalized probabilities of K classes.
   *
   *   ```
   *   L_i = -y_i^T * log(pred_i)
   *   L = (1/N) sum(L_i) for i=1 to N
   *   ```
   *
   * In these equations, `L` is the total loss, `L_i` is the loss for
   * example `i`, `y_i` is the K-dimensional vector of target class
   * probabilities, `pred_i` is K-dimensional vector of predicted
   * class probabilities, and `N` is the number of examples.
   *
   * This can be interpreted as the negative log-likelihood assuming
   * a Bernoulli distribution generalized to K dimensions, or a
   * Multinomial with one observation.
   *
   * Inputs:
   *  - pred: Predictions, of shape (N, K).
   *  - y: Targets, of shape (N, K).
   *
   * Outputs:
   *  - loss: Average loss.
   */
  N = nrow(y)
  eps = 1e-10  # numerical stability to avoid log(0)
  losses = rowSums(-y * log(pred+eps))
  loss = sum(losses) / N
}

backward = function(matrix[double] pred, matrix[double] y)
    return (matrix[double] dpred) {
  /*
   * Computes the backward pass of a cross-entropy loss function.  The
   * inputs consist of N examples, each with K dimensions corresponding
   * to normalized probabilities of K classes.
   *
   * Inputs:
   *  - pred: Predictions, of shape (N, K).
   *  - y: Targets, of shape (N, K).
   *
   * Outputs:
   *  - dpred: Gradient wrt `pred`, of shape (N, K).
   */
  N = nrow(y)
  eps = 1e-10  # numerical stability to avoid divide-by-zero
  dpred = (1/N) * -y * (1/(pred+eps))
}





© 2015 - 2024 Weber Informatics LLC | Privacy Policy