ml.comet.examples.mnist.MnistExperimentExample Maven / Gradle / Ivy
package ml.comet.examples.mnist;
import com.beust.jcommander.JCommander;
import com.beust.jcommander.Parameter;
import ml.comet.experiment.ExperimentBuilder;
import ml.comet.experiment.OnlineExperiment;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.BaseTrainingListener;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
* A Simple Multi Layered Perceptron (MLP) applied to digit classification for
* the MNIST Dataset (
* This file builds one input layer and one hidden layer.
The input layer has input dimension of numRows*numColumns where these variables indicate the
* number of vertical and horizontal pixels in the image. This layer uses a rectified linear unit
* (relu) activation function. The weights for this layer are initialized by using Xavier initialization
* (
* to avoid having a steep learning curve. This layer will have 1000 output signals to the hidden layer.
The hidden layer has input dimensions of 1000. These are fed from the input layer. The weights
* for this layer is also initialized using Xavier initialization. The activation function for this
* layer is a softmax, which normalizes all the 10 outputs such that the normalized sums
* add up to 1. The highest of these normalized values is picked as the predicted class.
public final class MnistExperimentExample {
private static final Logger log = LoggerFactory.getLogger(MnistExperimentExample.class);
* The number of epochs to perform.
@Parameter(names = {"--epochs", "-e"}, description = "number of epochs to perform")
int numEpochs = 2;
* The experiment entry point.
You should set three environment variables to run this experiment:
* - COMET_API_KEY - the API key to access Comet (MANDATORY)
* - COMET_WORKSPACE_NAME - the name of the workspace for your project (OPTIONAL)
* - COMET_PROJECT_NAME - the name of the project (OPTIONAL)
* Alternatively you can set these values in the resources/application.conf file
* @param args the command line arguments.
public static void main(String[] args) {
MnistExperimentExample main = new MnistExperimentExample();
// update application.conf or provide environment variables as described in JavaDoc.
OnlineExperiment experiment = ExperimentBuilder
try {
} catch (Exception e) {
System.out.println("--- Failed to run experiment ---");
} finally {
// make sure to close experiment
* The experiment runner.
* @param experiment the Comet experiment instance.
* @throws IOException if any exception raised.
public void runMnistExperiment(OnlineExperiment experiment) throws IOException {"****************MNIST Experiment Example Started********************");
//number of rows and columns in the input pictures
final int numRows = 28;
final int numColumns = 28;
int outputNum = 10; // number of output classes
int batchSize = 128; // batch size for each epoch
int rngSeed = 123; // random number seed for reproducibility
experiment.logParameter("numRows", numRows);
experiment.logParameter("numColumns", numColumns);
experiment.logParameter("outputNum", outputNum);
experiment.logParameter("batchSize", batchSize);
experiment.logParameter("rngSeed", rngSeed);
experiment.logParameter("numEpochs", numEpochs);
double lr = 0.006;
double nesterovsMomentum = 0.9;
double l2Regularization = 1e-4;
experiment.logParameter("learningRate", lr);
experiment.logParameter("nesterovsMomentum", nesterovsMomentum);
experiment.logParameter("l2Regularization", l2Regularization);
OptimizationAlgorithm optimizationAlgorithm = OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT;
experiment.logParameter("optimizationAlgorithm", OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT);"Build model....");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(rngSeed) //include a random seed for reproducibility
// use stochastic gradient descent as an optimization algorithm
.updater(new Nesterovs(lr, nesterovsMomentum))
.layer(new DenseLayer.Builder() //create the first, input layer with xavier initialization
.nIn(numRows * numColumns)
.layer(new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD) //create hidden layer
MultiLayerNetwork model = new MultiLayerNetwork(conf);
//print the score with every 1 iteration
model.setListeners(new StepScoreListener(experiment, 1, log));
// Get the train dataset iterator
DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, rngSeed);"Train model....");, numEpochs);
// Get the test dataset iterator
DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, rngSeed);"Evaluate model....");
Evaluation eval = model.evaluate(mnistTest);;
experiment.logHtml(eval.getConfusionMatrix().toHTML(), false);"****************MNIST Experiment Example finished********************");
* The listener to be invoked at each iteration of model training.
static class StepScoreListener extends BaseTrainingListener {
private final OnlineExperiment experiment;
private int printIterations;
private final Logger log;
StepScoreListener(OnlineExperiment experiment, int printIterations, Logger log) {
this.experiment = experiment;
this.printIterations = printIterations;
this.log = log;
public void iterationDone(Model model, int iteration, int epoch) {
if (printIterations <= 0) {
printIterations = 1;
// print score and log metric
if (iteration % printIterations == 0) {
double result = model.score();"Score at step/epoch {}/{} is {} ", iteration, epoch, result);
this.experiment.logMetric("score", model.score(), iteration);