All Downloads are FREE. Search and download functionalities are using the official Maven repository.

breeze.cluster.KMeans.scala Maven / Gradle / Ivy

package breeze.cluster

import breeze.math.{MutableInnerProductSpace, MutableNormedSpace}
import breeze.numerics._
import breeze.linalg._
import breeze.util._
import breeze.util.Implicits._
import collection.mutable.ArrayBuffer
import breeze.stats.distributions.Multinomial

/**
 * Implements the KMeans++ algorithm for clustering points in a normed space.
 * @author dlwh
 */
class KMeans[T](k: Int, tolerance: Double = 1E-4)(implicit space: MutableInnerProductSpace[T, Double]) {
  import space._
  case class State(means: IndexedSeq[T], error: Double, previousError: Double=Double.PositiveInfinity) {
    def converged = closeTo(error,previousError,tolerance)
  }


  def cluster(points: IndexedSeq[T]) = {
    iterates(points).last
  }

  def iterates(points: IndexedSeq[T]) = {
    if(points.length <= k) Iterator(State(points, 0.0))
    else {
      Iterator.iterate(initialState(points)){ current =>
        val distances = points.par.map { x =>
          val distances = current.means.map{m => distance(x,m)}
          val minPoint = distances.argmin // which cluster minimizes euclidean distance
          (x, minPoint,distances(minPoint))
        }

        val newClusters = distances.groupBy(_._2).seq.values.map{_.map(_._1)}
        val newMeans = newClusters.par.map(clust => clust.foldLeft(zeros(clust.head))(_ += _) /= clust.size.toDouble)
        val error = distances.map(tuple => math.pow(tuple._3, 2)).sum

        State(newMeans.seq.toIndexedSeq, error, current.error)
      }.takeUpToWhere(state => state.converged)
    }


  }

  private def distance(a: T, b: T) =  norm(a - b)

  private def initialState(points: IndexedSeq[T]):State = {
    val probs = Array.fill(points.length)(1.0)
    val means = new ArrayBuffer[T]()

    while(means.length < k) {
      val newMean = points(Multinomial(new DenseVector(probs)).sample())
      means += newMean
      for(i <- (0 until probs.length).par) {
        probs(i) = math.min(probs(i), math.pow(distance(newMean, points(i)),2))
      }
    }

    new State(means, probs.sum)
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy