All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.tsingjyujing.geo.algorithm.cluster.BaseGeoKMeans.scala Maven / Gradle / Ivy

There is a newer version: 2.8.9-2.11
Show newest version
package com.github.tsingjyujing.geo.algorithm.cluster

import com.github.tsingjyujing.geo.algorithm.containers.{ClusterResult, LabeledPoint}
import com.github.tsingjyujing.geo.basic.IGeoPoint
import com.github.tsingjyujing.geo.exceptions.ParameterException
import com.github.tsingjyujing.geo.util.GeoUtil

import scala.sys.process.stdout
import scala.util.control.Breaks._

/**
  *
  * @author [email protected]
  * @version 1.0
  * @since 2.7
  */
trait BaseGeoKMeans[V <: IGeoPoint] {

    /**
      * Get initialized k center points
      *
      * @param points sample point
      * @param k      k centers
      * @return
      */
    def initializePoints(points: Iterable[V], k: Int): Iterable[IGeoPoint]

    /**
      * Print loss and step information to console, override it to output to another place
      *
      * @param currentStep current step
      * @param lossValue   loss value: sum distance for each point to it's center
      * @param pointCount  the count of the point
      */
    def lossOutput(currentStep: Int, lossValue: Double, pointCount: Int): Unit = {
        stdout.print("\rLoss[%d] := %f\t\tMean(loss) = %f km".format(currentStep, lossValue, lossValue / pointCount))
        stdout.flush()
    }

    /**
      * Happened if K has changed while iteration
      *
      * @param currentK current K value
      * @param lastK
      */
    def kChangedEvent(currentK: Int, lastK: Int): Unit = {
        // Print warning information if K-changed
        System.err.println("Warning: K changed in iteration! From %d to %d".format(lastK, currentK))
    }

    /**
      * Do k-means training algorithm
      *
      * @param points       sample point
      * @param k            k centers
      * @param maxStepCount max training iter limitation
      * @return
      */
    def apply(
                 points: Iterable[V],
                 k: Int,
                 maxStepCount: Int = 100
             ): ClusterResult[Int, V] = if (k > 1) {

        var centerPoints: Iterable[IGeoPoint] = initializePoints(points, k)
        var lossValue = Double.MaxValue
        val pointCount = points.size
        var lastK = k

        breakable(

            (0 until maxStepCount).foreach(currentStep => {

                val currentK = centerPoints.size
                if (currentK != lastK) {
                    kChangedEvent(currentK, lastK)
                    lastK = currentK
                }
                // For each step while decreasing

                val expectationStep = points.map(
                    point => {
                        val electedPoint = centerPoints.zipWithIndex.minBy(_._1.geoTo(point))
                        val distance = electedPoint._1.geoTo(point)
                        (point, distance, electedPoint._2)
                    }
                )

                // Calculate the loss value
                val currentLoss = expectationStep.map(_._2).sum

                lossOutput(currentStep, currentLoss, pointCount)

                if (currentLoss >= lossValue) {
                    // If loss stops decrease
                    break()
                } else {
                    // Get new centers
                    // EM step 2: Maximization step
                    centerPoints = expectationStep.groupBy(
                        _._3
                    ).map(kvs => {
                        // val classId = kvs._1
                        val pointsInClass = kvs._2.map(_._1)
                        GeoUtil.mean(pointsInClass)
                    })
                }

                lossValue = currentLoss
            })
        )

        // Generate result from last center points which generated
        ClusterResult(points.map(point => {
            val classId = centerPoints.zipWithIndex.minBy(_._1.geoTo(point))._2
            LabeledPoint(classId, point)
        }))
    } else if (k == 1) {
        ClusterResult(points.map(LabeledPoint(0, _)))
    } else {
        throw new ParameterException("Wrong K: K should greater than 0")
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy