axle.ml.ConfusionMatrix.scala Maven / Gradle / Ivy
The newest version!
package axle.ml
import axle.matrix._
import axle.matrix.MatrixModule
import spire.algebra._
abstract class ConfusionMatrix[T, C: Order, L: Order](classifier: Classifier[T, C], data: Seq[T], labelExtractor: T => L)
extends MatrixModule {
import math.{ ceil, log10 }
val label2clusterId = data.map(datum => (labelExtractor(datum), classifier(datum)))
val labelList = label2clusterId.map(_._1).toSet.toList
val labelIndices = labelList.zipWithIndex.toMap
val labelIdClusterId2count = label2clusterId
.map({ case (label, clusterId) => ((labelIndices(label), clusterId), 1) })
.groupBy(_._1)
.map({ case (k, v) => (k, v.map(_._2).sum) })
.withDefaultValue(0)
val classes = classifier.classes
val counts = matrix[Int](labelList.length, classes.size, (r: Int, c: Int) => labelIdClusterId2count((r, classes(c))))
val width = ceil(log10(data.length)).toInt
val formatNumber = (i: Int) => ("%" + width + "d").format(i)
lazy val rowSums = counts.rowSums
lazy val columnSums = counts.columnSums
lazy val asString = (labelList.zipWithIndex.map({
case (label, r) => ((0 until counts.columns).map(c => formatNumber(counts(r, c))).mkString(" ") + " : " + formatNumber(rowSums(r, 0)) + " " + label + "\n")
}).mkString("")) + "\n" +
(0 until counts.columns).map(c => formatNumber(columnSums(0, c))).mkString(" ") + "\n"
override def toString: String = asString
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy