All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.ml.classification.ClassificationSummary.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.ml.classification

import org.apache.spark.annotation.Since
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.mllib.evaluation.{BinaryClassificationMetrics, MulticlassMetrics}
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions.{col, lit}
import org.apache.spark.sql.types.DoubleType


/**
 * Abstraction for multiclass classification results for a given model.
 */
private[classification] trait ClassificationSummary extends Serializable {

  /**
   * Dataframe output by the model's `transform` method.
   */
  @Since("3.1.0")
  def predictions: DataFrame

  /** Field in "predictions" which gives the prediction of each class. */
  @Since("3.1.0")
  def predictionCol: String

  /** Field in "predictions" which gives the true label of each instance (if available). */
  @Since("3.1.0")
  def labelCol: String

  /** Field in "predictions" which gives the weight of each instance. */
  @Since("3.1.0")
  def weightCol: String

  @transient private val multiclassMetrics = {
    val weightColumn = if (predictions.schema.fieldNames.contains(weightCol)) {
      col(weightCol).cast(DoubleType)
    } else {
      lit(1.0)
    }
    new MulticlassMetrics(
      predictions.select(col(predictionCol), col(labelCol).cast(DoubleType), weightColumn)
        .rdd.map {
          case Row(prediction: Double, label: Double, weight: Double) => (prediction, label, weight)
      })
  }

  /**
   * Returns the sequence of labels in ascending order. This order matches the order used
   * in metrics which are specified as arrays over labels, e.g., truePositiveRateByLabel.
   *
   * Note: In most cases, it will be values {0.0, 1.0, ..., numClasses-1}, However, if the
   * training set is missing a label, then all of the arrays over labels
   * (e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the
   * expected numClasses.
   */
  @Since("3.1.0")
  def labels: Array[Double] = multiclassMetrics.labels

  /** Returns true positive rate for each label (category). */
  @Since("3.1.0")
  def truePositiveRateByLabel: Array[Double] = recallByLabel

  /** Returns false positive rate for each label (category). */
  @Since("3.1.0")
  def falsePositiveRateByLabel: Array[Double] = {
    multiclassMetrics.labels.map(label => multiclassMetrics.falsePositiveRate(label))
  }

  /** Returns precision for each label (category). */
  @Since("3.1.0")
  def precisionByLabel: Array[Double] = {
    multiclassMetrics.labels.map(label => multiclassMetrics.precision(label))
  }

  /** Returns recall for each label (category). */
  @Since("3.1.0")
  def recallByLabel: Array[Double] = {
    multiclassMetrics.labels.map(label => multiclassMetrics.recall(label))
  }

  /** Returns f-measure for each label (category). */
  @Since("3.1.0")
  def fMeasureByLabel(beta: Double): Array[Double] = {
    multiclassMetrics.labels.map(label => multiclassMetrics.fMeasure(label, beta))
  }

  /** Returns f1-measure for each label (category). */
  @Since("3.1.0")
  def fMeasureByLabel: Array[Double] = fMeasureByLabel(1.0)

  /**
   * Returns accuracy.
   * (equals to the total number of correctly classified instances
   * out of the total number of instances.)
   */
  @Since("3.1.0")
  def accuracy: Double = multiclassMetrics.accuracy

  /**
   * Returns weighted true positive rate.
   * (equals to precision, recall and f-measure)
   */
  @Since("3.1.0")
  def weightedTruePositiveRate: Double = weightedRecall

  /** Returns weighted false positive rate. */
  @Since("3.1.0")
  def weightedFalsePositiveRate: Double = multiclassMetrics.weightedFalsePositiveRate

  /**
   * Returns weighted averaged recall.
   * (equals to precision, recall and f-measure)
   */
  @Since("3.1.0")
  def weightedRecall: Double = multiclassMetrics.weightedRecall

  /** Returns weighted averaged precision. */
  @Since("3.1.0")
  def weightedPrecision: Double = multiclassMetrics.weightedPrecision

  /** Returns weighted averaged f-measure. */
  @Since("3.1.0")
  def weightedFMeasure(beta: Double): Double = multiclassMetrics.weightedFMeasure(beta)

  /** Returns weighted averaged f1-measure. */
  @Since("3.1.0")
  def weightedFMeasure: Double = multiclassMetrics.weightedFMeasure(1.0)
}

/**
 * Abstraction for training results.
 */
private[classification] trait TrainingSummary {

  /**
   *  objective function (scaled loss + regularization) at each iteration.
   *  It contains one more element, the initial state, than number of iterations.
   */
  @Since("3.1.0")
  def objectiveHistory: Array[Double]

  /** Number of training iterations. */
  @Since("3.1.0")
  def totalIterations: Int = {
    assert(objectiveHistory.length > 0, "objectiveHistory length should be greater than 0.")
    objectiveHistory.length - 1
  }
}

/**
 * Abstraction for binary classification results for a given model.
 */
private[classification] trait BinaryClassificationSummary extends ClassificationSummary {

  private val sparkSession = predictions.sparkSession
  import sparkSession.implicits._

  /**
   *  Field in "predictions" which gives the probability or rawPrediction of each class as a
   *  vector.
   */
  def scoreCol: String = null

  @transient private val binaryMetrics = {
    val weightColumn = if (predictions.schema.fieldNames.contains(weightCol)) {
      col(weightCol).cast(DoubleType)
    } else {
      lit(1.0)
    }

    // TODO: Allow the user to vary the number of bins using a setBins method in
    // BinaryClassificationMetrics. For now the default is set to 1000.
    new BinaryClassificationMetrics(
      predictions.select(col(scoreCol), col(labelCol).cast(DoubleType), weightColumn).rdd.map {
        case Row(score: Vector, label: Double, weight: Double) => (score(1), label, weight)
      }, 1000
    )
  }

  /**
   * Returns the receiver operating characteristic (ROC) curve,
   * which is a Dataframe having two fields (FPR, TPR)
   * with (0.0, 0.0) prepended and (1.0, 1.0) appended to it.
   * See http://en.wikipedia.org/wiki/Receiver_operating_characteristic
   */
  @Since("3.1.0")
  @transient lazy val roc: DataFrame = binaryMetrics.roc().toDF("FPR", "TPR")

  /**
   * Computes the area under the receiver operating characteristic (ROC) curve.
   */
  @Since("3.1.0")
  lazy val areaUnderROC: Double = binaryMetrics.areaUnderROC()

  /**
   * Returns the precision-recall curve, which is a Dataframe containing
   * two fields recall, precision with (0.0, 1.0) prepended to it.
   */
  @Since("3.1.0")
  @transient lazy val pr: DataFrame = binaryMetrics.pr().toDF("recall", "precision")

  /**
   * Returns a dataframe with two fields (threshold, F-Measure) curve with beta = 1.0.
   */
  @Since("3.1.0")
  @transient lazy val fMeasureByThreshold: DataFrame = {
    binaryMetrics.fMeasureByThreshold().toDF("threshold", "F-Measure")
  }

  /**
   * Returns a dataframe with two fields (threshold, precision) curve.
   * Every possible probability obtained in transforming the dataset are used
   * as thresholds used in calculating the precision.
   */
  @Since("3.1.0")
  @transient lazy val precisionByThreshold: DataFrame = {
    binaryMetrics.precisionByThreshold().toDF("threshold", "precision")
  }

  /**
   * Returns a dataframe with two fields (threshold, recall) curve.
   * Every possible probability obtained in transforming the dataset are used
   * as thresholds used in calculating the recall.
   */
  @Since("3.1.0")
  @transient lazy val recallByThreshold: DataFrame = {
    binaryMetrics.recallByThreshold().toDF("threshold", "recall")
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy