All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.mllib.clustering.GaussianMixture.scala Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.mllib.clustering

import scala.collection.mutable.IndexedSeq

import breeze.linalg.{diag, DenseMatrix => BreezeMatrix, DenseVector => BDV, Vector => BV}

import org.apache.spark.annotation.Since
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, Matrices, Vector, Vectors}
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils

/**
 * This class performs expectation maximization for multivariate Gaussian
 * Mixture Models (GMMs).  A GMM represents a composite distribution of
 * independent Gaussian distributions with associated "mixing" weights
 * specifying each's contribution to the composite.
 *
 * Given a set of sample points, this class will maximize the log-likelihood
 * for a mixture of k Gaussians, iterating until the log-likelihood changes by
 * less than convergenceTol, or until it has reached the max number of iterations.
 * While this process is generally guaranteed to converge, it is not guaranteed
 * to find a global optimum.
 *
 * Note: For high-dimensional data (with many features), this algorithm may perform poorly.
 *       This is due to high-dimensional data (a) making it difficult to cluster at all (based
 *       on statistical/theoretical arguments) and (b) numerical issues with Gaussian distributions.
 *
 * @param k The number of independent Gaussians in the mixture model
 * @param convergenceTol The maximum change in log-likelihood at which convergence
 * is considered to have occurred.
 * @param maxIterations The maximum number of iterations to perform
 */
@Since("1.3.0")
class GaussianMixture private (
    private var k: Int,
    private var convergenceTol: Double,
    private var maxIterations: Int,
    private var seed: Long) extends Serializable {

  /**
   * Constructs a default instance. The default parameters are {k: 2, convergenceTol: 0.01,
   * maxIterations: 100, seed: random}.
   */
  @Since("1.3.0")
  def this() = this(2, 0.01, 100, Utils.random.nextLong())

  // number of samples per cluster to use when initializing Gaussians
  private val nSamples = 5

  // an initializing GMM can be provided rather than using the
  // default random starting point
  private var initialModel: Option[GaussianMixtureModel] = None

  /**
   * Set the initial GMM starting point, bypassing the random initialization.
   * You must call setK() prior to calling this method, and the condition
   * (model.k == this.k) must be met; failure will result in an IllegalArgumentException
   */
  @Since("1.3.0")
  def setInitialModel(model: GaussianMixtureModel): this.type = {
    if (model.k == k) {
      initialModel = Some(model)
    } else {
      throw new IllegalArgumentException("mismatched cluster count (model.k != k)")
    }
    this
  }

  /**
   * Return the user supplied initial GMM, if supplied
   */
  @Since("1.3.0")
  def getInitialModel: Option[GaussianMixtureModel] = initialModel

  /**
   * Set the number of Gaussians in the mixture model.  Default: 2
   */
  @Since("1.3.0")
  def setK(k: Int): this.type = {
    this.k = k
    this
  }

  /**
   * Return the number of Gaussians in the mixture model
   */
  @Since("1.3.0")
  def getK: Int = k

  /**
   * Set the maximum number of iterations to run. Default: 100
   */
  @Since("1.3.0")
  def setMaxIterations(maxIterations: Int): this.type = {
    this.maxIterations = maxIterations
    this
  }

  /**
   * Return the maximum number of iterations to run
   */
  @Since("1.3.0")
  def getMaxIterations: Int = maxIterations

  /**
   * Set the largest change in log-likelihood at which convergence is
   * considered to have occurred.
   */
  @Since("1.3.0")
  def setConvergenceTol(convergenceTol: Double): this.type = {
    this.convergenceTol = convergenceTol
    this
  }

  /**
   * Return the largest change in log-likelihood at which convergence is
   * considered to have occurred.
   */
  @Since("1.3.0")
  def getConvergenceTol: Double = convergenceTol

  /**
   * Set the random seed
   */
  @Since("1.3.0")
  def setSeed(seed: Long): this.type = {
    this.seed = seed
    this
  }

  /**
   * Return the random seed
   */
  @Since("1.3.0")
  def getSeed: Long = seed

  /**
   * Perform expectation maximization
   */
  @Since("1.3.0")
  def run(data: RDD[Vector]): GaussianMixtureModel = {
    val sc = data.sparkContext

    // we will operate on the data as breeze data
    val breezeData = data.map(_.toBreeze).cache()

    // Get length of the input vectors
    val d = breezeData.first().length

    val shouldDistributeGaussians = GaussianMixture.shouldDistributeGaussians(k, d)

    // Determine initial weights and corresponding Gaussians.
    // If the user supplied an initial GMM, we use those values, otherwise
    // we start with uniform weights, a random mean from the data, and
    // diagonal covariance matrices using component variances
    // derived from the samples
    val (weights, gaussians) = initialModel match {
      case Some(gmm) => (gmm.weights, gmm.gaussians)

      case None => {
        val samples = breezeData.takeSample(withReplacement = true, k * nSamples, seed)
        (Array.fill(k)(1.0 / k), Array.tabulate(k) { i =>
          val slice = samples.view(i * nSamples, (i + 1) * nSamples)
          new MultivariateGaussian(vectorMean(slice), initCovariance(slice))
        })
      }
    }

    var llh = Double.MinValue // current log-likelihood
    var llhp = 0.0            // previous log-likelihood

    var iter = 0
    while (iter < maxIterations && math.abs(llh-llhp) > convergenceTol) {
      // create and broadcast curried cluster contribution function
      val compute = sc.broadcast(ExpectationSum.add(weights, gaussians)_)

      // aggregate the cluster contribution for all sample points
      val sums = breezeData.aggregate(ExpectationSum.zero(k, d))(compute.value, _ += _)

      // Create new distributions based on the partial assignments
      // (often referred to as the "M" step in literature)
      val sumWeights = sums.weights.sum

      if (shouldDistributeGaussians) {
        val numPartitions = math.min(k, 1024)
        val tuples =
          Seq.tabulate(k)(i => (sums.means(i), sums.sigmas(i), sums.weights(i)))
        val (ws, gs) = sc.parallelize(tuples, numPartitions).map { case (mean, sigma, weight) =>
          updateWeightsAndGaussians(mean, sigma, weight, sumWeights)
        }.collect().unzip
        Array.copy(ws.toArray, 0, weights, 0, ws.length)
        Array.copy(gs.toArray, 0, gaussians, 0, gs.length)
      } else {
        var i = 0
        while (i < k) {
          val (weight, gaussian) =
            updateWeightsAndGaussians(sums.means(i), sums.sigmas(i), sums.weights(i), sumWeights)
          weights(i) = weight
          gaussians(i) = gaussian
          i = i + 1
        }
      }

      llhp = llh // current becomes previous
      llh = sums.logLikelihood // this is the freshly computed log-likelihood
      iter += 1
    }

    new GaussianMixtureModel(weights, gaussians)
  }

  /**
   * Java-friendly version of [[run()]]
   */
  @Since("1.3.0")
  def run(data: JavaRDD[Vector]): GaussianMixtureModel = run(data.rdd)

  private def updateWeightsAndGaussians(
      mean: BDV[Double],
      sigma: BreezeMatrix[Double],
      weight: Double,
      sumWeights: Double): (Double, MultivariateGaussian) = {
    val mu = (mean /= weight)
    BLAS.syr(-weight, Vectors.fromBreeze(mu),
      Matrices.fromBreeze(sigma).asInstanceOf[DenseMatrix])
    val newWeight = weight / sumWeights
    val newGaussian = new MultivariateGaussian(mu, sigma / weight)
    (newWeight, newGaussian)
  }

  /** Average of dense breeze vectors */
  private def vectorMean(x: IndexedSeq[BV[Double]]): BDV[Double] = {
    val v = BDV.zeros[Double](x(0).length)
    x.foreach(xi => v += xi)
    v / x.length.toDouble
  }

  /**
   * Construct matrix where diagonal entries are element-wise
   * variance of input vectors (computes biased variance)
   */
  private def initCovariance(x: IndexedSeq[BV[Double]]): BreezeMatrix[Double] = {
    val mu = vectorMean(x)
    val ss = BDV.zeros[Double](x(0).length)
    x.foreach(xi => ss += (xi - mu) :^ 2.0)
    diag(ss / x.length.toDouble)
  }
}

private[clustering] object GaussianMixture {
  /**
   * Heuristic to distribute the computation of the [[MultivariateGaussian]]s, approximately when
   * d > 25 except for when k is very small.
   * @param k  Number of topics
   * @param d  Number of features
   */
  def shouldDistributeGaussians(k: Int, d: Int): Boolean = ((k - 1.0) / k) * d > 25
}

// companion class to provide zero constructor for ExpectationSum
private object ExpectationSum {
  def zero(k: Int, d: Int): ExpectationSum = {
    new ExpectationSum(0.0, Array.fill(k)(0.0),
      Array.fill(k)(BDV.zeros(d)), Array.fill(k)(BreezeMatrix.zeros(d, d)))
  }

  // compute cluster contributions for each input point
  // (U, T) => U for aggregation
  def add(
      weights: Array[Double],
      dists: Array[MultivariateGaussian])
      (sums: ExpectationSum, x: BV[Double]): ExpectationSum = {
    val p = weights.zip(dists).map {
      case (weight, dist) => MLUtils.EPSILON + weight * dist.pdf(x)
    }
    val pSum = p.sum
    sums.logLikelihood += math.log(pSum)
    var i = 0
    while (i < sums.k) {
      p(i) /= pSum
      sums.weights(i) += p(i)
      sums.means(i) += x * p(i)
      BLAS.syr(p(i), Vectors.fromBreeze(x),
        Matrices.fromBreeze(sums.sigmas(i)).asInstanceOf[DenseMatrix])
      i = i + 1
    }
    sums
  }
}

// Aggregation class for partial expectation results
private class ExpectationSum(
    var logLikelihood: Double,
    val weights: Array[Double],
    val means: Array[BDV[Double]],
    val sigmas: Array[BreezeMatrix[Double]]) extends Serializable {

  val k = weights.length

  def +=(x: ExpectationSum): ExpectationSum = {
    var i = 0
    while (i < k) {
      weights(i) += x.weights(i)
      means(i) += x.means(i)
      sigmas(i) += x.sigmas(i)
      i = i + 1
    }
    logLikelihood += x.logLikelihood
    this
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy