All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.ml.odkl.ExtendedMultivariateOnlineSummarizer.scala Maven / Gradle / Ivy

package org.apache.spark.ml.odkl

import java.io.{IOException, ObjectInputStream, ObjectOutputStream, OutputStream}
import java.nio.ByteBuffer

import com.esotericsoftware.kryo.io.{Input, Output}
import com.esotericsoftware.kryo.{DefaultSerializer, Kryo, Serializer}
import com.tdunning.math.stats.AVLTreeDigest
import odkl.analysis.spark.util.Logging
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.mllib
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer

/**
  * Created by dmitriybugaichenko on 30.12.15.
  *
  * Utility used for estimating extended stat for the set of vectors. In addition to mean, deviation and count
  * estimates percentiles
  *
  * @param dimension   Expected dimension of vectors to aggregate
  * @param compression How should accuracy be traded for size?  A value of N here will give quantile errors
  *                    almost always less than 3/N with considerably smaller errors expected for extreme
  *                    quantiles. Conversely, you should expect to track about 5 N centroids for this
  *                    accuracy.
  */
class ExtendedMultivariateOnlineSummarizer
(
  val dimension: Int,
  val compression: Double) extends MultivariateOnlineSummarizer with Serializable with Logging {

  val percentileAggregators = Array.tabulate(dimension) { i => new SeriallizableAvlTreeDigest(compression) }

  override def add(sample: mllib.linalg.Vector): this.type = {
    require(sample.size == dimension, s"Expecting vector of size $dimension")
    super.add(sample)

    for (i <- 0 until dimension) percentileAggregators(i).add(sample(i))

    this
  }

  override def merge(other: MultivariateOnlineSummarizer): this.type = {
    require(other.isInstanceOf[ExtendedMultivariateOnlineSummarizer], "Extended summarizer expected.")
    require(dimension == other.asInstanceOf[ExtendedMultivariateOnlineSummarizer].dimension, s"Expecting summarizer with the same dimansion $dimension")

    super.merge(other)

    for (i <- 0 until dimension) try {
      percentileAggregators(i).add(other.asInstanceOf[ExtendedMultivariateOnlineSummarizer].percentileAggregators(i))
    } catch {
      case e: Exception => logError(s"Exception while aggregating index $i", e)
    }

    this
  }

  def percentile(p: Double): Vector = {
    require(p > 0 && p < 1, "Expected p between 0 and 1 both excluding")
    require(count > 0, "Expected at least one sample")

    Vectors.dense(Array.tabulate(dimension) {
      i => percentileAggregators(i).percentile(p)
    })
  }
}

/**
  * Serializable wrapper over the TDigest
  *
  * @param initialCompression How should accuracy be traded for size?  A value of N here will give quantile errors
  *                           almost always less than 3/N with considerably smaller errors expected for extreme
  *                           quantiles.  Conversely, you should expect to track about 5 N centroids for this
  *                           accuracy.
  */
@DefaultSerializer(classOf[SeriallizableAvlTreeDigest])
class SeriallizableAvlTreeDigest(val initialCompression: Double = 100) extends Serializer[SeriallizableAvlTreeDigest] with Serializable {
  def this() = this(100)

  def this(dig : AVLTreeDigest) = {
    this(dig.compression())
    digest = dig
  }

  var digest = new AVLTreeDigest(initialCompression)

  def effectiveCompression = digest.compression()

  @throws(classOf[IOException])
  private def writeObject(out: ObjectOutputStream): Unit = {
    writeToOutputStream(out)
  }

  def writeToOutputStream(out: OutputStream): Unit = {
    val bytes = Array.ofDim[Byte](digest.smallByteSize() + 4)
    val buffer: ByteBuffer = ByteBuffer.wrap(bytes)

    digest.asSmallBytes(buffer.putInt(bytes.length - 4))
    out.write(bytes)
  }

  @throws(classOf[IOException])
  private def readObject(in: ObjectInputStream): Unit = {
    val size: Int = in.readInt()
    val bytes = Array.ofDim[Byte](size)
    in.readFully(bytes)

    digest = AVLTreeDigest.fromBytes(ByteBuffer.wrap(bytes))
  }

  def add(x: Double) = digest.add(x)

  def add(x: SeriallizableAvlTreeDigest) = digest.add(x.digest)

  def percentile(p: Double) = {
    require(p > 0 && p < 1, "Expected p between 0 and 1 both excluding")
    digest.quantile(p)
  }

  override def write(kryo: Kryo, output: Output, obj: SeriallizableAvlTreeDigest): Unit = {
    obj.writeToOutputStream(output)
  }

  override def read(kryo: Kryo, input: Input, `type`: Class[SeriallizableAvlTreeDigest]): SeriallizableAvlTreeDigest = {
    val size = input.readInt()
    val bytes = Array.ofDim[Byte](size)
    input.read(bytes)

    new SeriallizableAvlTreeDigest(AVLTreeDigest.fromBytes(ByteBuffer.wrap(bytes)))
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy