
com.databricks.connect.DatabricksSession.scala Maven / Gradle / Ivy
/*
* DATABRICKS CONFIDENTIAL & PROPRIETARY
* __________________
*
* Copyright 2023-present Databricks, Inc.
* All Rights Reserved.
*
* NOTICE: All information contained herein is, and remains the property of Databricks, Inc.
* and its suppliers, if any. The intellectual and technical concepts contained herein are
* proprietary to Databricks, Inc. and its suppliers and may be covered by U.S. and foreign Patents,
* patents in process, and are protected by trade secret and/or copyright law. Dissemination, use,
* or reproduction of this information is strictly forbidden unless prior written permission is
* obtained from Databricks, Inc.
*
* If you view or obtain a copy of this information and believe Databricks, Inc. may not have
* intended it to be made available, please promptly report it to Databricks Legal Department
* @ [email protected].
*/
package com.databricks.connect
import java.io.{File, IOException}
import java.net.URI
import java.nio.file.{Files, Paths}
import java.util.Properties
import scala.collection.immutable
import scala.util.Try
import com.databricks.sdk.WorkspaceClient
import com.databricks.sdk.core.{DatabricksConfig, UserAgent}
import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connect.client.SparkConnectClient
/**
* Entry-point for building a DB Connect Scala Client Session.
* See [[DatabricksSession.Builder.getOrCreate()]].
*/
object DatabricksSession {
class Builder private[connect] ( // for unit testing
private val builder: SparkSession.Builder = SparkSession.builder(),
private val env: Map[String, String] = sys.env)
extends Logging {
private var clusterId: Option[String] = None
private var host: Option[String] = None
private var token: Option[String] = None
private var sdkConfig: Option[DatabricksConfig] = None
private var userAgent: Option[String] = None
private var headers: immutable.HashMap[String, String] = new immutable.HashMap()
private var artifacts: scala.collection.Map[String, File] = Map()
private var validateSession: Option[Boolean] = None
/**
* The instance of DatabricksConfig that was finally used to pick up connection
* parameters. This is populated as part of `getOrCreate()` API.
* When the Databricks SDK was not used, `None` will be returned.
* Package private for unit testing.
*/
private[connect] var resolvedSdkConfig: Option[DatabricksConfig] = None
private[connect] def this(env: Map[String, String]) = this(SparkSession.builder(), env)
/**
* API returns the same instance of the builder unchanged.
* This API is here for backwards compatibility.
* @return instance of Builder
*/
def remote(): Builder = this;
/**
* The Databricks cluster id to connect.
* @param clusterId cluster id
* @return instance of Builder with the cluster id configured
*/
def clusterId(clusterId: String): Builder = {
this.clusterId = Some(clusterId)
this
}
/**
* The URL of the Databricks workspace to connect.
*
* @param host Databricks workspace URL
* @return instance of Builder with the host configured
*/
def host(host: String): Builder = {
this.host = Some(host)
this
}
/**
* The access token to use while for connecting. The user associated with this token
* should have access to the workspace and cluster. Subsequent queries will be executed
* on behalf of this user.
*
* @param token Databricks access token
* @return instance of Builder with the access token configured
*/
def token(token: String): Builder = {
this.token = Some(token)
this
}
/**
* A user agent string identifying the application using the Databricks Connect module.
* Databricks Connect sends a set of standard information, such as, OS, Python version and the
* version of Databricks Connect included as a user agent to the service. The value provided
* here will be included along with the rest of the information. It is recommended to provide
* this value in the format "/" as described in
* https://datatracker.ietf.org/doc/html/rfc7231#section-5.5.3, but is not required.
*
* @param userAgent
* The user agent string identifying the application that is using this module.
* @return
* instance of Builder with the cluster id configured
*/
def userAgent(userAgent: String): Builder = {
if (userAgent.length > 2048) {
throw new IllegalArgumentException("User agent should not exceed 2048 characters.")
}
this.userAgent = Some(userAgent)
this
}
private def genUserAgent(): String = {
val applicationUserAgent = userAgent.getOrElse(
env.getOrElse("SPARK_CONNECT_USER_AGENT", "databricks-session")
)
// reuse databricks-sdk UserAgent until their version is accessible programmatically
val sdkUserAgent = UserAgent.asString().split(" ")(1)
Seq(
applicationUserAgent,
s"dbconnect/$version",
sdkUserAgent
).mkString(" ").trim()
}
/**
* Add a header to the Spark Connect GRPC requests.
* This method is cumulative (can be called repeatedly to add more headers).
*
* @param header
* Name of the header to set
* @param value
* The value to set
* @return
* The same instance of this class with a header set.
*/
def header(header: String, value: String): Builder = {
this.headers = this.headers + (header -> value)
this
}
/**
* Provide a custom SDK Config object to use for connection configuration.
* Can be used to override the default SDK Config that will be used.
*
* @param config Databricks SDK Config
* @return instance of Builder with the host configured
*/
def sdkConfig(config: DatabricksConfig): Builder = {
this.sdkConfig = Some(config)
this
}
/**
* Upload the compiled artifacts to the Databricks cluster and synchronize as
* part of session initialization.
*
* This is required when the program requires external dependencies to evaluate
* the Spark query.
* For any Spark query using a typed Dataset API such as filter(), map(), foreach(),
* etc., or if the query is using a UDF, the enclosing classes should be synchronized
* with the cluster.
*
* For example: the following code snippet includes the compiled artifact of class `Foo`.
* The `uri` in this example will either point to a JAR file or to a folder that contains
* the compiled class files.
*
* {{{
* val uri = Foo.getClass.getProtectionDomain.getCodeSource.getLocation.toURI
* val spark = DatabricksSession.builder()
* .addCompiledArtifacts(uri)
* ...
* }}}
*
* @param uri Specify the class file, JAR or directory to upload and synchronize.
* When a directory is specified, all deep nested JARs and class files will
* be synchronized.
* @return instance of this Builder
*/
def addCompiledArtifacts(uri: URI): Builder = {
val f = new File(uri.getPath)
if (!f.isFile && !f.isDirectory) {
throw new IOException(s"Path provided must be a valid file or directory: $uri")
}
Files.find(Paths.get(uri), 999,
(p, _) => p.toString.endsWith(".class") || p.toString.endsWith(".jar")
).forEach(p => {
val pUri = p.toUri
val pFile = p.toFile
var target: String = null
if (
// In shared clusters, only jars placed in the root target directory are recognized.
// See https://databricks.atlassian.net/browse/SASP-2988
pFile.getName.endsWith(".jar")
// uri == pUri when the uri parameter is not a folder but a direct file.
|| uri.equals(pUri)
) {
target = pFile.getName
} else {
target = uri.relativize(pUri).toString
}
this.artifacts = this.artifacts + (target -> pFile)
})
this
}
/**
* For unit testing
*/
private[connect] def getCompiledArtifacts() = {
this.artifacts
}
/**
* Setting this option will run validations and throw an error if any fail. Validations are run
* only when connecting to a Databricks cluster.
* - the specified connection parameters are valid and can communicate with the cluster
* - the Databricks Runtime version of the cluster is greater than or equal to the Databricks
* Connect version.
* By default, these validations are run and a warning is logged. Unsetting this option will
* turn off these validations. Skipping validations for connections strings.
*
* @param enabled Boolean, optional
* @return instance of Builder with validateSession option configured
*/
def validateSession(enabled: Boolean = true): Builder = {
this.validateSession = Some(enabled)
this
}
// Package private variable for unit testing.
private[connect] var createWorkspaceClient: DatabricksConfig => WorkspaceClient =
config => new WorkspaceClient(config)
/**
* Validates the configuration by retrieving the used Databricks Runtime version with the
* Databricks SDK. Checks if there is an unsupported combination of Databricks Runtime &
* Databricks Connect versions. Throws an exception if the validation fails.
*/
private[connect] def validateSessionWithSdk: Unit = {
def parseDbrVersion(version: String): (Int, Option[Int]) = {
// should work with custom image names, "major.x"-image names, and standard image names
val pattern = raw"(\d+(\.\d+)?\.x)".r
pattern.findFirstIn(version) match {
case Some(matched) =>
val versions = matched.split("\\.")
(versions(0).toInt, Try(versions(1).toInt).toOption)
case None =>
throw new IllegalArgumentException(
"Failed to parse minor & major version from Databricks Runtime" +
s"version string: $version"
)
}
}
logDebug(
"Validating compatibility between the Databricks Runtime and Databricks Connect versions"
)
val clusterId = resolvedSdkConfig match {
case Some(config) => config.getClusterId
case None =>
throw new SparkException(
"Session validation is not supported for connection strings."
)
}
val dbrVersionString = createWorkspaceClient(resolvedSdkConfig.get)
.clusters()
.get(clusterId)
.getSparkVersion
val dbrVersion = parseDbrVersion(dbrVersionString)
// dbConnectVersion should always have a valid major and minor version
val dbConnectVersion = version.split("\\.").take(2).map(_.toInt)
if (dbrVersion._1 < dbConnectVersion(0) || (dbrVersion._1 == dbConnectVersion(0) &&
dbrVersion._2.nonEmpty && dbrVersion._2.get < dbConnectVersion(1))) {
throw new SparkException(
"Unsupported combination of Databricks Runtime & Databricks Connect versions: " +
s"${dbrVersion match {
case (major, Some(minor)) => s"$major.$minor"
case (major, None) => s"$major.x"
}} (Databricks Runtime) < ${dbConnectVersion(0)}.${dbConnectVersion(1)} " +
"(Databricks Connect)."
)
}
logDebug("Session validated successfully.")
}
// scalastyle:off line.size.limit
/**
* Get a spark session if one was already created with the given configuration.
* Otherwise, create a new one.
*
* Connection parameters for connecting to Databricks Connect is collected in the
* following order. When one of the configuration parameters is collected, the lookup
* for that parameter stops; collection for other parameters continues down the chain.
* - Parameters configured directly in code via [[host()]], [[token()]], etc.
* - Parameters configured in the Databricks SDK's [[DatabricksConfig]] specified via
* [[sdkConfig()]]
* - The `SPARK_REMOTE` environment variable. For details on the connection string, see
* https://github.com/apache/spark/blob/master/connector/connect/docs/client-connection-string.md
* - Use the "default" configuration profile from the Databricks config file. For details
* on Databricks configuration profiles, see https://docs.databricks.com/dev-tools/auth.html
*
* By default, validates the used configuration by retrieving the used Databricks Runtime
* version with the Databricks SDK. Logs a warning if the authentication fails or if there is
* an unsupported combination of Databricks Runtime & Databricks Connect versions. Change this
* behaviour with [[DatabricksSession.builder.validateSession()]].
*
* @return spark session
*/
// scalastyle:on line.size.limit
def getOrCreate(): SparkSession = {
this.sdkConfig.foreach(_.resolve())
var spark: SparkSession = null
if (!Array(this.host, this.clusterId, this.token).forall(_.isEmpty)) {
logDebug("DatabricksSession: Initializing from explicitly set host, cluster, token")
// use either provided configuration or default configuration
val config = this.sdkConfig.getOrElse(new DatabricksConfig())
// SASP-2727: Explicitly configure the auth type to be PAT.
config.setAuthType("pat")
config.resolve()
// For any explicitly provided config, override the config.
if (this.host.isDefined) {
config.setHost(this.host.get)
}
if (this.clusterId.isDefined) {
config.setClusterId(this.clusterId.get)
}
if (this.token.isDefined) {
config.setToken(this.token.get)
}
this.resolvedSdkConfig = Some(config)
spark = fromSdkConfig(config)
} else if (this.sdkConfig.isDefined) {
logDebug("DatabricksSession: Initializing from sdkConfig")
this.resolvedSdkConfig = this.sdkConfig
spark = fromSdkConfig(this.sdkConfig.get.resolve())
} else if (env.contains("SPARK_REMOTE")) {
// look for SPARK_REMOTE
logDebug("DatabricksSession: Initializing from SPARK_REMOTE")
if (this.headers.nonEmpty) {
throw new IllegalArgumentException(
"Can't configure custom headers with SPARK_REMOTE connection"
)
}
// SparkSession.Builder knows how to handle SPARK_REMOTE
this.resolvedSdkConfig = None
spark = builder.getOrCreate()
} else {
logDebug("DatabrickSession: Constructing from default SDK Config")
val config = new DatabricksConfig().resolve()
this.resolvedSdkConfig = Some(config)
spark = fromSdkConfig(config)
}
if (spark == null) {
throw new RuntimeException("Spark session could not be initialized: unexpected state")
}
for ((target, file) <- this.artifacts) {
logDebug(s"Uploading file [$file] to target path [$target].")
spark.addArtifact(Files.readAllBytes(file.toPath), target)
}
// skip validation if session-id is set manually
if (validateSession.getOrElse(true) && !headers.contains("x-databricks-session-id")) {
try {
validateSessionWithSdk
} catch {
case e: Throwable if validateSession.isEmpty => logWarning(e.getMessage)
}
}
spark
}
private def fromSdkConfig(config: DatabricksConfig): SparkSession = {
if (config.getHost == null) {
throw new IllegalArgumentException(
"DatabricksSession: Need host to construct session. Received null."
)
}
val host = config.getHost.stripPrefix("https://").stripPrefix("http://")
(this.headers.contains("x-databricks-session-id"), config.getClusterId) match {
case (true, _) =>
val conf = new DatabricksSparkClientConfiguration(config, host, genUserAgent(), headers)
fromSparkClientConf(conf)
case (_, null) =>
throw new IllegalArgumentException(
"DatabricksSession: need cluster id to " +
s"construct session. Got cluster id=${config.getClusterId}"
)
case (_, cluster_id) =>
val conf = new DatabricksSparkClientConfiguration(
config,
host,
genUserAgent(),
headers + ("x-databricks-cluster-id" -> config.getClusterId)
)
fromSparkClientConf(conf)
}
}
/* for unit testing */
protected[connect] def fromSparkClientConf(
conf: SparkConnectClient.Configuration): SparkSession = {
builder.client(conf.toSparkConnectClient).getOrCreate()
}
}
/**
* Get a new Builder to configure and construct a new [[org.apache.spark.sql.SparkSession]].
* @return new Builder
*/
def builder(): Builder = new Builder
val version = loadProperty("version", "dbconnect-version-info.properties")
val gitVersion = loadProperty("git_version", "dbconnect-version-info.properties")
private def loadProperty(key: String, file: String): String = {
val resourceStream = Thread
.currentThread()
.getContextClassLoader
.getResourceAsStream(file)
if (resourceStream == null) {
throw new SparkException(s"Could not find ${file}")
}
try {
val unknownProp = ""
val props = new Properties()
props.load(resourceStream)
props.getProperty(key, unknownProp)
} catch {
case e: Exception =>
throw new SparkException(s"Error loading properties from ${file}", e)
} finally {
if (resourceStream != null) {
try {
resourceStream.close()
} catch {
case e: Exception =>
throw new SparkException(s"Error closing ${file} resource stream", e)
}
}
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy