org.apache.spark.sql.SparkSession.scala Maven / Gradle / Ivy
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* See the License for the specific language governing permissions and
* limitations under the License.
package org.apache.spark.sql
import java.nio.file.{Files, Paths}
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.{AtomicLong, AtomicReference}
import scala.jdk.CollectionConverters._
import scala.reflect.runtime.universe.TypeTag
import scala.util.Try
import{CacheBuilder, CacheLoader}
import io.grpc.ClientInterceptor
import org.apache.arrow.memory.RootAllocator
import org.apache.spark.annotation.{DeveloperApi, Experimental, Since}
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.ExecutePlanResponse
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalog.Catalog
import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection}
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BoxedLongEncoder, UnboundRowEncoder}
import org.apache.spark.sql.connect.client.{ClassFinder, CloseableIterator, SparkConnectClient, SparkResult}
import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration
import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.internal.{CatalogImpl, SessionCleaner, SqlApiConf}
import org.apache.spark.sql.internal.ColumnNodeToProtoConverter.{toExpr, toTypedExpr}
import org.apache.spark.sql.streaming.DataStreamReader
import org.apache.spark.sql.streaming.StreamingQueryManager
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.ArrayImplicits._
* The entry point to programming Spark with the Dataset and DataFrame API.
* In environments that this has been created upfront (e.g. REPL, notebooks), use the builder to
* get an existing session:
* {{{
* SparkSession.builder().getOrCreate()
* }}}
* The builder can also be used to create a new session:
* {{{
* SparkSession.builder
* .remote("sc://localhost:15001/myapp")
* .getOrCreate()
* }}}
class SparkSession private[sql] (
private[sql] val client: SparkConnectClient,
private val planIdGenerator: AtomicLong)
extends api.SparkSession[Dataset]
with Logging {
private[this] val allocator = new RootAllocator()
private[sql] lazy val cleaner = new SessionCleaner(this)
// a unique session ID for this session from client.
private[sql] def sessionId: String = client.sessionId
lazy val version: String = {
private[sql] val observationRegistry = new ConcurrentHashMap[Long, Observation]()
private[sql] def hijackServerSideSessionIdForTesting(suffix: String) = {
* Runtime configuration interface for Spark.
* This is the interface through which the user can get and set all Spark configurations that
* are relevant to Spark SQL. When getting the value of a config, his defaults to the value set
* in server, if any.
* @since 3.4.0
val conf: RuntimeConfig = new RuntimeConfig(client)
/** @inheritdoc */
val emptyDataFrame: DataFrame = emptyDataset(UnboundRowEncoder)
/** @inheritdoc */
def emptyDataset[T: Encoder]: Dataset[T] = createDataset[T](Nil)
private def createDataset[T](encoder: AgnosticEncoder[T], data: Iterator[T]): Dataset[T] = {
newDataset(encoder) { builder =>
if (data.nonEmpty) {
val arrowData = ArrowSerializer.serialize(data, encoder, allocator, timeZoneId)
if (arrowData.size() <= conf.get(SqlApiConf.LOCAL_RELATION_CACHE_THRESHOLD_KEY).toInt) {
} else {
val hash = client.cacheLocalRelation(arrowData, encoder.schema.json)
} else {
/** @inheritdoc */
def createDataFrame[A <: Product: TypeTag](data: Seq[A]): DataFrame = {
createDataset(ScalaReflection.encoderFor[A], data.iterator).toDF()
/** @inheritdoc */
def createDataFrame(rows: java.util.List[Row], schema: StructType): DataFrame = {
createDataset(RowEncoder.encoderFor(schema), rows.iterator().asScala).toDF()
/** @inheritdoc */
def createDataFrame(data: java.util.List[_], beanClass: Class[_]): DataFrame = {
val encoder = JavaTypeInference.encoderFor(beanClass.asInstanceOf[Class[Any]])
createDataset(encoder, data.iterator().asScala).toDF()
/** @inheritdoc */
def createDataset[T: Encoder](data: Seq[T]): Dataset[T] = {
createDataset(encoderFor[T], data.iterator)
/** @inheritdoc */
def createDataset[T: Encoder](data: java.util.List[T]): Dataset[T] = {
/** @inheritdoc */
def sql(sqlText: String, args: Array[_]): DataFrame = newDataFrame { builder =>
// Send the SQL once to the server and then check the output.
val cmd = newCommand(b =>
val plan = proto.Plan.newBuilder().setCommand(cmd)
val responseIter = client.execute(
try {
val response = responseIter
.getOrElse(throw new RuntimeException("SQLCommandResult must be present"))
// Update the builder with the values from the result.
} finally {
// consume the rest of the iterator
responseIter.foreach(_ => ())
/** @inheritdoc */
def sql(sqlText: String, args: Map[String, Any]): DataFrame = {
sql(sqlText, args.asJava)
/** @inheritdoc */
override def sql(sqlText: String, args: java.util.Map[String, Any]): DataFrame = newDataFrame {
builder =>
// Send the SQL once to the server and then check the output.
val cmd = newCommand(b =>
.putAllNamedArguments( { case (k, v) => (k, lit(v).expr) }.asJava)))
val plan = proto.Plan.newBuilder().setCommand(cmd)
val responseIter = client.execute(
try {
val response = responseIter
.getOrElse(throw new RuntimeException("SQLCommandResult must be present"))
// Update the builder with the values from the result.
} finally {
// consume the rest of the iterator
responseIter.foreach(_ => ())
/** @inheritdoc */
override def sql(query: String): DataFrame = {
sql(query, Array.empty)
/** @inheritdoc */
def read: DataFrameReader = new DataFrameReader(this)
* Returns a `DataStreamReader` that can be used to read streaming data in as a `DataFrame`.
* {{{
* sparkSession.readStream.parquet("/path/to/directory/of/parquet/files")
* sparkSession.readStream.schema(schema).json("/path/to/directory/of/json/files")
* }}}
* @since 3.5.0
def readStream: DataStreamReader = new DataStreamReader(this)
lazy val streams: StreamingQueryManager = new StreamingQueryManager(this)
/** @inheritdoc */
lazy val catalog: Catalog = new CatalogImpl(this)
/** @inheritdoc */
def table(tableName: String): DataFrame = {
/** @inheritdoc */
def range(end: Long): Dataset[java.lang.Long] = range(0, end)
/** @inheritdoc */
def range(start: Long, end: Long): Dataset[java.lang.Long] = {
range(start, end, step = 1)
/** @inheritdoc */
def range(start: Long, end: Long, step: Long): Dataset[java.lang.Long] = {
range(start, end, step, None)
/** @inheritdoc */
def range(start: Long, end: Long, step: Long, numPartitions: Int): Dataset[java.lang.Long] = {
range(start, end, step, Option(numPartitions))
/** @inheritdoc */
lazy val udf: UDFRegistration = new UDFRegistration(this)
// scalastyle:off
// Disable style checker so "implicits" object can start with lowercase i
* (Scala-specific) Implicit methods available in Scala for converting common names and Symbols
* into [[Column]]s, and for converting common Scala objects into DataFrame`s.
* {{{
* val sparkSession = SparkSession.builder.getOrCreate()
* import sparkSession.implicits._
* }}}
* @since 3.4.0
object implicits extends SQLImplicits(this) with Serializable
// scalastyle:on
/** @inheritdoc */
def newSession(): SparkSession = {
private def range(
start: Long,
end: Long,
step: Long,
numPartitions: Option[Int]): Dataset[java.lang.Long] = {
newDataset(BoxedLongEncoder) { builder =>
val rangeBuilder = builder.getRangeBuilder
def newDataFrame(f: proto.Relation.Builder => Unit): DataFrame = {
def newDataset[T](encoder: AgnosticEncoder[T])(
f: proto.Relation.Builder => Unit): Dataset[T] = {
val builder = proto.Relation.newBuilder()
val plan = proto.Plan.newBuilder().setRoot(builder).build()
new Dataset[T](this, plan, encoder)
private[sql] def newCommand[T](f: proto.Command.Builder => Unit): proto.Command = {
val builder = proto.Command.newBuilder()
private[sql] def analyze(
plan: proto.Plan,
method: proto.AnalyzePlanRequest.AnalyzeCase,
explainMode: Option[proto.AnalyzePlanRequest.Explain.ExplainMode] = None)
: proto.AnalyzePlanResponse = {
client.analyze(method, Some(plan), explainMode)
private[sql] def analyze(
f: proto.AnalyzePlanRequest.Builder => Unit): proto.AnalyzePlanResponse = {
val builder = proto.AnalyzePlanRequest.newBuilder()
private[sql] def sameSemantics(plan: proto.Plan, otherPlan: proto.Plan): Boolean = {
client.sameSemantics(plan, otherPlan).getSameSemantics.getResult
private[sql] def semanticHash(plan: proto.Plan): Int = {
private[sql] def timeZoneId: String = conf.get(SqlApiConf.SESSION_LOCAL_TIMEZONE_KEY)
private[sql] def execute[T](plan: proto.Plan, encoder: AgnosticEncoder[T]): SparkResult[T] = {
val value = client.execute(plan)
new SparkResult(
private[sql] def execute(f: proto.Relation.Builder => Unit): Unit = {
val builder = proto.Relation.newBuilder()
val plan = proto.Plan.newBuilder().setRoot(builder).build()
// .foreach forces that the iterator is consumed and closed
client.execute(plan).foreach(_ => ())
def execute(command: proto.Command): Seq[ExecutePlanResponse] = {
val plan = proto.Plan.newBuilder().setCommand(command).build()
// .toSeq forces that the iterator is consumed and closed. On top, ignore all
// progress messages.
private[sql] def execute(plan: proto.Plan): CloseableIterator[ExecutePlanResponse] =
private[sql] def registerUdf(udf: proto.CommonInlineUserDefinedFunction): Unit = {
val command = proto.Command.newBuilder().setRegisterFunction(udf).build()
/** @inheritdoc */
override def addArtifact(path: String): Unit = client.addArtifact(path)
/** @inheritdoc */
override def addArtifact(uri: URI): Unit = client.addArtifact(uri)
/** @inheritdoc */
override def addArtifact(bytes: Array[Byte], target: String): Unit = {
client.addArtifact(bytes, target)
/** @inheritdoc */
override def addArtifact(source: String, target: String): Unit = {
client.addArtifact(source, target)
/** @inheritdoc */
override def addArtifacts(uri: URI*): Unit = client.addArtifacts(uri)
* Register a ClassFinder for dynamically generated classes.
* @since 3.5.0
def registerClassFinder(finder: ClassFinder): Unit = client.registerClassFinder(finder)
* This resets the plan id generator so we can produce plans that are comparable.
* For testing only!
private[sql] def resetPlanIdGenerator(): Unit = {
* Interrupt all operations of this session currently running on the connected server.
* @return
* sequence of operationIds of interrupted operations. Note: there is still a possibility of
* operation finishing just as it is interrupted.
* @since 3.5.0
def interruptAll(): Seq[String] = {
* Interrupt all operations of this session with the given operation tag.
* @return
* sequence of operationIds of interrupted operations. Note: there is still a possibility of
* operation finishing just as it is interrupted.
* @since 3.5.0
def interruptTag(tag: String): Seq[String] = {
* Interrupt an operation of this session with the given operationId.
* @return
* sequence of operationIds of interrupted operations. Note: there is still a possibility of
* operation finishing just as it is interrupted.
* @since 3.5.0
def interruptOperation(operationId: String): Seq[String] = {
* Close the [[SparkSession]].
* Release the current session and close the GRPC connection to the server. The API will not
* error if any of these operations fail. Closing a closed session is a no-op.
* Close the allocator. Fail if there are still open SparkResults.
* @since 3.4.0
override def close(): Unit = {
if (releaseSessionOnClose) {
try {
} catch {
case e: Exception => logWarning("session.stop: Failed to release session", e)
try {
} catch {
case e: Exception => logWarning("session.stop: Failed to shutdown the client", e)
* Add a tag to be assigned to all the operations started by this thread in this session.
* Often, a unit of execution in an application consists of multiple Spark executions.
* Application programmers can use this method to group all those jobs together and give a group
* tag. The application can use `org.apache.spark.sql.SparkSession.interruptTag` to cancel all
* running running executions with this tag. For example:
* {{{
* // In the main thread:
* spark.addTag("myjobs")
* spark.range(10).map(i => { Thread.sleep(10); i }).collect()
* // In a separate thread:
* spark.interruptTag("myjobs")
* }}}
* There may be multiple tags present at the same time, so different parts of application may
* use different tags to perform cancellation at different levels of granularity.
* @param tag
* The tag to be added. Cannot contain ',' (comma) character or be an empty string.
* @since 3.5.0
def addTag(tag: String): Unit = {
* Remove a tag previously added to be assigned to all the operations started by this thread in
* this session. Noop if such a tag was not added earlier.
* @param tag
* The tag to be removed. Cannot contain ',' (comma) character or be an empty string.
* @since 3.5.0
def removeTag(tag: String): Unit = {
* Get the tags that are currently set to be assigned to all the operations started by this
* thread.
* @since 3.5.0
def getTags(): Set[String] = {
* Clear the current thread's operation tags.
* @since 3.5.0
def clearTags(): Unit = {
* We cannot deserialize a connect [[SparkSession]] because of a class clash on the server side.
* We null out the instance for now.
private def writeReplace(): Any = null
* Set to false to prevent client.releaseSession on close() (testing only)
private[sql] var releaseSessionOnClose = true
private[sql] def registerObservation(planId: Long, observation: Observation): Unit = {
observationRegistry.putIfAbsent(planId, observation)
private[sql] def setMetricsAndUnregisterObservation(planId: Long, metrics: Row): Unit = {
val observationOrNull = observationRegistry.remove(planId)
if (observationOrNull != null) {
implicit class RichColumn(c: Column) {
def expr: proto.Expression = toExpr(c)
def typedExpr[T](e: Encoder[T]): proto.Expression = toTypedExpr(c, e)
// The minimal builder needed to create a spark session.
// TODO: implements all methods mentioned in the scaladoc of [[SparkSession]]
object SparkSession extends Logging {
private val MAX_CACHED_SESSIONS = 100
private val planIdGenerator = new AtomicLong
private var server: Option[Process] = None
private[sql] val sparkOptions = sys.props.filter { p =>
p._1.startsWith("spark.") && p._2.nonEmpty
private val sessions = CacheBuilder
.build(new CacheLoader[Configuration, SparkSession] {
override def load(c: Configuration): SparkSession = create(c)
/** The active SparkSession for the current thread. */
private val activeThreadSession = new InheritableThreadLocal[SparkSession]
/** Reference to the root SparkSession. */
private val defaultSession = new AtomicReference[SparkSession]
* Set the (global) default [[SparkSession]], and (thread-local) active [[SparkSession]] when
* they are not set yet or the associated [[SparkConnectClient]] is unusable.
private def setDefaultAndActiveSession(session: SparkSession): Unit = {
val currentDefault = defaultSession.getAcquire
if (currentDefault == null || !currentDefault.client.isSessionValid) {
// Update `defaultSession` if it is null or the contained session is not valid. There is a
// chance that the following `compareAndSet` fails if a new default session has just been set,
// but that does not matter since that event has happened after this method was invoked.
defaultSession.compareAndSet(currentDefault, session)
if (getActiveSession.isEmpty) {
* Create a new Spark Connect server to connect locally.
private[sql] def withLocalConnectServer[T](f: => T): T = {
synchronized {
val remoteString = sparkOptions
.orElse(Option(System.getProperty("spark.remote"))) // Set from Spark Submit
val maybeConnectScript =
Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", ""))
if (server.isEmpty &&
remoteString.exists(_.startsWith("local")) &&
maybeConnectScript.exists(Files.exists(_))) {
server = Some {
val args =
Seq(maybeConnectScript.get.toString, "--master", remoteString.get) ++ sparkOptions
.filter(p => !p._1.startsWith("spark.remote"))
.flatMap { case (k, v) => Seq("--conf", s"$k=$v") }
val pb = new ProcessBuilder(args: _*)
// So don't exclude spark-sql jar in classpath
// Let the server start. We will directly request to set the configurations
// and this sleep makes less noisy with retries.
System.setProperty("spark.remote", "sc://localhost")
// scalastyle:off runtimeaddshutdownhook
Runtime.getRuntime.addShutdownHook(new Thread() {
override def run(): Unit = if (server.isDefined) {
new ProcessBuilder(maybeConnectScript.get.toString)
// scalastyle:on runtimeaddshutdownhook
* Create a new [[SparkSession]] based on the connect client [[Configuration]].
private[sql] def create(configuration: Configuration): SparkSession = {
new SparkSession(configuration.toSparkConnectClient, planIdGenerator)
* Hook called when a session is closed.
private[sql] def onSessionClose(session: SparkSession): Unit = {
defaultSession.compareAndSet(session, null)
if (getActiveSession.contains(session)) {
* Creates a [[SparkSession.Builder]] for constructing a [[SparkSession]].
* @since 3.4.0
def builder(): Builder = new Builder()
class Builder() extends Logging {
// Initialize the connection string of the Spark Connect client builder from SPARK_REMOTE
// by default, if it exists. The connection string can be overridden using
// the remote() function, as it takes precedence over the SPARK_REMOTE environment variable.
private val builder = SparkConnectClient.builder().loadFromEnvironment()
private var client: SparkConnectClient = _
private[this] val options = new scala.collection.mutable.HashMap[String, String]
def remote(connectionString: String): Builder = {
* Add an interceptor to be used during channel creation.
* Note that interceptors added last are executed first by gRPC.
* @since 3.5.0
def interceptor(interceptor: ClientInterceptor): Builder = {
private[sql] def client(client: SparkConnectClient): Builder = {
this.client = client
* Sets a config option. Options set using this method are automatically propagated to the
* Spark Connect session. Only runtime options are supported.
* @since 3.5.0
def config(key: String, value: String): Builder = synchronized {
options += key -> value
* Sets a config option. Options set using this method are automatically propagated to the
* Spark Connect session. Only runtime options are supported.
* @since 3.5.0
def config(key: String, value: Long): Builder = synchronized {
options += key -> value.toString
* Sets a config option. Options set using this method are automatically propagated to the
* Spark Connect session. Only runtime options are supported.
* @since 3.5.0
def config(key: String, value: Double): Builder = synchronized {
options += key -> value.toString
* Sets a config option. Options set using this method are automatically propagated to the
* Spark Connect session. Only runtime options are supported.
* @since 3.5.0
def config(key: String, value: Boolean): Builder = synchronized {
options += key -> value.toString
* Sets a config a map of options. Options set using this method are automatically propagated
* to the Spark Connect session. Only runtime options are supported.
* @since 3.5.0
def config(map: Map[String, Any]): Builder = synchronized {
map.foreach { kv: (String, Any) =>
options += kv._1 -> kv._2.toString
* Sets a config option. Options set using this method are automatically propagated to both
* `SparkConf` and SparkSession's own configuration.
* @since 3.5.0
def config(map: java.util.Map[String, Any]): Builder = synchronized {
@deprecated("enableHiveSupport does not work in Spark Connect")
def enableHiveSupport(): Builder = this
@deprecated("master does not work in Spark Connect, please use remote instead")
def master(master: String): Builder = this
@deprecated("appName does not work in Spark Connect")
def appName(name: String): Builder = this
private def tryCreateSessionFromClient(): Option[SparkSession] = {
if (client != null && client.isSessionValid) {
Option(new SparkSession(client, planIdGenerator))
} else {
private def applyOptions(session: SparkSession): Unit = {
// Only attempts to set Spark SQL configurations.
// If the configurations are static, it might throw an exception so
// simply ignore it for now.
.filter { case (k, _) =>
.foreach { case (key, value) =>
Try(session.conf.set(key, value))
options.foreach { case (key, value) =>
session.conf.set(key, value)
* Build the [[SparkSession]].
* This will always return a newly created session.
@deprecated(message = "Please use create() instead.", since = "3.5.0")
def build(): SparkSession = create()
* Create a new [[SparkSession]].
* This will always return a newly created session.
* This method will update the default and/or active session if they are not set.
* @since 3.5.0
def create(): SparkSession = withLocalConnectServer {
val session = tryCreateSessionFromClient()
* Get or create a [[SparkSession]].
* If a session exist with the same configuration that is returned instead of creating a new
* session.
* This method will update the default and/or active session if they are not set. This method
* will always set the specified configuration options on the session, even when it is not
* newly created.
* @since 3.5.0
def getOrCreate(): SparkSession = withLocalConnectServer {
val session = tryCreateSessionFromClient()
var existingSession = sessions.get(builder.configuration)
if (!existingSession.client.isSessionValid) {
// If the cached session has become invalid, e.g., due to a server restart, the cache
// entry is invalidated.
existingSession = sessions.get(builder.configuration)
* Returns the default SparkSession. If the previously set default SparkSession becomes
* unusable, returns None.
* @since 3.5.0
def getDefaultSession: Option[SparkSession] =
* Sets the default SparkSession.
* @since 3.5.0
def setDefaultSession(session: SparkSession): Unit = {
* Clears the default SparkSession.
* @since 3.5.0
def clearDefaultSession(): Unit = {
* Returns the active SparkSession for the current thread. If the previously set active
* SparkSession becomes unusable, returns None.
* @since 3.5.0
def getActiveSession: Option[SparkSession] =
* Changes the SparkSession that will be returned in this thread and its children when
* SparkSession.getOrCreate() is called. This can be used to ensure that a given thread receives
* an isolated SparkSession.
* @since 3.5.0
def setActiveSession(session: SparkSession): Unit = {
* Clears the active SparkSession for current thread.
* @since 3.5.0
def clearActiveSession(): Unit = {
* Returns the currently active SparkSession, otherwise the default one. If there is no default
* SparkSession, throws an exception.
* @since 3.5.0
def active: SparkSession = {
.getOrElse(throw new IllegalStateException("No active or default Spark session found"))
