.simplednn.0.5.4.source-code.SumSignRelevanceTest.kt Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of simplednn Show documentation
Show all versions of simplednn Show documentation
SimpleDNN is a machine learning lightweight open-source library written in Kotlin whose purpose is to
support the development of feed-forward and recurrent Artificial Neural Networks.
/* Copyright 2016-present The KotlinNLP Authors. All Rights Reserved.
*
* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, you can obtain one at http://mozilla.org/MPL/2.0/.
* ------------------------------------------------------------------*/
import com.kotlinnlp.simplednn.core.arrays.DistributionArray
import com.kotlinnlp.simplednn.core.optimizer.ParamsOptimizer
import com.kotlinnlp.simplednn.core.functionalities.activations.Softmax
import com.kotlinnlp.simplednn.core.functionalities.updatemethods.learningrate.LearningRateMethod
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArrayFactory
import com.kotlinnlp.simplednn.core.functionalities.activations.Tanh
import com.kotlinnlp.simplednn.core.functionalities.losses.SoftmaxCrossEntropyCalculator
import com.kotlinnlp.simplednn.core.neuralprocessor.recurrent.RecurrentNeuralProcessor
import com.kotlinnlp.simplednn.dataset.*
import com.kotlinnlp.simplednn.core.functionalities.outputevaluation.ClassificationEvaluation
import com.kotlinnlp.simplednn.core.neuralnetwork.preset.SimpleRecurrentNeuralNetwork
import com.kotlinnlp.simplednn.helpers.training.SequenceWithFinalOutputTrainingHelper
import com.kotlinnlp.simplednn.helpers.validation.SequenceWithFinalOutputValidationHelper
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray
fun main(args: Array) {
println("Start 'Sum Sign Relevance Test'")
SumSignRelevanceTest(dataset = DatasetBuilder.build()).start()
println("End.")
}
/**
*
*/
object DatasetBuilder {
/**
* @return a dataset of examples of sequences
*/
fun build(): Corpus> = Corpus(
training = arrayListOf(*Array(size = 10000, init = { this.createExample() })),
validation = arrayListOf(*Array(size = 1000, init = { this.createExample() })),
test = arrayListOf(*Array(size = 100, init = { this.createExample() }))
)
/**
* @return an example containing a sequence of single features with a random value in {-1.0, 0.0, 1.0} and a gold
* output which is the sign of the sum of the features (represented by a one hot encoder [0 = negative,
* 1 = zero sign, 2 = positive])
*/
private fun createExample(): SequenceExampleWithFinalOutput {
val features = arrayListOf(*Array(size = 10, init = { this.getRandomInput() }))
val outputGoldIndex = Math.signum(features.sumByDouble { it[0] }) + 1
return SequenceExampleWithFinalOutput(
sequenceFeatures = features,
outputGold = DenseNDArrayFactory.oneHotEncoder(length = 3, oneAt = outputGoldIndex.toInt())
)
}
/**
* @return a [DenseNDArray] containing a single value within {-1.0, 0.0, 1.0}
*/
private fun getRandomInput(): DenseNDArray {
val value = Math.round(Math.random() * 2.0 - 1.0).toDouble()
return DenseNDArrayFactory.arrayOf(doubleArrayOf(value))
}
}
/**
*
*/
class SumSignRelevanceTest(val dataset: Corpus>) {
/**
* The number of examples to print.
*/
private val examplesToPrint: Int = 20
/**
*
*/
private val neuralNetwork = SimpleRecurrentNeuralNetwork(
inputSize = 1,
hiddenSize = 10,
hiddenActivation = Tanh(),
outputSize = 3,
outputActivation = Softmax())
/**
* Start the test.
*/
fun start() {
this.train()
this.printRelevance()
}
/**
* Train the network on a dataset.
*/
private fun train() {
println("\n-- TRAINING\n")
val optimizer = ParamsOptimizer(
params = this.neuralNetwork.model,
updateMethod = LearningRateMethod(learningRate = 0.01))
val trainingHelper = SequenceWithFinalOutputTrainingHelper(
neuralProcessor = RecurrentNeuralProcessor(this.neuralNetwork),
optimizer = optimizer,
lossCalculator = SoftmaxCrossEntropyCalculator(),
verbose = true)
val validationHelper = SequenceWithFinalOutputValidationHelper(
neuralProcessor = RecurrentNeuralProcessor(this.neuralNetwork),
outputEvaluationFunction = ClassificationEvaluation())
trainingHelper.train(
trainingExamples = this.dataset.training,
validationExamples = this.dataset.validation,
epochs = 3,
shuffler = Shuffler(enablePseudoRandom = true, seed = 1),
batchSize = 1,
validationHelper = validationHelper)
}
/**
* Print the relevance of each example of the dataset.
*/
private fun printRelevance() {
println("\n-- PRINT RELEVANCE OF %d EXAMPLES\n".format(this.examplesToPrint))
val validationProcessor = RecurrentNeuralProcessor(neuralNetwork)
val validationHelper = SequenceWithFinalOutputValidationHelper(
neuralProcessor = validationProcessor,
outputEvaluationFunction = ClassificationEvaluation())
var exampleIndex = 0
validationHelper.validate(
examples = this.dataset.test,
saveContributions = true,
onPrediction = { example, isCorrect ->
if (isCorrect && exampleIndex < this.examplesToPrint) {
this.printSequenceRelevance(
neuralProcessor = validationProcessor,
example = example,
exampleIndex = exampleIndex++)
}
})
}
/**
* Print the relevance of each input of the sequence.
*
* @param neuralProcessor the neural processor of the validation
* @param example the validated sequence
*/
private fun printSequenceRelevance(neuralProcessor: RecurrentNeuralProcessor,
example: SequenceExampleWithFinalOutput,
exampleIndex: Int) {
val sequenceRelevance = this.getSequenceRelevance(
neuralProcessor = neuralProcessor,
outputGold = example.outputGold
)
println("EXAMPLE %d".format(exampleIndex + 1))
println("Gold: %d".format(example.outputGold.argMaxIndex() - 1))
println("Sequence (input | relevance):")
(0 until sequenceRelevance.size).forEach { i ->
println("\t%4.1f | %8.1f".format(example.sequenceFeatures[i][0], sequenceRelevance[i]))
}
println()
}
/**
* @param neuralProcessor the neural processor of the validation
* @param outputGold the gold output array
*
* @return an array containing the relevance for each input of the sequence in respect of the gold output
*/
private fun getSequenceRelevance(neuralProcessor: RecurrentNeuralProcessor,
outputGold: DenseNDArray): Array {
val outcomesDistr = DistributionArray.oneHot(length = 3, oneAt = outputGold.argMaxIndex())
return Array(
size = 10,
init = { i ->
neuralProcessor.calculateRelevance(
stateFrom = i,
stateTo = 9,
relevantOutcomesDistribution = outcomesDistr).sum()
}
)
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy