All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.mlflow.spark.autologging.DatasourceAttributeExtractor.scala Maven / Gradle / Ivy

The newest version!
package org.mlflow.spark.autologging

import org.apache.spark.sql.SparkAutologgingUtils
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan}
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, FileTable}
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
import org.apache.spark.sql.connector.catalog.Table
import org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionEnd
import org.apache.spark.sql.sources.DataSourceRegister
import org.slf4j.{Logger, LoggerFactory}

import scala.util.control.NonFatal

/** Case class wrapping information on a Spark datasource that was read. */
private[autologging] case class SparkTableInfo(
    path: String,
    versionOpt: Option[String],
    formatOpt: Option[String])

/** Base trait for extracting Spark datasource attributes from a Spark logical plan. */
private[autologging] trait DatasourceAttributeExtractorBase {
  protected val logger: Logger = LoggerFactory.getLogger(getClass)

  private def getSparkTableInfoFromTable(table: Table): Option[SparkTableInfo] = {
    table match {
      case fileTable: FileTable =>
        val tableName = fileTable.name
        val splitName = tableName.split(" ")
        val lowercaseFormat = fileTable.formatName.toLowerCase()
        if (splitName.headOption.exists(head => head.toLowerCase == lowercaseFormat)) {
          Option(SparkTableInfo(splitName.tail.mkString(" "), None, Option(lowercaseFormat)))
        } else {
          Option(SparkTableInfo(fileTable.name, None, Option(fileTable.formatName)))
        }
      case other: Table =>
        Option(SparkTableInfo(other.name, None, None))
    }
  }

  protected def maybeGetDeltaTableInfo(plan: LogicalPlan): Option[SparkTableInfo]

  /**
   * Get SparkTableInfo representing the datasource that was read from leaf node of a Spark SQL
   * query plan
   */
  protected def getTableInfoToLog(leafNode: LogicalPlan): Option[SparkTableInfo] = {
    val deltaInfoOpt = maybeGetDeltaTableInfo(leafNode)
    if (deltaInfoOpt.isDefined) {
      deltaInfoOpt
    } else {
      leafNode match {
        // DataSourceV2Relation was disabled in Spark 3.0.0 stable release due to some issue and 
        // still not present in Spark 3.2.0. While we are not sure whether it will be back in
        // the future, we still keep this code here to support previous versions.
        case relation: DataSourceV2Relation =>
          getSparkTableInfoFromTable(relation.table)
        // This is the case for Spark 3.x except for 3.0.0-preview
        case LogicalRelation(HadoopFsRelation(index, _, _, _, fileFormat, _), _, _, _) =>
          val path: String = index.rootPaths.headOption.map(_.toString).getOrElse("unknown")
          val formatOpt = fileFormat match {
            case format: DataSourceRegister => Option(format.shortName)
            case _ => None
          }
          Option(SparkTableInfo(path, None, formatOpt))
        case _ => None
      }
    }
  }

  private def getLeafNodes(lp: LogicalPlan): Seq[LogicalPlan] = {
    if (lp == null) {
      return Seq.empty
    }
    if (lp.isInstanceOf[LeafNode]) {
      Seq(lp)
    } else {
      lp.children.flatMap(getLeafNodes)
    }
  }

  /**
   * Get SparkTableInfo representing the datasource(s) that were read from a SparkListenerEvent
   * assumed to have a QueryExecution field named "qe".
   */
  def getTableInfos(event: SparkListenerSQLExecutionEnd): Seq[SparkTableInfo] = {
    val qe = SparkAutologgingUtils.getQueryExecution(event)
    if (qe != null) {
      val leafNodes = getLeafNodes(qe.analyzed)
      leafNodes.flatMap(getTableInfoToLog)
    } else {
      Seq.empty
    }
  }
}

/** Default datasource attribute extractor */
object DatasourceAttributeExtractor extends DatasourceAttributeExtractorBase {
  // TODO: attempt to detect Delta table info when Delta Lake becomes compatible with Spark 3.0
  override def maybeGetDeltaTableInfo(leafNode: LogicalPlan): Option[SparkTableInfo] = None
}

/** Datasource attribute extractor for REPL-ID aware environments (e.g. Databricks) */
object ReplAwareDatasourceAttributeExtractor extends DatasourceAttributeExtractorBase {
  override protected def maybeGetDeltaTableInfo(leafNode: LogicalPlan): Option[SparkTableInfo] = {
    leafNode match {
      case lr: LogicalRelation =>
        // First, check whether LogicalRelation is a Delta table
        val obj = ReflectionUtils.getScalaObjectByName("com.databricks.sql.transaction.tahoe.DeltaTable")
        val deltaFileIndexOpt = ReflectionUtils.callMethod(obj, "unapply", Seq(lr)).asInstanceOf[Option[Any]]
        deltaFileIndexOpt.map(fileIndex => {
          val path = ReflectionUtils.getField(fileIndex, "path").toString
          val versionOpt = ReflectionUtils.maybeCallMethod(fileIndex, "tableVersion", Seq.empty).orElse(
            ReflectionUtils.maybeCallMethod(fileIndex, "version", Seq.empty)
          ).map(_.toString)
          SparkTableInfo(path, versionOpt, Option("delta"))
        })
      case other => None
    }
  }

  // Attempts to apply redaction to sensitive data within datasource paths (e.g. S3 keys)
  private def tryRedactString(value: String): String = {
    try {
      val redactor = ReflectionUtils.getScalaObjectByName(
        "com.databricks.spark.util.DatabricksSparkLogRedactor")
      ReflectionUtils.callMethod(redactor, "redact", Seq(value)).asInstanceOf[String]
    } catch {
      case NonFatal(e) =>
        val msg = ExceptionUtils.getUnexpectedExceptionMessage(e, "while applying redaction to " +
          "datasource paths")
        logger.error(msg)
        throw e
    }
  }

  private def applyRedaction(tableInfo: SparkTableInfo): SparkTableInfo = {
    tableInfo match {
      case SparkTableInfo(path, versionOpt, formatOpt) =>
        SparkTableInfo(tryRedactString(path), versionOpt, formatOpt)
    }
  }

  override def getTableInfos(event: SparkListenerSQLExecutionEnd): Seq[SparkTableInfo] = {
    super.getTableInfos(event).map(applyRedaction)
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy