All Downloads are FREE. Search and download functionalities are using the official Maven repository.

metrics.Confusion.scala Maven / Gradle / Ivy

The newest version!
package jjm.metrics

import cats.Monoid
import cats.MonoidK
import cats.Show
import cats.implicits._
import io.circe.Encoder
import io.circe.Decoder

// outer layer is gold
case class Confusion[A, I](
  matrix: Map[A, Map[A, List[I]]]
) {

  def stats = Confusion.Stats(
    matrix.map { case (k1, v1) =>
      k1 -> v1.map { case (k2, v2) =>
        k2 -> v2.size
      }
    }
  )
}
object Confusion {

  case class Stats[A](
    matrix: Map[A, Map[A, Int]]
  ) {

    def allLabels = matrix.values.flatMap(_.keys).toSet ++ matrix.keySet
    def goldCounts = allLabels.map(k => k -> matrix.get(k).fold(0)(_.values.toList.combineAll)).toMap
    def predCounts = allLabels.map(k => k -> matrix.values.map(_.get(k).getOrElse(0)).toList.combineAll).toMap
    def total = matrix.values.map(_.values.sum).sum

    def prettyString(classFreqBound: Int)(implicit s: Show[A]) = {
      val gCounts = goldCounts
      val pCounts = predCounts
      val tot = total
      val emptyMap = Map.empty[A, Int]
      val sortedFilteredClasses = gCounts.keys.toList
        .filter(k => (gCounts(k) + pCounts(k)) >= classFreqBound)
        .sortBy(k => -gCounts(k) - pCounts(k))
      val header = (" " :: sortedFilteredClasses.map(_.show)).map(pred => f"$pred%s").mkString(",")
      val body = sortedFilteredClasses.map { gold =>
        f"${gold.show}%s," + sortedFilteredClasses.map { pred =>
          val count = matrix.get(gold).getOrElse(emptyMap).get(pred).getOrElse(0)
          val prop = count.toDouble / tot
          f"$count%d"
        }.mkString(",")
      }.mkString("\n")
      header + "\n" + body
    }

  }
  object Stats {
    implicit def confStatsMonoid[A] = {
      cats.derived.semiauto.monoid[Stats[A]]
    }

    import io.circe.syntax._

    implicit def confusionStatsEncoder[A: Encoder] = {
      Encoder.instance[Stats[A]](
        _.matrix.toVector.foldMap { case (goldLabel, predLabelCounts) =>
          predLabelCounts.map { case (predLabel, count) =>
            (goldLabel, predLabel) -> count
          }
        }.toVector.asJson
      )
    }

    implicit def confusionStatsDecoder[A: Decoder] = {
      Decoder.instance[Stats[A]](cursor =>
        cursor.as[Vector[((A, A), Int)]].map(
          _.foldMap {
          case ((goldLabel, predLabel), count) =>
            Map(goldLabel -> Map(predLabel -> count))
          }
        ).map(Stats(_))
      )
    }
  }

  implicit def confusionMonoidK[A]: MonoidK[Confusion[A, *]] = new MonoidK[Confusion[A, *]] {
    override def empty[I] = Confusion[A, I](Map())
    override def combineK[I](x: Confusion[A, I], y: Confusion[A, I]) = Confusion[A, I](x.matrix |+| y.matrix)
  }
  implicit def confusionMonoid[A, I]: Monoid[Confusion[A, I]] = confusionMonoidK.algebra[I]

  implicit def confAHasMetrics[A] = new HasMetrics[BinaryConf[A]] {
    def getMetrics(conf: BinaryConf[A]) = conf.stats.metrics
  }

  def instance[A, I](gold: A, pred: A, value: I) = Confusion(Map(gold -> Map(pred -> List(value))))
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy