All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.kotlinnlp.simplednn.deeplearning.multitasknetwork.MultiTaskNetwork.kt Maven / Gradle / Ivy

Go to download

SimpleDNN is a machine learning lightweight open-source library written in Kotlin whose purpose is to support the development of feed-forward and recurrent Artificial Neural Networks.

There is a newer version: 0.14.0
Show newest version
/* Copyright 2016-present The KotlinNLP Authors. All Rights Reserved.
 *
 * This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this
 * file, you can obtain one at http://mozilla.org/MPL/2.0/.
 * ------------------------------------------------------------------*/

package com.kotlinnlp.simplednn.deeplearning.multitasknetwork

import com.kotlinnlp.simplednn.core.arrays.DistributionArray
import com.kotlinnlp.simplednn.core.neuralprocessor.feedforward.FeedforwardNeuralProcessor
import com.kotlinnlp.simplednn.simplemath.ndarray.NDArray
import com.kotlinnlp.simplednn.simplemath.ndarray.dense.DenseNDArray
import com.kotlinnlp.simplednn.utils.ItemsPool

/**
 * A multi-task network is composed by single input feed-forward layer shared by more networks, each with a own output
 * feed-forward layer.
 *
 * @property model the model of this network
 * @property id an identification number useful to track a specific [MultiTaskNetwork]
 */
class MultiTaskNetwork>(
  val model: MultiTaskNetworkModel,
  override val id: Int = 0
) : ItemsPool.IDItem {

  /**
   * The neural processor of the input network.
   */
  val inputProcessor = FeedforwardNeuralProcessor(this.model.inputNetwork)

  /**
   * The list of neural processors of the output networks.
   */
  val outputProcessors: List> =
    this.model.outputNetworks.map { FeedforwardNeuralProcessor(it) }

  /**
   * @param copy a Boolean indicating whether the returned errors must be a copy or a reference
   *
   * @return the errors of the neural parameters
   */
  fun getParamsErrors(copy: Boolean) = MultiTaskNetworkParameters(
    inputParams = this.inputProcessor.getParamsErrors(copy = copy),
    outputParamsList = this.outputProcessors.map { it.getParamsErrors(copy = copy) }
  )

  /**
   * @param copy a Boolean indicating whether the returned errors must be a copy or a reference
   *
   * @return the errors of the input
   */
  fun getInputErrors(copy: Boolean = true): DenseNDArray = this.inputProcessor.getInputErrors(copy = copy)

  /**
   * Forward features.
   *
   * @param featuresArray the features to forward from the input to the output
   * @param useDropout whether to apply the dropout
   *
   * @return the list of output arrays, one for each output network
   */
  fun forward(featuresArray: InputNDArrayType, useDropout: Boolean = false): List {

    val hiddenOutput: DenseNDArray = this.inputProcessor.forward(featuresArray = featuresArray, useDropout = useDropout)

    return this.outputProcessors.map { it.forward(featuresArray = hiddenOutput, useDropout = useDropout) }
  }

  /**
   * Forward features, saving the contributes of the input in respect of the output.
   *
   * @param featuresArray the features to forward from the input to the output
   * @param saveContributions whether to save the contributions of each input to its output (needed to calculate
   *                          the relevance)
   * @param useDropout whether to apply the dropout
   *
   * @return the list of output arrays, one for each output network
   */
  fun forward(featuresArray: InputNDArrayType,
              saveContributions: Boolean,
              useDropout: Boolean = false): List {

    val hiddenOutput: DenseNDArray = this.inputProcessor.forward(
      featuresArray = featuresArray,
      saveContributions = saveContributions,
      useDropout = useDropout)

    return this.outputProcessors.map {
      it.forward(featuresArray = hiddenOutput, saveContributions = saveContributions, useDropout = useDropout)
    }
  }

  /**
   * Calculate the relevance of the input respect to the output of the network with the given networkIndex,
   * propagating backward the given distribution on the outcomes.
   *
   * @param networkIndex the index of an output network (starting from 0)
   * @param relevantOutcomesDistribution the distribution which indicates which outcomes are relevant, used
   *                                     as reference to calculate the relevance of the input
   * @param copy whether to return a copy of the relevance or not (default = true)
   *
   * @return the input relevance array (If the input is Dense it is Dense, if the input is Sparse or SparseBinary it
   *         is Sparse)
   */
  fun calculateInputRelevance(networkIndex: Int,
                              relevantOutcomesDistribution: DistributionArray,
                              copy: Boolean = true): NDArray<*> {
    require(networkIndex in 0 .. this.outputProcessors.size) {
      "Invalid network index: %d. Must be in range [0, %d].".format(networkIndex, this.outputProcessors.lastIndex)
    }

    val hiddenRelevance: DenseNDArray = this.outputProcessors[networkIndex].calculateInputRelevance(
      relevantOutcomesDistribution = relevantOutcomesDistribution,
      copy = false) as DenseNDArray // the hidden layer is always dense, the same for its relevance

    return this.inputProcessor.calculateInputRelevance(
      relevantOutcomesDistribution = DistributionArray(hiddenRelevance),
      copy = copy)
  }

  /**
   * Backward errors.
   *
   * @param outputErrorsList the list of output errors, one for each output network
   * @param propagateToInput whether to propagate the errors to the input
   * @param inputMePropK the input layer k factor of the 'meProp' algorithm to propagate from the k (in percentage)
   *                     hidden nodes with the top errors (can be null)
   * @param outputMePropK a list of k factors (one for each output layer) of the 'meProp' algorithm to propagate from
   *                      the k (in percentage) output nodes with the top errors (the list and each element can be null)
   */
  fun backward(outputErrorsList: List,
               propagateToInput: Boolean = false,
               inputMePropK: Double? = null,
               outputMePropK: List? = null) {

    val hiddenErrors: DenseNDArray =
      this.backwardOutputProcessors(outputErrorsList = outputErrorsList, outputMePropK = outputMePropK)

    this.inputProcessor
      .backward(outputErrors = hiddenErrors, propagateToInput = propagateToInput, mePropK = listOf(inputMePropK))
  }

  /**
   * Output processors backwards.
   *
   * @param outputErrorsList the list of output errors, one for each output network
   * @param outputMePropK a list of k factors (one for each output layer) of the 'meProp' algorithm to propagate from
   *                      the k (in percentage) output nodes with the top errors (the list and each element can be null)
   *
   * @return the sum of the input errors of each output network
   */
  private fun backwardOutputProcessors(outputErrorsList: List,
                                       outputMePropK: List? = null): DenseNDArray {

    require(outputErrorsList.size == this.outputProcessors.size) {
      "The list of output errors must have a size equal to the number of output networks."
    }

    var hiddenErrors: DenseNDArray? = null

    this.outputProcessors.zip(outputErrorsList).forEachIndexed { i, (processor, errors) ->

      processor.backward(outputErrors = errors, propagateToInput = true, mePropK = listOf(outputMePropK?.get(i)))

      if (hiddenErrors == null)
        hiddenErrors = processor.getInputErrors(copy = true)
      else
        hiddenErrors!!.assignSum(processor.getInputErrors(copy = false))
    }

    return hiddenErrors!!
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy