
org.apache.spark.mllib.evaluation.BinaryClassificationMetricsExt.scala Maven / Gradle / Ivy
package org.apache.spark.mllib.evaluation
import org.apache.spark.annotation.Since
import org.apache.spark.internal.Logging
import org.apache.spark.mllib.evaluation.binary._
import org.apache.spark.rdd.{RDD, UnionRDD}
import org.apache.spark.sql.DataFrame
class BinaryClassificationMetricsExt(val scoreAndLabels: RDD[(Double, Double)],
val numBins: Int) extends Logging {
require(numBins >= 0, "numBins must be nonnegative")
/**
* Defaults `numBins` to 0.
*/
@Since("1.0.0")
def this(scoreAndLabels: RDD[(Double, Double)]) = this(scoreAndLabels, 0)
/**
* An auxiliary constructor taking a DataFrame.
* @param scoreAndLabels a DataFrame with two double columns: score and label
*/
private[mllib] def this(scoreAndLabels: DataFrame) =
this(scoreAndLabels.rdd.map(r => (r.getDouble(0), r.getDouble(1))))
/**
* Unpersist intermediate RDDs used in the computation.
*/
@Since("1.0.0")
def unpersist() {
cumulativeCounts.unpersist()
}
/**
* Returns thresholds in descending order.
*/
@Since("1.0.0")
def thresholds(): RDD[Double] = cumulativeCounts.map(_._1)
def ksMinus(): RDD[(Double, Double)] = {
val kChart = createCurve(Reach, FalsePositiveRate)
val sc = confusions.context
val first = sc.makeRDD(Seq((0.0, 0.0)), 1)
val last = sc.makeRDD(Seq((1.0, 1.0)), 1)
new UnionRDD[(Double, Double)](sc, Seq(first, kChart, last))
}
def ksPlus(): RDD[(Double, Double)] = {
val sChart = createCurve(Reach, Recall)
val sc = confusions.context
val first = sc.makeRDD(Seq((0.0, 0.0)), 1)
val last = sc.makeRDD(Seq((1.0, 1.0)), 1)
new UnionRDD[(Double, Double)](sc, Seq(first, sChart, last))
}
def gains(): RDD[(Double, Double)] = {
val gainsChart = createCurve(Reach, Recall)
val sc = confusions.context
val first = sc.makeRDD(Seq((0.0, 0.0)), 1)
val last = sc.makeRDD(Seq((1.0, 1.0)), 1)
new UnionRDD[(Double, Double)](sc, Seq(first, gainsChart, last))
}
def lift(): RDD[(Double, Double)] = createCurve(Reach, Lift)
/**
* Returns the receiver operating characteristic (ROC) curve,
* which is an RDD of (false positive rate, true positive rate)
* with (0.0, 0.0) prepended and (1.0, 1.0) appended to it.
* @see
* Receiver operating characteristic (Wikipedia)
*/
@Since("1.0.0")
def roc(): RDD[(Double, Double)] = {
val rocCurve = createCurve(FalsePositiveRate, Recall)
val sc = confusions.context
val first = sc.makeRDD(Seq((0.0, 0.0)), 1)
val last = sc.makeRDD(Seq((1.0, 1.0)), 1)
new UnionRDD[(Double, Double)](sc, Seq(first, rocCurve, last))
}
/**
* Computes the area under the receiver operating characteristic (ROC) curve.
*/
@Since("1.0.0")
def areaUnderROC(): Double = AreaUnderCurve.of(roc())
/**
* Returns the precision-recall curve, which is an RDD of (recall, precision),
* NOT (precision, recall), with (0.0, 1.0) prepended to it.
* @see
* Precision and recall (Wikipedia)
*/
@Since("1.0.0")
def pr(): RDD[(Double, Double)] = {
val prCurve = createCurve(Recall, Precision)
val sc = confusions.context
val first = sc.makeRDD(Seq((0.0, 1.0)), 1)
first.union(prCurve)
}
/**
* Computes the area under the precision-recall curve.
*/
@Since("1.0.0")
def areaUnderPR(): Double = AreaUnderCurve.of(pr())
/**
* Returns the (threshold, F-Measure) curve.
* @param beta the beta factor in F-Measure computation.
* @return an RDD of (threshold, F-Measure) pairs.
* @see F1 score (Wikipedia)
*/
@Since("1.0.0")
def fMeasureByThreshold(beta: Double): RDD[(Double, Double)] = createCurve(FMeasure(beta))
/**
* Returns the (threshold, F-Measure) curve with beta = 1.0.
*/
@Since("1.0.0")
def fMeasureByThreshold(): RDD[(Double, Double)] = fMeasureByThreshold(1.0)
/**
* Returns the (threshold, precision) curve.
*/
@Since("1.0.0")
def precisionByThreshold(): RDD[(Double, Double)] = createCurve(Precision)
/**
* Returns the (threshold, recall) curve.
*/
@Since("1.0.0")
def recallByThreshold(): RDD[(Double, Double)] = createCurve(Recall)
private lazy val (
cumulativeCounts: RDD[(Double, BinaryLabelCounter)],
confusions: RDD[(Double, BinaryConfusionMatrix)]) = {
// Create a bin for each distinct score value, count positives and negatives within each bin,
// and then sort by score values in descending order.
val counts = scoreAndLabels.combineByKey(
createCombiner = (label: Double) => new BinaryLabelCounter(0L, 0L) += label,
mergeValue = (c: BinaryLabelCounter, label: Double) => c += label,
mergeCombiners = (c1: BinaryLabelCounter, c2: BinaryLabelCounter) => c1 += c2
).sortByKey(ascending = false)
val binnedCounts =
// Only down-sample if bins is > 0
if (numBins == 0) {
// Use original directly
counts
} else {
val countsSize = counts.count()
// Group the iterator into chunks of about countsSize / numBins points,
// so that the resulting number of bins is about numBins
var grouping = countsSize / numBins
if (grouping < 2) {
// numBins was more than half of the size; no real point in down-sampling to bins
logInfo(s"Curve is too small ($countsSize) for $numBins bins to be useful")
counts
} else {
if (grouping >= Int.MaxValue) {
logWarning(
s"Curve too large ($countsSize) for $numBins bins; capping at ${Int.MaxValue}")
grouping = Int.MaxValue
}
counts.mapPartitions(_.grouped(grouping.toInt).map { pairs =>
// The score of the combined point will be just the first one's score
val firstScore = pairs.head._1
// The point will contain all counts in this chunk
val agg = new BinaryLabelCounter()
pairs.foreach(pair => agg += pair._2)
(firstScore, agg)
})
}
}
val agg = binnedCounts.values.mapPartitions { iter =>
val agg = new BinaryLabelCounter()
iter.foreach(agg += _)
Iterator(agg)
}.collect()
val partitionwiseCumulativeCounts =
agg.scanLeft(new BinaryLabelCounter())((agg, c) => agg.clone() += c)
val totalCount = partitionwiseCumulativeCounts.last
logInfo(s"Total counts: $totalCount")
val cumulativeCounts = binnedCounts.mapPartitionsWithIndex(
(index: Int, iter: Iterator[(Double, BinaryLabelCounter)]) => {
val cumCount = partitionwiseCumulativeCounts(index)
iter.map { case (score, c) =>
cumCount += c
(score, cumCount.clone())
}
}, preservesPartitioning = true)
cumulativeCounts.persist()
val confusions = cumulativeCounts.map { case (score, cumCount) =>
(score, BinaryConfusionMatrixImpl(cumCount, totalCount).asInstanceOf[BinaryConfusionMatrix])
}
(cumulativeCounts, confusions)
}
/** Creates a curve of (threshold, metric). */
private def createCurve(y: BinaryClassificationMetricComputer): RDD[(Double, Double)] = {
confusions.map { case (s, c) =>
(s, y(c))
}
}
/** Creates a curve of (metricX, metricY). */
private def createCurve(
x: BinaryClassificationMetricComputer,
y: BinaryClassificationMetricComputer): RDD[(Double, Double)] = {
confusions.map { case (_, c) =>
(x(c), y(c))
}
}
}
private[evaluation] object Reach extends BinaryClassificationMetricComputer {
override def apply(c: BinaryConfusionMatrix): Double = {
val totalPopulation = c.numNegatives + c.numPositives
if (totalPopulation == 0) {
1.0
} else {
(c.numTruePositives.toDouble + c.numFalsePositives.toDouble) / totalPopulation
}
}
}
private[evaluation] object Lift extends BinaryClassificationMetricComputer {
override def apply(c: BinaryConfusionMatrix): Double = {
Recall(c) / Reach(c)
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy