All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.databricks.connect.DatabricksSession.scala Maven / Gradle / Ivy

Go to download

Develop locally and connect IDEs, notebook servers and running applications to Databricks clusters.

The newest version!
/*
 * DATABRICKS CONFIDENTIAL & PROPRIETARY
 * __________________
 *
 * Copyright 2023-present Databricks, Inc.
 * All Rights Reserved.
 *
 * NOTICE:  All information contained herein is, and remains the property of Databricks, Inc.
 * and its suppliers, if any.  The intellectual and technical concepts contained herein are
 * proprietary to Databricks, Inc. and its suppliers and may be covered by U.S. and foreign Patents,
 * patents in process, and are protected by trade secret and/or copyright law. Dissemination, use,
 * or reproduction of this information is strictly forbidden unless prior written permission is
 * obtained from Databricks, Inc.
 *
 * If you view or obtain a copy of this information and believe Databricks, Inc. may not have
 * intended it to be made available, please promptly report it to Databricks Legal Department
 * @ [email protected].
 */

package com.databricks.connect

import java.io.{File, IOException}
import java.net.URI
import java.nio.file.{Files, Paths}
import java.util.{Properties, UUID}

import scala.collection.immutable
import scala.util.Try

import com.databricks.sdk.WorkspaceClient
import com.databricks.sdk.core.{DatabricksConfig, UserAgent}

import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connect.client.SparkConnectClient

/**
 * Entry-point for building a DB Connect Scala Client Session. See
 * [[DatabricksSession.Builder.getOrCreate()]].
 */
object DatabricksSession {
  class Builder private[connect] ( // for unit testing
      private val builder: SparkSession.Builder = SparkSession.builder(),
      private val env: Map[String, String] = sys.env)
    extends Logging {
    private var clusterId: Option[String] = None
    private var serverlessMode: Option[Boolean] = None
    private var host: Option[String] = None
    private var token: Option[String] = None
    private var sdkConfig: Option[DatabricksConfig] = None
    private var userAgent: Option[String] = None
    private var headers: immutable.HashMap[String, String] = new immutable.HashMap()
    private var artifacts: scala.collection.Map[String, File] = Map()
    private var validateSession: Option[Boolean] = None

    /**
     * The instance of DatabricksConfig that was finally used to pick up connection parameters.
     * This is populated as part of `getOrCreate()` API. When the Databricks SDK was not used,
     * `None` will be returned. Package private for unit testing.
     */
    private[connect] var resolvedSdkConfig: Option[DatabricksConfig] = None

    private[connect] def this(env: Map[String, String]) = this(SparkSession.builder(), env)

    /**
     * API returns the same instance of the builder unchanged. This API is here for backwards
     * compatibility.
     * @return
     *   instance of Builder
     */
    def remote(): Builder = this

    /**
     * The Databricks cluster id to connect.
     * Can't be used with serverless(enabled=True) at the same time.
     *
     * @param clusterId
     *   cluster id
     * @return
     *   instance of Builder with the cluster id configured
     */
    def clusterId(clusterId: String): Builder = {
      this.clusterId = Some(clusterId)
      this
    }

    /**
     * [IMPORTANT]: This API and support for Scala in serverless is in early PRIVATE PREVIEW.
     * Some operations may not work.
     *
     * Connect to the serverless endpoint of the workspace.
     * Can't be used with clusterId at the same time.
     *
     * @param enabled
     * Boolean flag that enables serverless mode.
     * @return
     * instance of Builder with the serverless mode configured
     */
    def serverless(enabled: Boolean = true): Builder = {
      this.serverlessMode = Some(enabled)
      this
    }

    /**
     * The URL of the Databricks workspace to connect.
     *
     * @param host
     *   Databricks workspace URL
     * @return
     *   instance of Builder with the host configured
     */
    def host(host: String): Builder = {
      this.host = Some(host)
      this
    }

    /**
     * The access token to use while for connecting. The user associated with this token should
     * have access to the workspace and cluster. Subsequent queries will be executed on behalf of
     * this user.
     *
     * @param token
     *   Databricks access token
     * @return
     *   instance of Builder with the access token configured
     */
    def token(token: String): Builder = {
      this.token = Some(token)
      this
    }

    /**
     * A user agent string identifying the application using the Databricks Connect module.
     * Databricks Connect sends a set of standard information, such as, OS, Python version and the
     * version of Databricks Connect included as a user agent to the service. The value provided
     * here will be included along with the rest of the information. It is recommended to provide
     * this value in the format "/" as described in
     * https://datatracker.ietf.org/doc/html/rfc7231#section-5.5.3, but is not required.
     *
     * @param userAgent
     *   The user agent string identifying the application that is using this module.
     * @return
     *   instance of Builder with the cluster id configured
     */
    def userAgent(userAgent: String): Builder = {
      if (userAgent.length > 2048) {
        throw new IllegalArgumentException("User agent should not exceed 2048 characters.")
      }
      this.userAgent = Some(userAgent)
      this
    }

    private def genUserAgent(): String = {
      val applicationUserAgent =
        userAgent.getOrElse(env.getOrElse("SPARK_CONNECT_USER_AGENT", "databricks-session"))
      // reuse databricks-sdk UserAgent until their version is accessible programmatically
      val sdkUserAgent = UserAgent.asString().split(" ")(1)

      Seq(applicationUserAgent, s"dbconnect/$version", sdkUserAgent).mkString(" ").trim()
    }

    /**
     * Add a header to the Spark Connect GRPC requests. This method is cumulative (can be called
     * repeatedly to add more headers).
     *
     * @param header
     *   Name of the header to set
     * @param value
     *   The value to set
     * @return
     *   The same instance of this class with a header set.
     */
    def header(header: String, value: String): Builder = {
      this.headers = this.headers + (header -> value)
      this
    }

    /**
     * Provide a custom SDK Config object to use for connection configuration. Can be used to
     * override the default SDK Config that will be used.
     *
     * @param config
     *   Databricks SDK Config
     * @return
     *   instance of Builder with the host configured
     */
    def sdkConfig(config: DatabricksConfig): Builder = {
      this.sdkConfig = Some(config)
      this
    }

    /**
     * Upload the compiled artifacts to the Databricks cluster and synchronize as part of session
     * initialization.
     *
     * This is required when the program requires external dependencies to evaluate the Spark
     * query. For any Spark query using a typed Dataset API such as filter(), map(), foreach(),
     * etc., or if the query is using a UDF, the enclosing classes should be synchronized with the
     * cluster.
     *
     * For example: the following code snippet includes the compiled artifact of class `Foo`. The
     * `uri` in this example will either point to a JAR file or to a folder that contains the
     * compiled class files.
     *
     * {{{
     *   val uri = Foo.getClass.getProtectionDomain.getCodeSource.getLocation.toURI
     *   val spark = DatabricksSession.builder()
     *       .addCompiledArtifacts(uri)
     *       ...
     * }}}
     *
     * @param uri
     *   Specify the class file, JAR or directory to upload and synchronize. When a directory is
     *   specified, all deep nested JARs and class files will be synchronized.
     * @return
     *   instance of this Builder
     */
    def addCompiledArtifacts(uri: URI): Builder = {
      val f = new File(uri.getPath)
      if (!f.isFile && !f.isDirectory) {
        throw new IOException(s"Path provided must be a valid file or directory: $uri")
      }

      Files
        .find(
          Paths.get(uri),
          999,
          (p, _) => p.toString.endsWith(".class") || p.toString.endsWith(".jar"))
        .forEach(p => {
          val pUri = p.toUri
          val pFile = p.toFile

          var target: String = null
          if (
            // In shared clusters, only jars placed in the root target directory are recognized.
            // See https://databricks.atlassian.net/browse/SASP-2988
            pFile.getName.endsWith(".jar")

            // uri == pUri when the uri parameter is not a folder but a direct file.
            || uri.equals(pUri)) {
            target = pFile.getName
          } else {
            target = uri.relativize(pUri).toString
          }

          this.artifacts = this.artifacts + (target -> pFile)
        })
      this
    }

    /**
     * For unit testing
     */
    private[connect] def getCompiledArtifacts = {
      this.artifacts
    }

    /**
     * Setting this option will run validations and throw an error if any fail. Validations are
     * run only when connecting to a Databricks cluster.
     *   - the specified connection parameters are valid and can communicate with the cluster
     *   - the Databricks Runtime version of the cluster is greater than or equal to the
     *     Databricks Connect version. By default, these validations are run and a warning is
     *     logged. Unsetting this option will turn off these validations. Skipping validations for
     *     connections strings.
     *
     * @param enabled
     *   Boolean, optional
     * @return
     *   instance of Builder with validateSession option configured
     */
    def validateSession(enabled: Boolean = true): Builder = {
      this.validateSession = Some(enabled)
      this
    }

    // Package private variable for unit testing.
    private[connect] var getVersion: () => String = () => version

    // Package private variable for unit testing.
    private[connect] var createWorkspaceClient: DatabricksConfig => WorkspaceClient =
      config => new WorkspaceClient(config)

    /**
     * Validates the configuration by retrieving the used Databricks Runtime version with the
     * Databricks SDK. Checks if there is an unsupported combination of Databricks Runtime &
     * Databricks Connect versions. Throws an exception if the validation fails.
     */
    private[connect] def validateSessionWithSdk: Unit = {
      case class DBRVersion(major: Int, minor: Option[Int]) {
        override def toString: String = {
          minor match {
            case Some(minor) => s"$major.$minor"
            case None => s"$major.x"
          }
        }
      }
      case class DBConnectVersion(major: Int, minor: Int) {
        override def toString: String = s"$major.$minor"
      }

      def parseDbrVersion(version: String): DBRVersion = {
        // should work with custom image names, "major.x"-image names, and standard image names
        val pattern = raw"(\d+(\.\d+)?\.x)".r
        pattern.findFirstIn(version) match {
          case Some(matched) =>
            val versions = matched.split("\\.")
            DBRVersion(versions(0).toInt, Try(versions(1).toInt).toOption)
          case None =>
            throw new IllegalArgumentException(
              "Failed to parse minor & major version from Databricks Runtime" +
                s"version string: $version")
        }
      }

      def parseDBConnectVersion(version: String): DBConnectVersion = {
        val versions = version.split("\\.").map(_.toInt)
        // dbConnectVersion should always have a valid major and minor version
        DBConnectVersion(versions(0), versions(1))
      }

      logDebug(
        "Validating compatibility between the Databricks Runtime and Databricks Connect versions")
      val clusterId = resolvedSdkConfig match {
        case Some(config) => config.getClusterId
        case None =>
          throw new SparkException("Session validation is not supported for connection strings.")
      }
      val dbrVersionString = createWorkspaceClient(resolvedSdkConfig.get)
        .clusters()
        .get(clusterId)
        .getSparkVersion
      val dbrVersion = parseDbrVersion(dbrVersionString)
      // dbConnectVersion should always have a valid major and minor version
      val dbConnectVersion = parseDBConnectVersion(getVersion())

      if (dbrVersion.major < dbConnectVersion.major || (dbrVersion.major == dbConnectVersion.major
        && dbrVersion.minor.nonEmpty && dbrVersion.minor.get < dbConnectVersion.minor)) {
        throw new SparkException(
          "Unsupported combination of Databricks Runtime & Databricks Connect versions: " +
          s"${dbrVersion.toString} (Databricks Runtime) < ${dbConnectVersion.toString} " +
          "(Databricks Connect)."
        )
      }

      // Print a warning when minor versions DBR and DB Connect are inconsistent. Since currently,
      // we only guarantee compatibility within the same minor version, using mismatched minor
      // versions can lead to errors in certain operations, such as using Scala UDFs. If an error
      // occurs, the user should retry with a matching version of DB Connect.
      //  See this ticket for more information: go/jira/SASP-3736
      if (dbrVersion.major != dbConnectVersion.major || dbrVersion.minor.isEmpty ||
        dbrVersion.minor.get != dbConnectVersion.minor) {
        throw new SparkException(
          s"Minor version mismatch: Databricks Runtime (${dbrVersion.toString}) vs. Databricks " +
            s"Connect (${dbConnectVersion.toString}). Compatibility between Databricks Connect " +
            s"and Databricks Runtime is guaranteed within the same minor version. This mismatch " +
            s"may lead to errors with User-defined Functions. If you encounter any errors, " +
            s"use the version of Databricks Connect that matches the Databricks Runtime " +
            s"version you are running."
        )
      }
      logDebug("Session validated successfully.")
    }

    // scalastyle:off line.size.limit
    /**
     * Get a spark session if one was already created with the given configuration. Otherwise,
     * create a new one.
     *
     * Connection parameters for connecting to Databricks Connect is collected in the following
     * order. When one of the configuration parameters is collected, the lookup for that parameter
     * stops; collection for other parameters continues down the chain.
     *   - Parameters configured directly in code via [[host()]], [[token()]], etc.
     *   - Parameters configured in the Databricks SDK's [[DatabricksConfig]] specified via
     *     [[sdkConfig()]]
     *   - The `SPARK_REMOTE` environment variable. For details on the connection string, see
     *     https://github.com/apache/spark/blob/master/connector/connect/docs/client-connection-string.md
     *   - Use the "default" configuration profile from the Databricks config file. For details on
     *     Databricks configuration profiles, see https://docs.databricks.com/dev-tools/auth.html
     *
     * By default, validates the used configuration by retrieving the used Databricks Runtime
     * version with the Databricks SDK. Logs a warning if the authentication fails or if there is
     * an unsupported combination of Databricks Runtime & Databricks Connect versions. Change this
     * behaviour with [[DatabricksSession.builder().validateSession()]].
     *
     * @return
     *   spark session
     */
    // scalastyle:on line.size.limit
    def getOrCreate(): SparkSession = {
      this.sdkConfig.foreach(_.resolve())

      var spark: SparkSession = null
      if (Array(this.host, this.clusterId, this.token).exists(_.isDefined)
        || this.serverlessMode.contains(true)) {
        logDebug("DatabricksSession: Initializing from explicitly set host, cluster, token")

        // use either provided configuration or default configuration
        val config = this.sdkConfig.getOrElse(new DatabricksConfig())
        // SASP-2727: Explicitly configure the auth type to be PAT.
        config.setAuthType("pat")
        config.resolve()

        // For any explicitly provided config, override the config.
        if (this.host.isDefined) {
          config.setHost(this.host.get)
        }

        if (this.clusterId.isDefined && this.serverlessMode.contains(true)) {
          throw new IllegalArgumentException("Can't set both cluster id and serverless.")
        }

        if (this.clusterId.isDefined) {
          config.setClusterId(this.clusterId.get)
        }

        if (this.token.isDefined) {
          config.setToken(this.token.get)
        }

        this.resolvedSdkConfig = Some(config)
        spark = fromSdkConfig(config)
      } else if (this.sdkConfig.isDefined) {
        logDebug("DatabricksSession: Initializing from sdkConfig")

        this.resolvedSdkConfig = this.sdkConfig
        spark = fromSdkConfig(this.sdkConfig.get.resolve())
      } else if (env.contains("SPARK_REMOTE")) {
        // look for SPARK_REMOTE

        logDebug("DatabricksSession: Initializing from SPARK_REMOTE")

        if (this.headers.nonEmpty) {
          throw new IllegalArgumentException(
            "Can't configure custom headers with SPARK_REMOTE connection")
        }

        // SparkSession.Builder knows how to handle SPARK_REMOTE
        this.resolvedSdkConfig = None
        spark = builder.getOrCreate()
      } else {
        logDebug("DatabrickSession: Constructing from default SDK Config")
        val config = new DatabricksConfig().resolve()
        this.resolvedSdkConfig = Some(config)
        spark = fromSdkConfig(config)
      }

      if (spark == null) {
        throw new RuntimeException("Spark session could not be initialized: unexpected state")
      }

      for ((target, file) <- this.artifacts) {
        logDebug(s"Uploading file [$file] to target path [$target].")
        spark.addArtifact(Files.readAllBytes(file.toPath), target)
      }

      // Validate the session for None and True.
      // Only fail if validateSession is explicitly enabled. For None, log a warning.
      // Skip validation if serverless mode is used or session-id is set manually.
      if (
        validateSession.getOrElse(true)
        && !this.serverlessMode.contains(true)
        && !headers.contains("x-databricks-session-id")
      ) {
        try {
          validateSessionWithSdk
        } catch {
          case e: Throwable if validateSession.isEmpty => logWarning(e.getMessage)
        }
      }

      spark
    }

    private def fromSdkConfig(config: DatabricksConfig): SparkSession = {
      if (config.getHost == null) {
        throw new IllegalArgumentException(
          "DatabricksSession: Need host to construct session. Received null.")
      }
      val host = config.getHost.stripPrefix("https://").stripPrefix("http://")
      // local headers variable will be modified when appending "x-databricks-session-id",
      // but this.headers will not
      var headers = this.headers

      (
        this.serverlessMode.contains(true),
        headers.contains("x-databricks-session-id"),
        config.getClusterId,
      ) match {

        case (true, _, _) | (_, true, _) =>
          if (!headers.contains("x-databricks-session-id")) {
            headers += ("x-databricks-session-id" -> UUID.randomUUID().toString)
          }
          logDebug("Using serverless compute")
          logDebug(
            "Using SparkSession with remote session id: " + headers("x-databricks-session-id"))
          val conf = new DatabricksSparkClientConfiguration(config, host, genUserAgent(), headers)

          fromSparkClientConf(conf)

        case (_, _, null) =>
          throw new IllegalArgumentException(
            "DatabricksSession: need cluster id or serverless to construct a session.")

        case (_, _, cluster_id) =>
          logDebug("Using cluster: " + cluster_id)
          val conf = new DatabricksSparkClientConfiguration(
            config,
            host,
            genUserAgent(),
            headers + ("x-databricks-cluster-id" -> cluster_id))

          fromSparkClientConf(conf)
      }
    }

    /* package protected for unit testing */
    protected[connect] def fromSparkClientConf(
        conf: SparkConnectClient.Configuration): SparkSession = {
      builder.client(conf.toSparkConnectClient).getOrCreate()
    }
  }

  /**
   * Get a new Builder to configure and construct a new [[org.apache.spark.sql.SparkSession]].
   * @return
   *   new Builder
   */
  def builder(): Builder = new Builder

  val version: String = loadProperty("version", "dbconnect-version-info.properties")

  val gitVersion: String = loadProperty("git_version", "dbconnect-version-info.properties")

  private def loadProperty(key: String, file: String): String = {
    val resourceStream = Thread
      .currentThread()
      .getContextClassLoader
      .getResourceAsStream(file)
    if (resourceStream == null) {
      throw new SparkException(s"Could not find ${file}")
    }

    try {
      val unknownProp = ""
      val props = new Properties()
      props.load(resourceStream)
      props.getProperty(key, unknownProp)
    } catch {
      case e: Exception =>
        throw new SparkException(s"Error loading properties from ${file}", e)
    } finally {
      if (resourceStream != null) {
        try {
          resourceStream.close()
        } catch {
          case e: Exception =>
            throw new SparkException(s"Error closing ${file} resource stream", e)
        }
      }
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy