scripts.nn.examples.mnist_softmax-predict.dml Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of systemml Show documentation
Show all versions of systemml Show documentation
Declarative Machine Learning
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
# MNIST Softmax - Predict
#
# This script computes the class probability predictions of a
# trained softmax classifier on images of handwritten digits.
#
# Inputs:
# - X: File containing training images.
# The format is "pixel_1, pixel_2, ..., pixel_n".
# - model_dir: Directory containing the trained weights and biases
# of the model.
# - out_dir: Directory to store class probability predictions for
# each image.
# - fmt: [DEFAULT: "csv"] File format of `X` and output predictions.
# Options include: "csv", "mm", "text", and "binary".
#
# Outputs:
# - probs: File containing class probability predictions for each
# image.
#
# Data:
# The X file should contain images of handwritten digits,
# where each example is a 28x28 pixel image of grayscale values in
# the range [0,255] stretched out as 784 pixels.
#
# Sample Invocation:
# 1. Download images.
#
# For example, save images to `nn/examples/data/mnist/images.csv`.
#
# 2. Execute using Spark
# ```
# spark-submit --master local[*] --driver-memory 5G
# --conf spark.driver.maxResultSize=0 --conf spark.rpc.message.maxSize=128
# $SYSTEMML_HOME/target/SystemML.jar -f nn/examples/mnist_softmax-predict.dml
# -nvargs X=nn/examples/data/mnist/images.csv
# model_dir=nn/examples/model/mnist_softmax out_dir=nn/examples/data/mnist
#
source("nn/examples/mnist_softmax.dml") as mnist_softmax
# Read training data
fmt = ifdef($fmt, "csv")
X = read($X, format=fmt)
# Scale images to [0,1], and one-hot encode the labels
X = X / 255.0
# Read model coefficients
W = read($model_dir+"/W")
b = read($model_dir+"/b")
# Predict classes
probs = mnist_softmax::predict(X, W, b)
# Output results
write(probs, $out_dir+"/probs."+fmt, format=fmt)