All Downloads are FREE. Search and download functionalities are using the official Maven repository.

scripts.nn.examples.mnist_softmax.dml Maven / Gradle / Ivy

There is a newer version: 1.2.0
Show newest version
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------

/*
 * MNIST Softmax Example
 */
# Imports
source("nn/layers/affine.dml") as affine
source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
source("nn/layers/softmax.dml") as softmax
source("nn/optim/sgd_nesterov.dml") as sgd_nesterov

train = function(matrix[double] X, matrix[double] Y,
                 matrix[double] X_val, matrix[double] Y_val,
                 int epochs)
    return (matrix[double] W, matrix[double] b) {
  /*
   * Trains a softmax classifier.
   *
   * The input matrix, X, has N examples, each with D features.
   * The targets, Y, have K classes, and are one-hot encoded.
   *
   * Inputs:
   *  - X: Input data matrix, of shape (N, D).
   *  - Y: Target matrix, of shape (N, K).
   *  - X_val: Input validation data matrix, of shape (N, C*Hin*Win).
   *  - Y_val: Target validation matrix, of shape (N, K).
   *  - epochs: Total number of full training loops over the full data set.
   *
   * Outputs:
   *  - W: Weights (parameters) matrix, of shape (D, M).
   *  - b: Biases vector, of shape (1, M).
   */
  N = nrow(X)  # num examples
  D = ncol(X)  # num features
  K = ncol(Y)  # num classes

  # Create softmax classifier:
  # affine -> softmax
  [W, b] = affine::init(D, K)
  W = W / sqrt(2.0/(D)) * sqrt(1/(D))

  # Initialize SGD w/ Nesterov momentum optimizer
  lr = 0.2  # learning rate
  mu = 0  # momentum
  decay = 0.99  # learning rate decay constant
  vW = sgd_nesterov::init(W)  # optimizer momentum state for W
  vb = sgd_nesterov::init(b)  # optimizer momentum state for b

  # Optimize
  print("Starting optimization")
  batch_size = 50
  iters = 1000 #ceil(N / batch_size)
  for (e in 1:epochs) {
    for(i in 1:iters) {
      # Get next batch
      beg = ((i-1) * batch_size) %% N + 1
      end = min(N, beg + batch_size - 1)
      X_batch = X[beg:end,]
      y_batch = Y[beg:end,]

      # Compute forward pass
      ## affine & softmax:
      out = affine::forward(X_batch, W, b)
      probs = softmax::forward(out)

      # Compute loss & accuracy for training & validation data
      loss = cross_entropy_loss::forward(probs, y_batch)
      accuracy = mean(rowIndexMax(probs) == rowIndexMax(y_batch))
      probs_val = predict(X_val, W, b)
      loss_val = cross_entropy_loss::forward(probs_val, Y_val)
      accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))
      print("Epoch: " + e + ", Iter: " + i + ", Train Loss: " + loss + ", Train Accuracy: " +
            accuracy + ", Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val)

      # Compute backward pass
      ## loss:
      dprobs = cross_entropy_loss::backward(probs, y_batch)
      ## affine & softmax:
      dout = softmax::backward(dprobs, out)
      [dX_batch, dW, db] = affine::backward(dout, X_batch, W, b)

      # Optimize with SGD w/ Nesterov momentum
      [W, vW] = sgd_nesterov::update(W, dW, lr, mu, vW)
      [b, vb] = sgd_nesterov::update(b, db, lr, mu, vb)
    }
    # Anneal momentum towards 0.999
    mu = mu + (0.999 - mu)/(1+epochs-e)
    # Decay learning rate
    lr = lr * decay
  }
}

predict = function(matrix[double] X, matrix[double] W, matrix[double] b)
    return (matrix[double] probs) {
  /*
   * Computes the class probability predictions of a softmax classifier.
   *
   * The input matrix, X, has N examples, each with D features.
   *
   * Inputs:
   *  - X: Input data matrix, of shape (N, D).
   *  - W: Weights (parameters) matrix, of shape (D, M).
   *  - b: Biases vector, of shape (1, M).
   *
   * Outputs:
   *  - probs: Class probabilities, of shape (N, K).
   */
  # Compute forward pass
  ## affine & softmax:
  out = affine::forward(X, W, b)
  probs = softmax::forward(out)
}

eval = function(matrix[double] probs, matrix[double] Y)
    return (double loss, double accuracy) {
  /*
   * Evaluates a softmax classifier.
   *
   * The probs matrix contains the class probability predictions
   * of K classes over N examples.  The targets, Y, have K classes,
   * and are one-hot encoded.
   *
   * Inputs:
   *  - probs: Class probabilities, of shape (N, K).
   *  - Y: Target matrix, of shape (N, K).
   *
   * Outputs:
   *  - loss: Scalar loss, of shape (1).
   *  - accuracy: Scalar accuracy, of shape (1).
   */
  # Compute loss & accuracy
  loss = cross_entropy_loss::forward(probs, Y)
  correct_pred = rowIndexMax(probs) == rowIndexMax(Y)
  accuracy = mean(correct_pred)
}

generate_dummy_data = function()
    return (matrix[double] X, matrix[double] Y, int C, int Hin, int Win) {
  /*
   * Generate a dummy dataset similar to the MNIST dataset.
   *
   * Outputs:
   *  - X: Input data matrix, of shape (N, D).
   *  - Y: Target matrix, of shape (N, K).
   *  - C: Number of input channels (dimensionality of input depth).
   *  - Hin: Input height.
   *  - Win: Input width.
   */
  # Generate dummy input data
  N = 1024  # num examples
  C = 1  # num input channels
  Hin = 28  # input height
  Win = 28  # input width
  T = 10  # num targets
  X = rand(rows=N, cols=C*Hin*Win, pdf="normal")
  classes = round(rand(rows=N, cols=1, min=1, max=T, pdf="uniform"))
  Y = table(seq(1, N), classes)  # one-hot encoding
}





© 2015 - 2024 Weber Informatics LLC | Privacy Policy