All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.nlp.annotators.classifier.dl.ClassifierMetrics.scala Maven / Gradle / Ivy

/*
 * Copyright 2017-2022 John Snow Labs
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.johnsnowlabs.nlp.annotators.classifier.dl

import com.johnsnowlabs.ml.tensorflow.Logging
import org.apache.spark.ml.util.Identifiable

trait ClassifierMetrics extends Logging {

  def calcStat(tp: Int, fp: Int, fn: Int): (Float, Float, Float) = {
    val precision = tp.toFloat / (tp.toFloat + fp.toFloat)
    val recall = tp.toFloat / (tp.toFloat + fn.toFloat)
    val f1 = 2 * ((precision * recall) / (precision + recall))

    (
      if (precision.isNaN) 0f else precision,
      if (recall.isNaN) 0f else recall,
      if (f1.isNaN) 0 else f1)
  }

  def aggregatedMetrics(
      labels: Seq[String],
      truePositives: Map[String, Int],
      falsePositives: Map[String, Int],
      falseNegatives: Map[String, Int],
      extended: Boolean,
      enableOutputLogs: Boolean,
      outputLogsPath: String,
      uuid: String = Identifiable.randomUID("ClassifierMetrics")): (Float, Float) = {

    val totalTruePositives = truePositives.values.sum
    val totalFalsePositives = falsePositives.values.sum
    val totalFalseNegatives = falseNegatives.values.sum
    val (precision, recall, f1) =
      calcStat(totalTruePositives, totalFalsePositives, totalFalseNegatives)

    val maxLabelLength = labels.maxBy(_.length).length
    val headLabelSpace = computeLabelSpace(maxLabelLength)

    if (extended) {
      println(s"label$headLabelSpace tp\t fp\t fn\t prec\t rec\t f1")
      outputLog(
        s"label$headLabelSpace tp\t fp\t fn\t prec\t rec\t f1",
        uuid,
        enableOutputLogs,
        outputLogsPath)
    }

    var totalPercByClass, totalRecByClass = 0f
    for (label <- labels) {
      val tp = truePositives.getOrElse(label, 0)
      val fp = falsePositives.getOrElse(label, 0)
      val fn = falseNegatives.getOrElse(label, 0)
      val (precision, recall, f1) = calcStat(tp, fp, fn)

      if (extended) {

        val labelOutput = computeLabelOutput(label, maxLabelLength, headLabelSpace.length)

        println(s"$labelOutput $tp\t $fp\t $fn\t $precision\t $recall\t $f1")
        outputLog(
          s"$labelOutput $tp\t $fp\t $fn\t $precision\t $recall\t $f1",
          uuid,
          enableOutputLogs,
          outputLogsPath)
      }

      totalPercByClass = totalPercByClass + precision
      totalRecByClass = totalRecByClass + recall
    }
    val macroPrecision = totalPercByClass / labels.length
    val macroRecall = totalRecByClass / labels.length
    val macroF1 = 2 * ((macroPrecision * macroRecall) / (macroPrecision + macroRecall))

    if (extended) {
      println(
        s"tp: $totalTruePositives fp: $totalFalsePositives fn: $totalFalseNegatives labels: ${labels.length}")
      outputLog(
        s"tp: $totalTruePositives fp: $totalFalsePositives fn: $totalFalseNegatives labels: ${labels.length}",
        uuid,
        enableOutputLogs,
        outputLogsPath)
    }
    // ex: Precision = P1+P2/2
    println(s"Macro-average\t prec: $macroPrecision, rec: $macroRecall, f1: $macroF1")
    outputLog(
      s"Macro-average\t prec: $macroPrecision, rec: $macroRecall, f1: $macroF1",
      uuid,
      enableOutputLogs,
      outputLogsPath)
    // ex: Precision =  TP1+TP2/TP1+TP2+FP1+FP2
    println(s"Micro-average\t prec: $precision, recall: $recall, f1: $f1")
    outputLog(
      s"Micro-average\t prec: $precision, recall: $recall, f1: $f1",
      uuid,
      enableOutputLogs,
      outputLogsPath)

    (f1, macroF1)
  }

  private def computeLabelSpace(maxLabelLength: Int): String = {
    val spaceToFill = maxLabelLength - "label".length
    val labelSpace = if (spaceToFill <= 0) {
      Array.fill("label".length + 2)(" ").mkString("")
    } else Array.fill(spaceToFill + 2)(" ").mkString("")

    labelSpace
  }

  private def computeLabelOutput(
      label: String,
      maxLabelLength: Int,
      headLabelLength: Int): String = {

    if (label.length > "label".length) {
      val spaceToFill = maxLabelLength - label.length + 2
      val currentLabelSpace = Array.fill(spaceToFill)(" ").mkString("")
      s"$label$currentLabelSpace"
    } else {
      val spaceToFill = "label".length - label.length
      val currentLabelSpace = Array.fill(spaceToFill + headLabelLength)(" ").mkString("")
      s"$label$currentLabelSpace"
    }

  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy