
com.databricks.connect.DatabricksSession.scala Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of databricks-connect Show documentation
Show all versions of databricks-connect Show documentation
Develop locally and connect IDEs, notebook servers and running applications to Databricks clusters.
The newest version!
/*
* DATABRICKS CONFIDENTIAL & PROPRIETARY
* __________________
*
* Copyright 2023-present Databricks, Inc.
* All Rights Reserved.
*
* NOTICE: All information contained herein is, and remains the property of Databricks, Inc.
* and its suppliers, if any. The intellectual and technical concepts contained herein are
* proprietary to Databricks, Inc. and its suppliers and may be covered by U.S. and foreign Patents,
* patents in process, and are protected by trade secret and/or copyright law. Dissemination, use,
* or reproduction of this information is strictly forbidden unless prior written permission is
* obtained from Databricks, Inc.
*
* If you view or obtain a copy of this information and believe Databricks, Inc. may not have
* intended it to be made available, please promptly report it to Databricks Legal Department
* @ [email protected].
*/
package com.databricks.connect
import java.io.{File, IOException}
import java.net.URI
import java.nio.file.{Files, Paths}
import java.util.{Properties, UUID}
import scala.collection.immutable
import scala.util.Try
import com.databricks.sdk.WorkspaceClient
import com.databricks.sdk.core.{DatabricksConfig, UserAgent}
import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connect.client.SparkConnectClient
/**
* Entry-point for building a DB Connect Scala Client Session. See
* [[DatabricksSession.Builder.getOrCreate()]].
*/
object DatabricksSession {
class Builder private[connect] ( // for unit testing
private val builder: SparkSession.Builder = SparkSession.builder(),
private val env: Map[String, String] = sys.env)
extends Logging {
private var clusterId: Option[String] = None
private var serverlessMode: Option[Boolean] = None
private var host: Option[String] = None
private var token: Option[String] = None
private var sdkConfig: Option[DatabricksConfig] = None
private var userAgent: Option[String] = None
private var headers: immutable.HashMap[String, String] = new immutable.HashMap()
private var artifacts: scala.collection.Map[String, File] = Map()
private var validateSession: Option[Boolean] = None
/**
* The instance of DatabricksConfig that was finally used to pick up connection parameters.
* This is populated as part of `getOrCreate()` API. When the Databricks SDK was not used,
* `None` will be returned. Package private for unit testing.
*/
private[connect] var resolvedSdkConfig: Option[DatabricksConfig] = None
private[connect] def this(env: Map[String, String]) = this(SparkSession.builder(), env)
/**
* API returns the same instance of the builder unchanged. This API is here for backwards
* compatibility.
* @return
* instance of Builder
*/
def remote(): Builder = this
/**
* The Databricks cluster id to connect.
* Can't be used with serverless(enabled=True) at the same time.
*
* @param clusterId
* cluster id
* @return
* instance of Builder with the cluster id configured
*/
def clusterId(clusterId: String): Builder = {
this.clusterId = Some(clusterId)
this
}
/**
* [IMPORTANT]: This API and support for Scala in serverless is in early PRIVATE PREVIEW.
* Some operations may not work.
*
* Connect to the serverless endpoint of the workspace.
* Can't be used with clusterId at the same time.
*
* @param enabled
* Boolean flag that enables serverless mode.
* @return
* instance of Builder with the serverless mode configured
*/
def serverless(enabled: Boolean = true): Builder = {
this.serverlessMode = Some(enabled)
this
}
/**
* The URL of the Databricks workspace to connect.
*
* @param host
* Databricks workspace URL
* @return
* instance of Builder with the host configured
*/
def host(host: String): Builder = {
this.host = Some(host)
this
}
/**
* The access token to use while for connecting. The user associated with this token should
* have access to the workspace and cluster. Subsequent queries will be executed on behalf of
* this user.
*
* @param token
* Databricks access token
* @return
* instance of Builder with the access token configured
*/
def token(token: String): Builder = {
this.token = Some(token)
this
}
/**
* A user agent string identifying the application using the Databricks Connect module.
* Databricks Connect sends a set of standard information, such as, OS, Python version and the
* version of Databricks Connect included as a user agent to the service. The value provided
* here will be included along with the rest of the information. It is recommended to provide
* this value in the format "/" as described in
* https://datatracker.ietf.org/doc/html/rfc7231#section-5.5.3, but is not required.
*
* @param userAgent
* The user agent string identifying the application that is using this module.
* @return
* instance of Builder with the cluster id configured
*/
def userAgent(userAgent: String): Builder = {
if (userAgent.length > 2048) {
throw new IllegalArgumentException("User agent should not exceed 2048 characters.")
}
this.userAgent = Some(userAgent)
this
}
private def genUserAgent(): String = {
val applicationUserAgent =
userAgent.getOrElse(env.getOrElse("SPARK_CONNECT_USER_AGENT", "databricks-session"))
// reuse databricks-sdk UserAgent until their version is accessible programmatically
val sdkUserAgent = UserAgent.asString().split(" ")(1)
Seq(applicationUserAgent, s"dbconnect/$version", sdkUserAgent).mkString(" ").trim()
}
/**
* Add a header to the Spark Connect GRPC requests. This method is cumulative (can be called
* repeatedly to add more headers).
*
* @param header
* Name of the header to set
* @param value
* The value to set
* @return
* The same instance of this class with a header set.
*/
def header(header: String, value: String): Builder = {
this.headers = this.headers + (header -> value)
this
}
/**
* Provide a custom SDK Config object to use for connection configuration. Can be used to
* override the default SDK Config that will be used.
*
* @param config
* Databricks SDK Config
* @return
* instance of Builder with the host configured
*/
def sdkConfig(config: DatabricksConfig): Builder = {
this.sdkConfig = Some(config)
this
}
/**
* Upload the compiled artifacts to the Databricks cluster and synchronize as part of session
* initialization.
*
* This is required when the program requires external dependencies to evaluate the Spark
* query. For any Spark query using a typed Dataset API such as filter(), map(), foreach(),
* etc., or if the query is using a UDF, the enclosing classes should be synchronized with the
* cluster.
*
* For example: the following code snippet includes the compiled artifact of class `Foo`. The
* `uri` in this example will either point to a JAR file or to a folder that contains the
* compiled class files.
*
* {{{
* val uri = Foo.getClass.getProtectionDomain.getCodeSource.getLocation.toURI
* val spark = DatabricksSession.builder()
* .addCompiledArtifacts(uri)
* ...
* }}}
*
* @param uri
* Specify the class file, JAR or directory to upload and synchronize. When a directory is
* specified, all deep nested JARs and class files will be synchronized.
* @return
* instance of this Builder
*/
def addCompiledArtifacts(uri: URI): Builder = {
val f = new File(uri.getPath)
if (!f.isFile && !f.isDirectory) {
throw new IOException(s"Path provided must be a valid file or directory: $uri")
}
Files
.find(
Paths.get(uri),
999,
(p, _) => p.toString.endsWith(".class") || p.toString.endsWith(".jar"))
.forEach(p => {
val pUri = p.toUri
val pFile = p.toFile
var target: String = null
if (
// In shared clusters, only jars placed in the root target directory are recognized.
// See https://databricks.atlassian.net/browse/SASP-2988
pFile.getName.endsWith(".jar")
// uri == pUri when the uri parameter is not a folder but a direct file.
|| uri.equals(pUri)) {
target = pFile.getName
} else {
target = uri.relativize(pUri).toString
}
this.artifacts = this.artifacts + (target -> pFile)
})
this
}
/**
* For unit testing
*/
private[connect] def getCompiledArtifacts = {
this.artifacts
}
/**
* Setting this option will run validations and throw an error if any fail. Validations are
* run only when connecting to a Databricks cluster.
* - the specified connection parameters are valid and can communicate with the cluster
* - the Databricks Runtime version of the cluster is greater than or equal to the
* Databricks Connect version. By default, these validations are run and a warning is
* logged. Unsetting this option will turn off these validations. Skipping validations for
* connections strings.
*
* @param enabled
* Boolean, optional
* @return
* instance of Builder with validateSession option configured
*/
def validateSession(enabled: Boolean = true): Builder = {
this.validateSession = Some(enabled)
this
}
// Package private variable for unit testing.
private[connect] var getVersion: () => String = () => version
// Package private variable for unit testing.
private[connect] var createWorkspaceClient: DatabricksConfig => WorkspaceClient =
config => new WorkspaceClient(config)
/**
* Validates the configuration by retrieving the used Databricks Runtime version with the
* Databricks SDK. Checks if there is an unsupported combination of Databricks Runtime &
* Databricks Connect versions. Throws an exception if the validation fails.
*/
private[connect] def validateSessionWithSdk: Unit = {
case class DBRVersion(major: Int, minor: Option[Int]) {
override def toString: String = {
minor match {
case Some(minor) => s"$major.$minor"
case None => s"$major.x"
}
}
}
case class DBConnectVersion(major: Int, minor: Int) {
override def toString: String = s"$major.$minor"
}
def parseDbrVersion(version: String): DBRVersion = {
// should work with custom image names, "major.x"-image names, and standard image names
val pattern = raw"(\d+(\.\d+)?\.x)".r
pattern.findFirstIn(version) match {
case Some(matched) =>
val versions = matched.split("\\.")
DBRVersion(versions(0).toInt, Try(versions(1).toInt).toOption)
case None =>
throw new IllegalArgumentException(
"Failed to parse minor & major version from Databricks Runtime" +
s"version string: $version")
}
}
def parseDBConnectVersion(version: String): DBConnectVersion = {
val versions = version.split("\\.").map(_.toInt)
// dbConnectVersion should always have a valid major and minor version
DBConnectVersion(versions(0), versions(1))
}
logDebug(
"Validating compatibility between the Databricks Runtime and Databricks Connect versions")
val clusterId = resolvedSdkConfig match {
case Some(config) => config.getClusterId
case None =>
throw new SparkException("Session validation is not supported for connection strings.")
}
val dbrVersionString = createWorkspaceClient(resolvedSdkConfig.get)
.clusters()
.get(clusterId)
.getSparkVersion
val dbrVersion = parseDbrVersion(dbrVersionString)
// dbConnectVersion should always have a valid major and minor version
val dbConnectVersion = parseDBConnectVersion(getVersion())
if (dbrVersion.major < dbConnectVersion.major || (dbrVersion.major == dbConnectVersion.major
&& dbrVersion.minor.nonEmpty && dbrVersion.minor.get < dbConnectVersion.minor)) {
throw new SparkException(
"Unsupported combination of Databricks Runtime & Databricks Connect versions: " +
s"${dbrVersion.toString} (Databricks Runtime) < ${dbConnectVersion.toString} " +
"(Databricks Connect)."
)
}
// Print a warning when minor versions DBR and DB Connect are inconsistent. Since currently,
// we only guarantee compatibility within the same minor version, using mismatched minor
// versions can lead to errors in certain operations, such as using Scala UDFs. If an error
// occurs, the user should retry with a matching version of DB Connect.
// See this ticket for more information: go/jira/SASP-3736
if (dbrVersion.major != dbConnectVersion.major || dbrVersion.minor.isEmpty ||
dbrVersion.minor.get != dbConnectVersion.minor) {
throw new SparkException(
s"Minor version mismatch: Databricks Runtime (${dbrVersion.toString}) vs. Databricks " +
s"Connect (${dbConnectVersion.toString}). Compatibility between Databricks Connect " +
s"and Databricks Runtime is guaranteed within the same minor version. This mismatch " +
s"may lead to errors with User-defined Functions. If you encounter any errors, " +
s"use the version of Databricks Connect that matches the Databricks Runtime " +
s"version you are running."
)
}
logDebug("Session validated successfully.")
}
// scalastyle:off line.size.limit
/**
* Get a spark session if one was already created with the given configuration. Otherwise,
* create a new one.
*
* Connection parameters for connecting to Databricks Connect is collected in the following
* order. When one of the configuration parameters is collected, the lookup for that parameter
* stops; collection for other parameters continues down the chain.
* - Parameters configured directly in code via [[host()]], [[token()]], etc.
* - Parameters configured in the Databricks SDK's [[DatabricksConfig]] specified via
* [[sdkConfig()]]
* - The `SPARK_REMOTE` environment variable. For details on the connection string, see
* https://github.com/apache/spark/blob/master/connector/connect/docs/client-connection-string.md
* - Use the "default" configuration profile from the Databricks config file. For details on
* Databricks configuration profiles, see https://docs.databricks.com/dev-tools/auth.html
*
* By default, validates the used configuration by retrieving the used Databricks Runtime
* version with the Databricks SDK. Logs a warning if the authentication fails or if there is
* an unsupported combination of Databricks Runtime & Databricks Connect versions. Change this
* behaviour with [[DatabricksSession.builder().validateSession()]].
*
* @return
* spark session
*/
// scalastyle:on line.size.limit
def getOrCreate(): SparkSession = {
this.sdkConfig.foreach(_.resolve())
var spark: SparkSession = null
if (Array(this.host, this.clusterId, this.token).exists(_.isDefined)
|| this.serverlessMode.contains(true)) {
logDebug("DatabricksSession: Initializing from explicitly set host, cluster, token")
// use either provided configuration or default configuration
val config = this.sdkConfig.getOrElse(new DatabricksConfig())
// SASP-2727: Explicitly configure the auth type to be PAT.
config.setAuthType("pat")
config.resolve()
// For any explicitly provided config, override the config.
if (this.host.isDefined) {
config.setHost(this.host.get)
}
if (this.clusterId.isDefined && this.serverlessMode.contains(true)) {
throw new IllegalArgumentException("Can't set both cluster id and serverless.")
}
if (this.clusterId.isDefined) {
config.setClusterId(this.clusterId.get)
}
if (this.token.isDefined) {
config.setToken(this.token.get)
}
this.resolvedSdkConfig = Some(config)
spark = fromSdkConfig(config)
} else if (this.sdkConfig.isDefined) {
logDebug("DatabricksSession: Initializing from sdkConfig")
this.resolvedSdkConfig = this.sdkConfig
spark = fromSdkConfig(this.sdkConfig.get.resolve())
} else if (env.contains("SPARK_REMOTE")) {
// look for SPARK_REMOTE
logDebug("DatabricksSession: Initializing from SPARK_REMOTE")
if (this.headers.nonEmpty) {
throw new IllegalArgumentException(
"Can't configure custom headers with SPARK_REMOTE connection")
}
// SparkSession.Builder knows how to handle SPARK_REMOTE
this.resolvedSdkConfig = None
spark = builder.getOrCreate()
} else {
logDebug("DatabrickSession: Constructing from default SDK Config")
val config = new DatabricksConfig().resolve()
this.resolvedSdkConfig = Some(config)
spark = fromSdkConfig(config)
}
if (spark == null) {
throw new RuntimeException("Spark session could not be initialized: unexpected state")
}
for ((target, file) <- this.artifacts) {
logDebug(s"Uploading file [$file] to target path [$target].")
spark.addArtifact(Files.readAllBytes(file.toPath), target)
}
// Validate the session for None and True.
// Only fail if validateSession is explicitly enabled. For None, log a warning.
// Skip validation if serverless mode is used or session-id is set manually.
if (
validateSession.getOrElse(true)
&& !this.serverlessMode.contains(true)
&& !headers.contains("x-databricks-session-id")
) {
try {
validateSessionWithSdk
} catch {
case e: Throwable if validateSession.isEmpty => logWarning(e.getMessage)
}
}
spark
}
private def fromSdkConfig(config: DatabricksConfig): SparkSession = {
if (config.getHost == null) {
throw new IllegalArgumentException(
"DatabricksSession: Need host to construct session. Received null.")
}
val host = config.getHost.stripPrefix("https://").stripPrefix("http://")
// local headers variable will be modified when appending "x-databricks-session-id",
// but this.headers will not
var headers = this.headers
(
this.serverlessMode.contains(true),
headers.contains("x-databricks-session-id"),
config.getClusterId,
) match {
case (true, _, _) | (_, true, _) =>
if (!headers.contains("x-databricks-session-id")) {
headers += ("x-databricks-session-id" -> UUID.randomUUID().toString)
}
logDebug("Using serverless compute")
logDebug(
"Using SparkSession with remote session id: " + headers("x-databricks-session-id"))
val conf = new DatabricksSparkClientConfiguration(config, host, genUserAgent(), headers)
fromSparkClientConf(conf)
case (_, _, null) =>
throw new IllegalArgumentException(
"DatabricksSession: need cluster id or serverless to construct a session.")
case (_, _, cluster_id) =>
logDebug("Using cluster: " + cluster_id)
val conf = new DatabricksSparkClientConfiguration(
config,
host,
genUserAgent(),
headers + ("x-databricks-cluster-id" -> cluster_id))
fromSparkClientConf(conf)
}
}
/* package protected for unit testing */
protected[connect] def fromSparkClientConf(
conf: SparkConnectClient.Configuration): SparkSession = {
builder.client(conf.toSparkConnectClient).getOrCreate()
}
}
/**
* Get a new Builder to configure and construct a new [[org.apache.spark.sql.SparkSession]].
* @return
* new Builder
*/
def builder(): Builder = new Builder
val version: String = loadProperty("version", "dbconnect-version-info.properties")
val gitVersion: String = loadProperty("git_version", "dbconnect-version-info.properties")
private def loadProperty(key: String, file: String): String = {
val resourceStream = Thread
.currentThread()
.getContextClassLoader
.getResourceAsStream(file)
if (resourceStream == null) {
throw new SparkException(s"Could not find ${file}")
}
try {
val unknownProp = ""
val props = new Properties()
props.load(resourceStream)
props.getProperty(key, unknownProp)
} catch {
case e: Exception =>
throw new SparkException(s"Error loading properties from ${file}", e)
} finally {
if (resourceStream != null) {
try {
resourceStream.close()
} catch {
case e: Exception =>
throw new SparkException(s"Error closing ${file} resource stream", e)
}
}
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy