All Downloads are FREE. Search and download functionalities are using the official Maven repository.

.simplednn.0.5.4.source-code.SumSignRelevanceTest.kt Maven / Gradle / Ivy

Go to download

SimpleDNN is a machine learning lightweight open-source library written in Kotlin whose purpose is to support the development of feed-forward and recurrent Artificial Neural Networks.

There is a newer version: 0.14.0
Show newest version
/* Copyright 2016-present The KotlinNLP Authors. All Rights Reserved.
 *
 * This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this
 * file, you can obtain one at http://mozilla.org/MPL/2.0/.
 * ------------------------------------------------------------------*/

import com.kotlinnlp.simplednn.core.arrays.DistributionArray
import com.kotlinnlp.simplednn.core.optimizer.ParamsOptimizer
import com.kotlinnlp.simplednn.core.functionalities.activations.Softmax
import com.kotlinnlp.simplednn.core.functionalities.updatemethods.learningrate.LearningRateMethod
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArrayFactory
import com.kotlinnlp.simplednn.core.functionalities.activations.Tanh
import com.kotlinnlp.simplednn.core.functionalities.losses.SoftmaxCrossEntropyCalculator
import com.kotlinnlp.simplednn.core.neuralprocessor.recurrent.RecurrentNeuralProcessor
import com.kotlinnlp.simplednn.dataset.*
import com.kotlinnlp.simplednn.core.functionalities.outputevaluation.ClassificationEvaluation
import com.kotlinnlp.simplednn.core.neuralnetwork.preset.SimpleRecurrentNeuralNetwork
import com.kotlinnlp.simplednn.helpers.training.SequenceWithFinalOutputTrainingHelper
import com.kotlinnlp.simplednn.helpers.validation.SequenceWithFinalOutputValidationHelper
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray

fun main(args: Array) {

  println("Start 'Sum Sign Relevance Test'")

  SumSignRelevanceTest(dataset = DatasetBuilder.build()).start()

  println("End.")
}

/**
 *
 */
object DatasetBuilder {

  /**
   * @return a dataset of examples of sequences
   */
  fun build(): Corpus> = Corpus(
    training = arrayListOf(*Array(size = 10000, init = { this.createExample() })),
    validation = arrayListOf(*Array(size = 1000, init = { this.createExample() })),
    test = arrayListOf(*Array(size = 100, init = { this.createExample() }))
  )

  /**
   * @return an example containing a sequence of single features with a random value in {-1.0, 0.0, 1.0} and a gold
   *         output which is the sign of the sum of the features (represented by a one hot encoder [0 = negative,
   *         1 = zero sign, 2 = positive])
   */
  private fun createExample(): SequenceExampleWithFinalOutput {

    val features = arrayListOf(*Array(size = 10, init = { this.getRandomInput() }))
    val outputGoldIndex = Math.signum(features.sumByDouble { it[0] }) + 1

    return SequenceExampleWithFinalOutput(
      sequenceFeatures = features,
      outputGold = DenseNDArrayFactory.oneHotEncoder(length = 3, oneAt = outputGoldIndex.toInt())
    )
  }

  /**
   * @return a [DenseNDArray] containing a single value within {-1.0, 0.0, 1.0}
   */
  private fun getRandomInput(): DenseNDArray {
    val value = Math.round(Math.random() * 2.0 - 1.0).toDouble()
    return DenseNDArrayFactory.arrayOf(doubleArrayOf(value))
  }
}

/**
 *
 */
class SumSignRelevanceTest(val dataset: Corpus>) {

  /**
   * The number of examples to print.
   */
  private val examplesToPrint: Int = 20

  /**
   *
   */
  private val neuralNetwork = SimpleRecurrentNeuralNetwork(
    inputSize = 1,
    hiddenSize = 10,
    hiddenActivation = Tanh(),
    outputSize = 3,
    outputActivation = Softmax())

  /**
   * Start the test.
   */
  fun start() {

    this.train()
    this.printRelevance()
  }

  /**
   * Train the network on a dataset.
   */
  private fun train() {

    println("\n-- TRAINING\n")

    val optimizer = ParamsOptimizer(
      params = this.neuralNetwork.model,
      updateMethod = LearningRateMethod(learningRate = 0.01))

    val trainingHelper = SequenceWithFinalOutputTrainingHelper(
      neuralProcessor = RecurrentNeuralProcessor(this.neuralNetwork),
      optimizer = optimizer,
      lossCalculator = SoftmaxCrossEntropyCalculator(),
      verbose = true)

    val validationHelper = SequenceWithFinalOutputValidationHelper(
      neuralProcessor = RecurrentNeuralProcessor(this.neuralNetwork),
      outputEvaluationFunction = ClassificationEvaluation())

    trainingHelper.train(
      trainingExamples = this.dataset.training,
      validationExamples = this.dataset.validation,
      epochs = 3,
      shuffler = Shuffler(enablePseudoRandom = true, seed = 1),
      batchSize = 1,
      validationHelper = validationHelper)
  }

  /**
   * Print the relevance of each example of the dataset.
   */
  private fun printRelevance() {

    println("\n-- PRINT RELEVANCE OF %d EXAMPLES\n".format(this.examplesToPrint))

    val validationProcessor = RecurrentNeuralProcessor(neuralNetwork)

    val validationHelper = SequenceWithFinalOutputValidationHelper(
      neuralProcessor = validationProcessor,
      outputEvaluationFunction = ClassificationEvaluation())

    var exampleIndex = 0

    validationHelper.validate(
      examples = this.dataset.test,
      saveContributions = true,
      onPrediction = { example, isCorrect ->
        if (isCorrect && exampleIndex < this.examplesToPrint) {
          this.printSequenceRelevance(
            neuralProcessor = validationProcessor,
            example = example,
            exampleIndex = exampleIndex++)
        }
      })
  }

  /**
   * Print the relevance of each input of the sequence.
   *
   * @param neuralProcessor the neural processor of the validation
   * @param example the validated sequence
   */
  private fun printSequenceRelevance(neuralProcessor: RecurrentNeuralProcessor,
                                     example: SequenceExampleWithFinalOutput,
                                     exampleIndex: Int) {

    val sequenceRelevance = this.getSequenceRelevance(
      neuralProcessor = neuralProcessor,
      outputGold = example.outputGold
    )

    println("EXAMPLE %d".format(exampleIndex + 1))
    println("Gold: %d".format(example.outputGold.argMaxIndex() - 1))
    println("Sequence (input | relevance):")

    (0 until sequenceRelevance.size).forEach { i ->
      println("\t%4.1f | %8.1f".format(example.sequenceFeatures[i][0], sequenceRelevance[i]))
    }

    println()
  }

  /**
   * @param neuralProcessor the neural processor of the validation
   * @param outputGold the gold output array
   *
   * @return an array containing the relevance for each input of the sequence in respect of the gold output
   */
  private fun getSequenceRelevance(neuralProcessor: RecurrentNeuralProcessor,
                                   outputGold: DenseNDArray): Array {

    val outcomesDistr = DistributionArray.oneHot(length = 3, oneAt = outputGold.argMaxIndex())

    return Array(
      size = 10,
      init = { i ->
        neuralProcessor.calculateRelevance(
          stateFrom = i,
          stateTo = 9,
          relevantOutcomesDistribution = outcomesDistr).sum()
      }
    )
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy