![JAR search and dependency download from the Maven repository](/logo.png)
com.tradeshift.blayze.features.Multinomial.kt Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of blayze Show documentation
Show all versions of blayze Show documentation
A fast and flexible Naive Bayes implementation for the JVM written in Kotlin
package com.tradeshift.blayze.features
import com.tradeshift.blayze.Protos
import com.tradeshift.blayze.collection.Counter
import com.tradeshift.blayze.collection.Table
import com.tradeshift.blayze.collection.tableOf
import com.tradeshift.blayze.dto.Outcome
import kotlin.math.ln
import kotlin.math.pow
/**
* A feature for multinomial data.
*
* @property includeFeatureProbability Include new features with this probability. See Ad Click Prediction: a View from the Trenches, Table 2
* @property pseudoCount Add this number to all counts, even zero counts. Prevents 0 probability. See http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes
* @property countTable Table with row of outcomes and column of features and the count for each occurrence e.g. {{"positive": {"awesome": 67, "terrible": 14}, "negative": {"awesome": 11, "terrible": 114}}
*/
class Multinomial(
private val includeFeatureProbability: Double = 1.0,
private val pseudoCount: Double = 1.0,
private val countTable: Table = tableOf()
) : Feature> {
private val outcomeCounter: Counter by lazy { sumColumns(countTable) }
override fun batchUpdate(updates: List>>): Multinomial {
val updatedTable = countTable.toMutableTable()
val knownFeatures = updatedTable.columnKeySet.toMutableSet()
for ((outcome, counts) in updates) {
for ((feature, value) in counts) {
if (knownFeatures.contains(feature) || Math.random() < (1.0 - (1.0 - includeFeatureProbability).pow(value))) {
knownFeatures.add(feature)
val count = updatedTable[outcome, feature] ?: 0
updatedTable.put(outcome, feature, count + value)
}
}
}
return Multinomial(includeFeatureProbability, pseudoCount, updatedTable.toTable())
}
override fun logProbability(outcomes: Set, value: Counter): Map {
val knownFeatures: Set = value.keys.intersect(countTable.columnKeySet)
val numOfFeatures = countTable.columnKeySet.size
val results = mutableMapOf()
for (outcome in outcomes) {
var logProb = 0.0
for (feature in knownFeatures) {
val count = ((countTable[outcome, feature] ?: 0) + pseudoCount)
logProb += value[feature] * (ln(count) - ln(outcomeCounter[outcome] + (numOfFeatures * pseudoCount)))
}
results[outcome] = logProb
}
return results
}
private fun sumColumns(table: Table): Counter {
val counts = mutableMapOf()
for (cell in table.entries) {
val outcome = cell.key.first
val current = counts.getOrDefault(outcome, 0)
counts[outcome] = current + cell.value
}
return Counter(counts)
}
fun toProto(): Protos.Multinomial {
return Protos.Multinomial.newBuilder()
.setIncludeFeatureProbability(includeFeatureProbability)
.setPseudoCount(pseudoCount)
.setTable(Protos.Table.newBuilder()
.addAllEntries(countTable.entries.map {
Protos.Entry.newBuilder()
.setRowKey(it.key.first)
.setColumnKey(it.key.second)
.setCount(it.value)
.build()
}).build()
).build()
}
companion object {
fun fromProto(proto: Protos.Multinomial): Multinomial {
return Multinomial(proto.includeFeatureProbability, proto.pseudoCount, tableOf(proto.table.entriesList.map { Pair(Pair(it.rowKey, it.columnKey), it.count) }))
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy