![JAR search and dependency download from the Maven repository](/logo.png)
com.tradeshift.blayze.features.Gaussian.kt Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of blayze Show documentation
Show all versions of blayze Show documentation
A fast and flexible Naive Bayes implementation for the JVM written in Kotlin
package com.tradeshift.blayze.features
import com.tradeshift.blayze.Protos
import com.tradeshift.blayze.dto.Outcome
import kotlin.math.*
/**
* A feature for numbers that approximately follow a normal distribution, e.g. age, amounts, etc.
*/
class Gaussian(
private val estimators: Map = mapOf()
) : Feature {
override fun batchUpdate(updates: List>): Gaussian {
val map = estimators.toMutableMap()
for ((outcome, x) in updates) {
map[outcome] = map[outcome]?.add(x) ?: StreamingEstimator(x)
}
return Gaussian(map)
}
override fun logProbability(outcomes: Set, value: Double): Map {
val results = mutableMapOf()
for (outcome in outcomes) {
results[outcome] = logPropabilityOutcome(outcome, value)
}
return results
}
private fun logPropabilityOutcome(outcome: Outcome, value: Double): Double {
// p(x|mu,sigma) = 1/sqrt(2*pi*sigma^2) * exp(-(x-mu)^2/(2*sigma^2))
// log(p(x|mu, sigma) = log(1) - log(sqrt(2*pi*sigma^2)) - (x-mu)^2/(2*sigma^2)
// = -log(sqrt(2*pi*sigma^2)) - (x-mu)^2/(2*sigma^2)
// = -log(sigma*sqrt(2*pi)) - (x-mu)^2/(2*sigma^2)
// = -log(sigma) - log(sqrt(2*pi)) - (x-mu)^2/(2*sigma^2)
val (mu, sigma) = estimators[outcome] ?: return 0.0
if (sigma == 0.0) {
return 0.0
}
return -ln(sigma) - ln(sqrt(2 * PI)) - (value - mu).pow(2).div(2 * sigma.pow(2))
}
/**
* B. P. Welford (1962). "Note on a method for calculating corrected sums of squares and products".
*/
class StreamingEstimator private constructor(
private val count: Int,
val mean: Double,
private val m2: Double
) {
constructor(x: Double) : this(1, x, 0.0)
fun add(x: Double): StreamingEstimator {
var (count, mean, m2) = Triple(count, mean, m2)
count += 1
val delta = x - mean
mean += delta / count
val delta2 = x - mean
m2 += delta * delta2
return StreamingEstimator(count, mean, m2)
}
val stdev: Double by lazy {
if (count < 2) {
0.0
} else {
sqrt(m2 / (count - 1))
}
}
operator fun component1(): Double {
return mean
}
operator fun component2(): Double {
return stdev
}
fun toProto(): Protos.StreamingEstimator {
return Protos.StreamingEstimator.newBuilder()
.setCount(count)
.setMean(mean)
.setM2(m2)
.build()
}
companion object {
fun fromProto(proto: Protos.StreamingEstimator): StreamingEstimator {
return StreamingEstimator(proto.count, proto.mean, proto.m2)
}
}
}
fun toProto(): Protos.Gaussian {
return Protos.Gaussian.newBuilder().putAllEstimators(estimators.mapValues { it.value.toProto() }).build()
}
companion object {
fun fromProto(proto: Protos.Gaussian): Gaussian {
return Gaussian(proto.estimatorsMap.mapValues { StreamingEstimator.fromProto(it.value) })
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy