All Downloads are FREE. Search and download functionalities are using the official Maven repository.

dk.bayes.learn.em.GenericEMLearn.scala Maven / Gradle / Ivy

The newest version!
package dk.bayes.learn.em
import dk.bayes.model.clustergraph.ClusterGraph
import dk.bayes.infer.LoopyBP
import dk.bayes.model.clustergraph.factor.Factor
import scala.collection._
import scala.annotation.tailrec
import EMLearn._

/**
 * Default implementation of EM algorithm.
 *
 * @author Daniel Korzekwa
 */
object GenericEMLearn extends EMLearn {

  private val ZERO_PROBABILITY = 1.0E-20

  /**
   * Represents sufficient statistics produced by E-step of EM algorithm.
   */
  case class SufficientStats(clusterBeliefsByTypeId: Seq[Tuple2[Int, Factor]], logLikelihood: Double)

  def learn(clusterGraph: ClusterGraph, trainSet: DataSet, maxIterNum: Int, progress: (Progress) => Unit = (progress: Progress) => {}) = {

    require(trainSet.samples.size > 0, "No samples found in training set")

    /**
     * Returns learned cluster potentials by cluster type id.
     */
    @tailrec
    def trainRecursive(clusterPotentialsByTypeId: Map[Int, Factor], currIter: Int): Map[Int, Factor] = {

      val sufficientStats = expectationStep(clusterGraph, clusterPotentialsByTypeId, trainSet)
      val newClusterPotentialsByTypeId = maximisationStep(sufficientStats.clusterBeliefsByTypeId)

      progress(Progress(currIter, sufficientStats.logLikelihood))

      if (currIter < maxIterNum) trainRecursive(newClusterPotentialsByTypeId, currIter + 1)
      else newClusterPotentialsByTypeId
    }

    val loopyBP = LoopyBP(clusterGraph)
    /**Seq[Tuple2[clusterTypeId,cluster initial potentials]]*/
    val clusterPotentials = clusterGraph.getClusters().map(c => c.typeId -> loopyBP.clusterBelief(c.id))
    val clusterPotentialsByTypeId: Map[Int, Factor] = maximisationStep(clusterPotentials)

    val finalClusterPotentialsByTypeId = trainRecursive(clusterPotentialsByTypeId, 1)

    updateInitialClusterPotentials(clusterGraph, finalClusterPotentialsByTypeId)

  }

  private def updateInitialClusterPotentials(clusterGraph: ClusterGraph, clusterPotentialsByTypeId: Map[Int, Factor]) {
    for (cluster <- clusterGraph.getClusters()) {
      val clusterTypePotentials = clusterPotentialsByTypeId(cluster.typeId)
      val newClusterPotentials = cluster.getFactor().copy(clusterTypePotentials.getValues())
      cluster.updateFactor(newClusterPotentials)
    }
  }

  /**
   * Returns sufficient statistics.
   */
  private def expectationStep(clusterGraph: ClusterGraph, clusterPotentialsByTypeId: Map[Int, Factor], trainSet: DataSet): SufficientStats = {

    val loopyBP = LoopyBP(clusterGraph)

    var dataLogLikelihood = 0d

    val clusterBeliefs: Seq[Tuple2[Int, Factor]] = trainSet.samples.flatMap { sample =>

      updateInitialClusterPotentials(clusterGraph, clusterPotentialsByTypeId)

      val evidence: Seq[Tuple2[Int, Int]] = DataSet.toEvidence(trainSet.variableIds, sample)

      val logLikelihood = loopyBP.calibrateWithEvidence(evidence)
      dataLogLikelihood += logLikelihood

      /**Seq[Tuple2[clusterTypeId,cluster belief]]*/
      val clusterBeliefs: Seq[Tuple2[Int, Factor]] = clusterGraph.getClusters().map(c => c.typeId -> loopyBP.clusterBelief(c.id))
      clusterBeliefs
    }

    SufficientStats(clusterBeliefs, dataLogLikelihood)
  }

  /**
   * Returns cluster initial potentials estimated from sufficient statistics using MLE (Maximum Likelihood Estimation).
   *
   * @returns Map[clusterTypeId, MLE estimated cluster initial potentials]
   */
  private def maximisationStep(clusterBeliefs: Seq[Tuple2[Int, Factor]]): Map[Int, Factor] = {

    val clusterBeliefsByTypeId: Map[Int, Seq[Factor]] = clusterBeliefs.groupBy(c => c._1).mapValues(v => v.map(_._2))

    val clusterPotentialsByTypeId: Map[Int, Factor] = clusterBeliefsByTypeId.map {
      case (clusterTypeId, clusterBeliefs) =>

        val clusterBeliefFactor = clusterBeliefs.head
        val beliefValuesSum = new Array[Double](clusterBeliefFactor.getValues().size)

        for (belief <- clusterBeliefs) {
          val beliefValues = belief.getValues().map(v => if (v == 0) ZERO_PROBABILITY else v)
          var i = 0
          while (i < beliefValues.size) {
            beliefValuesSum(i) += beliefValues(i)
            i += 1
          }
        }

        val cptVarSize = clusterBeliefFactor.getVariables().last.dim
        val cptValues = toCPT(beliefValuesSum, cptVarSize)
        val clusterPotentials = clusterBeliefFactor.copy(cptValues)

        (clusterTypeId -> clusterPotentials)
    }

    clusterPotentialsByTypeId
  }

  /**
   * Converts factor values to CPT values.
   */
  private def toCPT(values: Seq[Double], sliceSize: Int): Array[Double] = {

    val cptValues = values.isEmpty match {
      case true => Nil
      case false => values.grouped(sliceSize).flatMap { slice =>
      slice.map(elem => elem / slice.sum)
    }
    }

    cptValues.toArray
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy