All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.bnd.math.business.learning.Trainer.scala Maven / Gradle / Ivy

The newest version!
package com.bnd.math.business.learning

import java.util.ArrayList

import com.bnd.core.runnable._
import com.bnd.math.BndMathException
import com.bnd.math.domain.learning.MachineLearningSetting
import com.bnd.core.runnable.TimeRunnable
import com.bnd.core.runnable.SeqIndexAccessible.Implicits._

import scala.collection.mutable.{Publisher, Subscriber}

/**
 * Abstract class representing a trainer training underlying TimeRunnable.
 *
 * @author © Peter Banda
 * @since 2015
 */
abstract class Trainer[T: Manifest, C, S[X]: SeqIndexAccessible](
  calcOutputError : (T, T) => T,
  calcErrorMetrics : (Seq[T], Seq[T]) => T)(
  setting : MachineLearningSetting,
  trainingStream : IOStream[T])(
  learner : TimeRunnable with Publisher[StateEvent[T, S]],
  outputComponents : Iterable[C]
) extends Subscriber[StateEvent[T, S], Publisher[StateEvent[T, S]]] {

  // Training errors
  private val _errors = new ArrayList[T]
  private val _outputs = new ArrayList[T]

  // Desired output iterator
  private val desiredOutputIterator = trainingStream.outputStream.iterator

  // Component Indices
  private var componentIndices : Option[Map[C, Int]] = None

  // Training time iterator and current
  val trainingTimeIterator : Iterator[BigDecimal] = {
    val initialDelay : BigDecimal = setting.getInitialDelay : Double
    val singleIterationLength : BigDecimal = setting.getSingleIterationLength : Double
    val outputInterpretationRelativeTime : BigDecimal = setting.getOutputInterpretationRelativeTime : Double
    val initInterpretationTime = initialDelay + outputInterpretationRelativeTime + singleIterationLength * trainingStream.outputShift
    Stream.iterate(initInterpretationTime)(singleIterationLength + _)
  }.iterator

  protected var currentTrainingTime = trainingTimeIterator.next
  protected var currentTrainingIteration = 0

  // The trainer listens to the underlying learner
  learner.subscribe(this)

  def notify(pub: Publisher[StateEvent[T, S]], event: StateEvent[T, S]) =
    event match {
      case StateUpdatedEvent(time : BigDecimal, components : Iterable[C], state : S[T]) =>
        if (time >= currentTrainingTime) {
          if (!componentIndices.isDefined) componentIndices = Some(components.zipWithIndex.toMap)
          train(components, state)
        }
      case _ =>
    }

  def train(iterationNum: Int) : Unit =
    (1 to iterationNum).foreach{iteration =>
      currentTrainingIteration = iteration
      learner.runUntil(currentTrainingTime)
      // hacky solution for an inclusion/exclusion of the input/output time
      learner.runFor(0d)
      currentTrainingTime = trainingTimeIterator.next
    }

  private def train(components : Iterable[C], state : S[T]): Unit = {
//    println("State: " + state)
    if (!desiredOutputIterator.hasNext) {
      throw new BndMathException(s"No more desired outputs available for a trainer at time ${currentTrainingTime.toDouble}.")
    }

    val desiredOutputs = desiredOutputIterator.next
    val outputComponentStatePairs = outputComponents.map(outputComponent =>
        (outputComponent, componentState(outputComponent, state)))

    val outputComponentDesiredOutputErrorTuples = (outputComponentStatePairs, desiredOutputs).zipped.map{
      case ((outputComponent, output), desiredOutput) =>
        (outputComponent, desiredOutput, calcOutputError(desiredOutput, output))}

    val error = calcErrorMetrics(outputComponentStatePairs.map(_._2).toSeq, desiredOutputs)
    _errors.add(error)
    _outputs.add(outputComponentStatePairs.map(_._2).head)

    adapt(outputComponentDesiredOutputErrorTuples, state)
  }

  private[learning] def adapt(outputComponentErrorPairs : Iterable[(C, T, T)], state : S[T])

  protected def componentState(component: C, state: S[T]): T = state(componentIndices.get.get(component).get)

  def errors() = _errors

  def outputs() = _outputs
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy