All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.databricks.labs.automl.reports.ReportingTools.scala Maven / Gradle / Ivy

package com.databricks.labs.automl.reports

import com.databricks.labs.automl.utils.SparkSessionWrapper
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._

import scala.collection.mutable.{ArrayBuffer, ListBuffer}

trait ReportingTools extends SparkSessionWrapper {

  def generateFrameReport(columns: Array[String], importances: Array[Double]): DataFrame = {
    import spark.sqlContext.implicits._
    sc.parallelize(columns zip importances).toDF("Feature", "Importance").orderBy($"Importance".desc)
      .withColumn("Importance", col("Importance") * 100.0)
      .withColumn("Feature", split(col("Feature"), "_si$")(0))
  }

  def cleanupFieldArray(indexedFields: Array[(String, Int)]): List[(String, Int)] = {

    val cleanedBuffer = new ListBuffer[(String, Int)]
    indexedFields.map(x => {
      cleanedBuffer += ((x._1.split("_si$")(0), x._2))
    })
    cleanedBuffer.result()
  }

  def generateDecisionTextReport(modelDebugString: String, featureIndex: List[(String, Int)]): String = {

    val reparsedArray = new ArrayBuffer[(String, String)]

    featureIndex.toArray.map(x => {
      reparsedArray += (("feature " + x._2.toString, x._1))
    })
    reparsedArray.result.toMap.foldLeft(modelDebugString){case(body, (k,v)) => body.replaceAll(k, v)}
  }

  def reportFields(fieldIndexArray: Array[(String, Int)]): String = {

    val stringConstructor = new ArrayBuffer[String]
    cleanupFieldArray(fieldIndexArray).foreach(x => {
      stringConstructor += s"Column ${x._1} is feature ${x._2}"
    })
    stringConstructor.result.mkString("\n")
  }

  def extractTopFeaturesByCount(featureFrame: DataFrame, topNCutoff: Int): Array[String] = {
    // Ensure the DataFrame is sorted and take the top N rows
    val sortedData = featureFrame.sort(col("Importance").desc).limit(topNCutoff).collect()

    sortedData.map(x => x(0).toString)

  }

  def extractTopFeaturesByImportance(featureFrame: DataFrame, importancePercentageCutoff: Double): Array[String] = {

    val sortedData = featureFrame.filter(col("Importance") >= importancePercentageCutoff)
      .sort(col("Importance").desc).collect()

    sortedData.map(x => x(0).toString)
  }


}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy