mnist.MNISTSparseBinaryTest.kt Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of simplednn Show documentation
Show all versions of simplednn Show documentation
SimpleDNN is a machine learning lightweight open-source library written in Kotlin whose purpose is to
support the development of feed-forward and recurrent Artificial Neural Networks.
/* Copyright 2016-present The KotlinNLP Authors. All Rights Reserved.
*
* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, you can obtain one at http://mozilla.org/MPL/2.0/.
* ------------------------------------------------------------------*/
package mnist
import com.kotlinnlp.simplednn.core.optimizer.ParamsOptimizer
import com.kotlinnlp.simplednn.core.functionalities.activations.ELU
import com.kotlinnlp.simplednn.core.functionalities.activations.Softmax
import com.kotlinnlp.simplednn.core.functionalities.decaymethods.HyperbolicDecay
import com.kotlinnlp.simplednn.core.functionalities.updatemethods.learningrate.LearningRateMethod
import com.kotlinnlp.simplednn.core.neuralnetwork.preset.FeedforwardNeuralNetwork
import com.kotlinnlp.simplednn.helpers.training.FeedforwardTrainingHelper
import com.kotlinnlp.simplednn.core.neuralprocessor.feedforward.FeedforwardNeuralProcessor
import com.kotlinnlp.simplednn.dataset.*
import com.kotlinnlp.simplednn.core.functionalities.outputevaluation.ClassificationEvaluation
import com.kotlinnlp.simplednn.helpers.validation.FeedforwardValidationHelper
import com.kotlinnlp.simplednn.core.arrays.DistributionArray
import com.kotlinnlp.simplednn.core.functionalities.losses.SoftmaxCrossEntropyCalculator
import com.kotlinnlp.simplednn.core.layers.LayerType
import com.kotlinnlp.simplednn.simplemath.ndarray.*
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArrayFactory
import com.kotlinnlp.simplednn.simplemath.ndarray.sparsebinary.SparseBinaryNDArray
import mnist.helpers.MNISTSparseExampleExtractor
import utils.CorpusReader
fun main(args: Array) {
println("Start 'MNIST Sparse Binary Test'")
val dataset = CorpusReader>().read(
corpusPath = Configuration.loadFromFile().mnist.datasets_paths,
exampleExtractor = MNISTSparseExampleExtractor(outputSize = 10),
perLine = false)
MNISTSparseBinaryTest(dataset).start()
println("End.")
}
/**
*
*/
class MNISTSparseBinaryTest(val dataset: Corpus>) {
/**
*
*/
private val neuralNetwork = FeedforwardNeuralNetwork(
inputSize = 784,
inputType = LayerType.Input.SparseBinary,
hiddenSize = 100,
hiddenActivation = ELU(),
outputSize = 10,
outputActivation = Softmax())
/**
*
*/
fun start() {
this.train()
this.printImages(examples = ArrayList(this.dataset.validation.subList(0, 20))) // reduced sublist
}
/**
*
*/
private fun train() {
println("\n-- TRAINING\n")
val optimizer = ParamsOptimizer(
params = this.neuralNetwork.model,
updateMethod = LearningRateMethod(
learningRate = 0.01,
decayMethod = HyperbolicDecay(decay = 0.5, initLearningRate = 0.01)))
val trainingHelper = FeedforwardTrainingHelper(
neuralProcessor = FeedforwardNeuralProcessor(this.neuralNetwork),
optimizer = optimizer,
lossCalculator = SoftmaxCrossEntropyCalculator(),
verbose = true)
val validationHelper = FeedforwardValidationHelper(
neuralProcessor = FeedforwardNeuralProcessor(this.neuralNetwork),
outputEvaluationFunction = ClassificationEvaluation())
trainingHelper.train(
trainingExamples = this.dataset.training,
validationExamples = this.dataset.validation,
epochs = 3,
batchSize = 1,
shuffler = Shuffler(enablePseudoRandom = true, seed = 1),
validationHelper = validationHelper)
}
/**
*
*/
private fun printImages(examples: ArrayList>) {
println("\n-- PRINT IMAGES RELEVANCE\n")
val neuralProcessor = FeedforwardNeuralProcessor(neuralNetwork)
val validationHelper = FeedforwardValidationHelper(
neuralProcessor = neuralProcessor,
outputEvaluationFunction = ClassificationEvaluation())
validationHelper.validate(
examples = examples,
saveContributions = true,
onPrediction = { example, _ ->
val sparseRelevance = neuralProcessor.calculateInputRelevance(DistributionArray.uniform(length = 10))
val denseRelevance: DenseNDArray = DenseNDArrayFactory.zeros(Shape(784)).assignValues(sparseRelevance)
this.printImage(image = denseRelevance, value = example.outputGold.argMaxIndex())
}
)
}
/**
*
*/
private fun printImage(image: DenseNDArray, value: Int) {
println("------------------ %d -----------------".format(value))
for (i in 0 until 28) {
for (j in 0 until 28) {
print(if (image[i * 28 + j] > 0.0) "# " else " ")
}
println()
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy