All Downloads are FREE. Search and download functionalities are using the official Maven repository.

axle.ml.Classifier.scala Maven / Gradle / Ivy

The newest version!
package axle.ml

import spire.optional.unicode.Σ
import spire.math._
import spire.implicits._
import spire.algebra._
import axle._
import axle.matrix.JblasMatrixModule
import axle.algebra._
import Semigroups._

abstract class Classifier[DATA, CLASS: Order: Eq] extends Function1[DATA, CLASS] {

  def apply(d: DATA): CLASS

  def classes: IndexedSeq[CLASS]

  /**
   * For a given class (label value), predictedVsActual returns a tally of 4 cases:
   *
   * 1. true positive
   * 2. false positive
   * 3. false negative
   * 4. true negative
   *
   */

  private[this] def predictedVsActual(data: Seq[DATA], classExtractor: DATA => CLASS, k: CLASS): (Int, Int, Int, Int) = Σ(data.map(d => {
    val actual: CLASS = classExtractor(d)
    val predicted: CLASS = this(d)
    (actual === k, predicted === k) match {
      case (true, true) => (1, 0, 0, 0) // true positive
      case (false, true) => (0, 1, 0, 0) // false positive
      case (false, false) => (0, 0, 1, 0) // false negative
      case (true, false) => (0, 0, 0, 1) // true negative
    }
  }))

  def performance(data: Seq[DATA], classExtractor: DATA => CLASS, k: CLASS): ClassifierPerformance[Rational] = {

    val (tp, fp, fn, tn) = predictedVsActual(data, classExtractor, k)

    ClassifierPerformance[Rational](
      Rational(tp, tp + fp), // precision
      Rational(tp, tp + fn), // recall
      Rational(tn, tn + fp), // specificity aka "true negative rate"
      Rational(tp + tn, tp + tn + fp + fn) // accuracy
      )
  }

  def confusionMatrix[L: Order](data: Seq[DATA], labelExtractor: DATA => L): ConfusionMatrix[DATA, CLASS, L] =
    new ConfusionMatrix(this, data, labelExtractor) with JblasMatrixModule

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy