All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.databricks.connect.DatabricksSession.scala Maven / Gradle / Ivy

/*
 * DATABRICKS CONFIDENTIAL & PROPRIETARY
 * __________________
 *
 * Copyright 2023-present Databricks, Inc.
 * All Rights Reserved.
 *
 * NOTICE:  All information contained herein is, and remains the property of Databricks, Inc.
 * and its suppliers, if any.  The intellectual and technical concepts contained herein are
 * proprietary to Databricks, Inc. and its suppliers and may be covered by U.S. and foreign Patents,
 * patents in process, and are protected by trade secret and/or copyright law. Dissemination, use,
 * or reproduction of this information is strictly forbidden unless prior written permission is
 * obtained from Databricks, Inc.
 *
 * If you view or obtain a copy of this information and believe Databricks, Inc. may not have
 * intended it to be made available, please promptly report it to Databricks Legal Department
 * @ [email protected].
 */

package com.databricks.connect

import java.io.{File, IOException}
import java.net.URI
import java.nio.file.{Files, Paths}
import java.util.Properties

import scala.collection.immutable
import scala.util.Try

import com.databricks.sdk.WorkspaceClient
import com.databricks.sdk.core.{DatabricksConfig, UserAgent}

import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connect.client.SparkConnectClient

/**
 * Entry-point for building a DB Connect Scala Client Session.
 * See [[DatabricksSession.Builder.getOrCreate()]].
 */
object DatabricksSession {
  class Builder private[connect] ( // for unit testing
      private val builder: SparkSession.Builder = SparkSession.builder(),
      private val env: Map[String, String] = sys.env)
      extends Logging {
    private var clusterId: Option[String] = None
    private var host: Option[String] = None
    private var token: Option[String] = None
    private var sdkConfig: Option[DatabricksConfig] = None
    private var userAgent: Option[String] = None
    private var headers: immutable.HashMap[String, String] = new immutable.HashMap()
    private var artifacts: scala.collection.Map[String, File] = Map()
    private var validateSession: Option[Boolean] = None

    /**
     * The instance of DatabricksConfig that was finally used to pick up connection
     * parameters. This is populated as part of `getOrCreate()` API.
     * When the Databricks SDK was not used, `None` will be returned.
     * Package private for unit testing.
     */
    private[connect] var resolvedSdkConfig: Option[DatabricksConfig] = None

    private[connect] def this(env: Map[String, String]) = this(SparkSession.builder(), env)

    /**
     * API returns the same instance of the builder unchanged.
     * This API is here for backwards compatibility.
     * @return instance of Builder
     */
    def remote(): Builder = this;

    /**
     * The Databricks cluster id to connect.
     * @param clusterId cluster id
     * @return instance of Builder with the cluster id configured
     */
    def clusterId(clusterId: String): Builder = {
      this.clusterId = Some(clusterId)
      this
    }

    /**
     * The URL of the Databricks workspace to connect.
     *
     * @param host Databricks workspace URL
     * @return instance of Builder with the host configured
     */
    def host(host: String): Builder = {
      this.host = Some(host)
      this
    }

    /**
     * The access token to use while for connecting. The user associated with this token
     * should have access to the workspace and cluster. Subsequent queries will be executed
     * on behalf of this user.
     *
     * @param token Databricks access token
     * @return instance of Builder with the access token configured
     */
    def token(token: String): Builder = {
      this.token = Some(token)
      this
    }

    /**
     * A user agent string identifying the application using the Databricks Connect module.
     * Databricks Connect sends a set of standard information, such as, OS, Python version and the
     * version of Databricks Connect included as a user agent to the service. The value provided
     * here will be included along with the rest of the information. It is recommended to provide
     * this value in the format "/" as described in
     * https://datatracker.ietf.org/doc/html/rfc7231#section-5.5.3, but is not required.
     *
     * @param userAgent
     *   The user agent string identifying the application that is using this module.
     * @return
     *   instance of Builder with the cluster id configured
     */
    def userAgent(userAgent: String): Builder = {
      if (userAgent.length > 2048) {
        throw new IllegalArgumentException("User agent should not exceed 2048 characters.")
      }
      this.userAgent = Some(userAgent)
      this
    }

    private def genUserAgent(): String = {
      val applicationUserAgent = userAgent.getOrElse(
        env.getOrElse("SPARK_CONNECT_USER_AGENT", "databricks-session")
      )
      // reuse databricks-sdk UserAgent until their version is accessible programmatically
      val sdkUserAgent = UserAgent.asString().split(" ")(1)

      Seq(
        applicationUserAgent,
        s"dbconnect/$version",
        sdkUserAgent
      ).mkString(" ").trim()
    }

    /**
     * Add a header to the Spark Connect GRPC requests.
     * This method is cumulative (can be called repeatedly to add more headers).
     *
     * @param header
     *   Name of the header to set
     * @param value
     *   The value to set
     * @return
     *   The same instance of this class with a header set.
     */
    def header(header: String, value: String): Builder = {
      this.headers = this.headers + (header -> value)
      this
    }

    /**
     * Provide a custom SDK Config object to use for connection configuration.
     * Can be used to override the default SDK Config that will be used.
     *
     * @param config Databricks SDK Config
     * @return instance of Builder with the host configured
     */
    def sdkConfig(config: DatabricksConfig): Builder = {
      this.sdkConfig = Some(config)
      this
    }

    /**
     * Upload the compiled artifacts to the Databricks cluster and synchronize as
     * part of session initialization.
     *
     * This is required when the program requires external dependencies to evaluate
     * the Spark query.
     * For any Spark query using a typed Dataset API such as filter(), map(), foreach(),
     * etc., or if the query is using a UDF, the enclosing classes should be synchronized
     * with the cluster.
     *
     * For example: the following code snippet includes the compiled artifact of class `Foo`.
     * The `uri` in this example will either point to a JAR file or to a folder that contains
     * the compiled class files.
     *
     * {{{
     *   val uri = Foo.getClass.getProtectionDomain.getCodeSource.getLocation.toURI
     *   val spark = DatabricksSession.builder()
     *       .addCompiledArtifacts(uri)
     *       ...
     * }}}
     *
     * @param uri Specify the class file, JAR or directory to upload and synchronize.
     *             When a directory is specified, all deep nested JARs and class files will
     *             be synchronized.
     * @return instance of this Builder
     */
    def addCompiledArtifacts(uri: URI): Builder = {
      val f = new File(uri.getPath)
      if (!f.isFile && !f.isDirectory) {
        throw new IOException(s"Path provided must be a valid file or directory: $uri")
      }

      Files.find(Paths.get(uri), 999,
        (p, _) => p.toString.endsWith(".class") || p.toString.endsWith(".jar")
      ).forEach(p => {
        val pUri = p.toUri
        val pFile = p.toFile

        var target: String = null
        if (
          // In shared clusters, only jars placed in the root target directory are recognized.
          // See https://databricks.atlassian.net/browse/SASP-2988
          pFile.getName.endsWith(".jar")

          // uri == pUri when the uri parameter is not a folder but a direct file.
          || uri.equals(pUri)
        ) {
          target = pFile.getName
        } else {
          target = uri.relativize(pUri).toString
        }

        this.artifacts = this.artifacts + (target -> pFile)
      })
      this
    }

    /**
     * For unit testing
     */
    private[connect] def getCompiledArtifacts() = {
      this.artifacts
    }

    /**
     * Setting this option will run validations and throw an error if any fail. Validations are run
     * only when connecting to a Databricks cluster.
     * - the specified connection parameters are valid and can communicate with the cluster
     * - the Databricks Runtime version of the cluster is greater than or equal to the Databricks
     *   Connect version.
     * By default, these validations are run and a warning is logged. Unsetting this option will
     * turn off these validations. Skipping validations for connections strings.
     *
     * @param enabled Boolean, optional
     * @return instance of Builder with validateSession option configured
     */
    def validateSession(enabled: Boolean = true): Builder = {
      this.validateSession = Some(enabled)
      this
    }

    // Package private variable for unit testing.
    private[connect] var createWorkspaceClient: DatabricksConfig => WorkspaceClient =
      config => new WorkspaceClient(config)

    /**
     * Validates the configuration by retrieving the used Databricks Runtime version with the
     * Databricks SDK. Checks if there is an unsupported combination of Databricks Runtime &
     * Databricks Connect versions. Throws an exception if the validation fails.
     */
    private[connect] def validateSessionWithSdk: Unit = {
      def parseDbrVersion(version: String): (Int, Option[Int]) = {
        // should work with custom image names, "major.x"-image names, and standard image names
        val pattern = raw"(\d+(\.\d+)?\.x)".r
        pattern.findFirstIn(version) match {
          case Some(matched) =>
            val versions = matched.split("\\.")
            (versions(0).toInt, Try(versions(1).toInt).toOption)
          case None =>
            throw new IllegalArgumentException(
              "Failed to parse minor & major version from Databricks Runtime" +
              s"version string: $version"
            )
        }
      }

      logDebug(
        "Validating compatibility between the Databricks Runtime and Databricks Connect versions"
      )
      val clusterId = resolvedSdkConfig match {
        case Some(config) => config.getClusterId
        case None =>
          throw new SparkException(
            "Session validation is not supported for connection strings."
          )
      }
      val dbrVersionString = createWorkspaceClient(resolvedSdkConfig.get)
        .clusters()
        .get(clusterId)
        .getSparkVersion
      val dbrVersion = parseDbrVersion(dbrVersionString)
      // dbConnectVersion should always have a valid major and minor version
      val dbConnectVersion = version.split("\\.").take(2).map(_.toInt)

      if (dbrVersion._1 < dbConnectVersion(0) || (dbrVersion._1 == dbConnectVersion(0) &&
        dbrVersion._2.nonEmpty && dbrVersion._2.get < dbConnectVersion(1))) {
        throw new SparkException(
          "Unsupported combination of Databricks Runtime & Databricks Connect versions: " +
          s"${dbrVersion match {
            case (major, Some(minor)) => s"$major.$minor"
            case (major, None) => s"$major.x"
          }} (Databricks Runtime) < ${dbConnectVersion(0)}.${dbConnectVersion(1)} " +
          "(Databricks Connect)."
        )
      }
      logDebug("Session validated successfully.")
    }

    // scalastyle:off line.size.limit
    /**
     * Get a spark session if one was already created with the given configuration.
     * Otherwise, create a new one.
     *
     * Connection parameters for connecting to Databricks Connect is collected in the
     * following order. When one of the configuration parameters is collected, the lookup
     * for that parameter stops; collection for other parameters continues down the chain.
     * - Parameters configured directly in code via [[host()]], [[token()]], etc.
     * - Parameters configured in the Databricks SDK's [[DatabricksConfig]] specified via
     *   [[sdkConfig()]]
     * - The `SPARK_REMOTE` environment variable. For details on the connection string, see
     *   https://github.com/apache/spark/blob/master/connector/connect/docs/client-connection-string.md
     * - Use the "default" configuration profile from the Databricks config file. For details
     *   on Databricks configuration profiles, see https://docs.databricks.com/dev-tools/auth.html
     *
     * By default, validates the used configuration by retrieving the used Databricks Runtime
     * version with the Databricks SDK. Logs a warning if the authentication fails or if there is
     * an unsupported combination of Databricks Runtime & Databricks Connect versions. Change this
     * behaviour with [[DatabricksSession.builder.validateSession()]].
     *
     * @return spark session
     */
    // scalastyle:on line.size.limit
    def getOrCreate(): SparkSession = {
      this.sdkConfig.foreach(_.resolve())

      var spark: SparkSession = null
      if (!Array(this.host, this.clusterId, this.token).forall(_.isEmpty)) {
        logDebug("DatabricksSession: Initializing from explicitly set host, cluster, token")

        // use either provided configuration or default configuration
        val config = this.sdkConfig.getOrElse(new DatabricksConfig())
        // SASP-2727: Explicitly configure the auth type to be PAT.
        config.setAuthType("pat")
        config.resolve()

        // For any explicitly provided config, override the config.
        if (this.host.isDefined) {
          config.setHost(this.host.get)
        }

        if (this.clusterId.isDefined) {
          config.setClusterId(this.clusterId.get)
        }

        if (this.token.isDefined) {
          config.setToken(this.token.get)
        }

        this.resolvedSdkConfig = Some(config)
        spark = fromSdkConfig(config)
      } else if (this.sdkConfig.isDefined) {
        logDebug("DatabricksSession: Initializing from sdkConfig")

        this.resolvedSdkConfig = this.sdkConfig
        spark = fromSdkConfig(this.sdkConfig.get.resolve())
      } else if (env.contains("SPARK_REMOTE")) {
        // look for SPARK_REMOTE

        logDebug("DatabricksSession: Initializing from SPARK_REMOTE")

        if (this.headers.nonEmpty) {
          throw new IllegalArgumentException(
            "Can't configure custom headers with SPARK_REMOTE connection"
          )
        }

        // SparkSession.Builder knows how to handle SPARK_REMOTE
        this.resolvedSdkConfig = None
        spark = builder.getOrCreate()
      } else {
        logDebug("DatabrickSession: Constructing from default SDK Config")
        val config = new DatabricksConfig().resolve()
        this.resolvedSdkConfig = Some(config)
        spark = fromSdkConfig(config)
      }

      if (spark == null) {
        throw new RuntimeException("Spark session could not be initialized: unexpected state")
      }

      for ((target, file) <- this.artifacts) {
        logDebug(s"Uploading file [$file] to target path [$target].")
        spark.addArtifact(Files.readAllBytes(file.toPath), target)
      }

      // skip validation if session-id is set manually
      if (validateSession.getOrElse(true) && !headers.contains("x-databricks-session-id")) {
        try {
          validateSessionWithSdk
        } catch {
          case e: Throwable if validateSession.isEmpty => logWarning(e.getMessage)
        }
      }

      spark
    }

    private def fromSdkConfig(config: DatabricksConfig): SparkSession = {
      if (config.getHost == null) {
        throw new IllegalArgumentException(
          "DatabricksSession: Need host to construct session. Received null."
        )
      }
      val host = config.getHost.stripPrefix("https://").stripPrefix("http://")

      (this.headers.contains("x-databricks-session-id"), config.getClusterId) match {
        case (true, _) =>
          val conf = new DatabricksSparkClientConfiguration(config, host, genUserAgent(), headers)

          fromSparkClientConf(conf)

        case (_, null) =>
          throw new IllegalArgumentException(
            "DatabricksSession: need cluster id to " +
            s"construct session. Got cluster id=${config.getClusterId}"
          )

        case (_, cluster_id) =>
          val conf = new DatabricksSparkClientConfiguration(
            config,
            host,
            genUserAgent(),
            headers + ("x-databricks-cluster-id" -> config.getClusterId)
          )

          fromSparkClientConf(conf)
      }
    }

    /* for unit testing */
    protected[connect] def fromSparkClientConf(
        conf: SparkConnectClient.Configuration): SparkSession = {
      builder.client(conf.toSparkConnectClient).getOrCreate()
    }
  }

  /**
   * Get a new Builder to configure and construct a new [[org.apache.spark.sql.SparkSession]].
   * @return new Builder
   */
  def builder(): Builder = new Builder

  val version = loadProperty("version", "dbconnect-version-info.properties")

  val gitVersion = loadProperty("git_version", "dbconnect-version-info.properties")

  private def loadProperty(key: String, file: String): String = {
    val resourceStream = Thread
      .currentThread()
      .getContextClassLoader
      .getResourceAsStream(file)
    if (resourceStream == null) {
      throw new SparkException(s"Could not find ${file}")
    }

    try {
      val unknownProp = ""
      val props = new Properties()
      props.load(resourceStream)
      props.getProperty(key, unknownProp)
    } catch {
      case e: Exception =>
        throw new SparkException(s"Error loading properties from ${file}", e)
    } finally {
      if (resourceStream != null) {
        try {
          resourceStream.close()
        } catch {
          case e: Exception =>
            throw new SparkException(s"Error closing ${file} resource stream", e)
        }
      }
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy