All Downloads are FREE. Search and download functionalities are using the official Maven repository.

spire.example.kmeans.scala Maven / Gradle / Ivy

The newest version!
package spire.example

import spire.algebra._
import spire.math._
import spire.implicits._

import scala.collection.generic.CanBuildFrom
import scala.reflect.ClassTag
import scala.{ specialized => spec }
import scala.util.Random.{ nextInt, nextDouble, nextGaussian }
import scala.annotation.tailrec

/**
 * An example using `NormedVectorSpace`s to create a generic k-Means
 * implementation. We also abstract over the collection type for the fun of it.
 * We implement Lloyd's algorithm, which has problems of its own, but performs
 * well enough.
 */
object KMeansExample extends App {

  /**
   * Returns a collection of k points which are the centers of k clusters of
   * `points0`.
   */
  def kMeans[V, @spec(Double) A, CC[V] <: Iterable[V]](points0: CC[V], k: Int)(implicit
      vs: NormedVectorSpace[V, A], order: Order[A],
      cbf: CanBuildFrom[Nothing, V, CC[V]], ct: ClassTag[V]): CC[V] = {

    val points = points0.toArray

    // We want to create an array that maps the index of each point to the
    // index of the cluster it is closest to (according to the norm).

    def assign(clusters: Array[V]): Array[Int] = {
      val assignments = new Array[Int](points.length)
      cfor(0)(_ < points.length, _ + 1) { i =>
        var min = (points(i) - clusters(0)).norm
        var idx = 0
        cfor(1)(_ < clusters.length, _ + 1) { j =>
          val dist = (points(i) - clusters(j)).norm
          if (dist < min) {
            min = dist
            idx = j
          }
        }
        assignments(i) = idx
      }
      assignments
    }

    // This is the main loop of the k-means algorithm. Given a new clustering
    // and some previous assignments mapping each point to a cluster, we
    // determine if the new clustering will cause any points to switch
    // clusters. If so, we re-assign all points to their closest center in the
    // clustering, then find new centers using the centroids of the points
    // assigned to each cluster, rinse, repeat.

    @tailrec
    def loop(assignments0: Array[Int], clusters0: Array[V]): Array[V] = {
      val assignments = assign(clusters0)
      if (assignments === assignments0) {
        clusters0
      } else {
        val clusters = Array.fill[V](clusters0.length)(vs.zero)
        val counts = new Array[Int](clusters0.length)
        cfor(0)(_ < points.length, _ + 1) { i =>
          val idx = assignments(i)
          clusters(idx) = clusters(idx) + points(i)
          counts(idx) += 1
        }
        cfor(0)(_ < clusters.length, _ + 1) { j =>
          clusters(j) = clusters(j) :/ vs.scalar.fromInt(counts(j))
        }
        loop(assignments, clusters)
      }
    }

    // Our seed points are chosen rather naively. However, the points below are
    // generated randomly, so we don't need to worry about being too smart here.

    val init: Array[V] = points take k
    val clusters = loop(assign(init), init)

    // We work with arrays above, but turn it into the collection type the user
    // wants before we return the clusters.

    val bldr = cbf()
    cfor(0)(_ < clusters.length, _ + 1) { i =>
      bldr += clusters(i)
    }
    bldr.result()
  }

  // This method let's us generate a set of n points which are clustered around
  // k centers in d-dimensions.

  def genPoints[CC[_], V, @spec(Double) A](d: Int, k: Int, n: Int)(f: Array[Double] => V)(implicit
      vs: VectorSpace[V, A], cbf: CanBuildFrom[Nothing, V, CC[V]]): CC[V] = {

    def randPoint(gen: => Double): V = f((1 to d).map(_ => gen)(collection.breakOut))

    val centers: Vector[V] = (1 to k).map({ _ =>
      randPoint(nextDouble() * 10)
    })(collection.breakOut)

    val bldr = cbf()
    cfor(0)(_ < n, _ + 1) { _ =>
      bldr += centers(nextInt(k)) + randPoint(nextGaussian)
    }
    bldr.result()
  }

  implicit val mc = java.math.MathContext.DECIMAL128

  // We construct 3 sets of points, each using different VectorSpaces. We'll
  // cluster each one, using the same k-means algorithm.

  val points0 = genPoints[List, Array[Double], Double](15, 5, 10000)(identity)
  val points1 = genPoints[List, Vector[Double], Double](5, 10, 10000)(_.toVector)
  val points2 = genPoints[List, Vector[BigDecimal], BigDecimal](7, 8, 2000)(_.map(BigDecimal(_)).toVector)

  println("Finding clusters of Array[Double] points.")
  val cluster0 = kMeans(points0, 5)

  println("Finding clusters of Vector[Double] points.")
  val cluster1 = kMeans(points1, 10)

  println("Finding clusters of Vector[BigDecimal] points.")
  val cluster2 = kMeans(points2, 8)

  println("Finished finding our clusters! Yay!")
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy