org.apache.spark.sql.SparkSession.scala Maven / Gradle / Ivy
The newest version!
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql
import java.net.URI
import java.nio.file.{Files, Paths}
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.{AtomicLong, AtomicReference}
import scala.jdk.CollectionConverters._
import scala.reflect.runtime.universe.TypeTag
import scala.util.Try
import com.google.common.cache.{CacheBuilder, CacheLoader}
import io.grpc.ClientInterceptor
import org.apache.arrow.memory.RootAllocator
import org.apache.spark.annotation.{DeveloperApi, Experimental, Since}
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.ExecutePlanResponse
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalog.Catalog
import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection}
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BoxedLongEncoder, UnboundRowEncoder}
import org.apache.spark.sql.connect.client.{ClassFinder, CloseableIterator, SparkConnectClient, SparkResult}
import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration
import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.internal.{CatalogImpl, SessionCleaner, SqlApiConf}
import org.apache.spark.sql.internal.ColumnNodeToProtoConverter.{toExpr, toTypedExpr}
import org.apache.spark.sql.streaming.DataStreamReader
import org.apache.spark.sql.streaming.StreamingQueryManager
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.ArrayImplicits._
/**
* The entry point to programming Spark with the Dataset and DataFrame API.
*
* In environments that this has been created upfront (e.g. REPL, notebooks), use the builder to
* get an existing session:
*
* {{{
* SparkSession.builder().getOrCreate()
* }}}
*
* The builder can also be used to create a new session:
*
* {{{
* SparkSession.builder
* .remote("sc://localhost:15001/myapp")
* .getOrCreate()
* }}}
*/
class SparkSession private[sql] (
private[sql] val client: SparkConnectClient,
private val planIdGenerator: AtomicLong)
extends api.SparkSession[Dataset]
with Logging {
private[this] val allocator = new RootAllocator()
private[sql] lazy val cleaner = new SessionCleaner(this)
// a unique session ID for this session from client.
private[sql] def sessionId: String = client.sessionId
lazy val version: String = {
client.analyze(proto.AnalyzePlanRequest.AnalyzeCase.SPARK_VERSION).getSparkVersion.getVersion
}
private[sql] val observationRegistry = new ConcurrentHashMap[Long, Observation]()
private[sql] def hijackServerSideSessionIdForTesting(suffix: String) = {
client.hijackServerSideSessionIdForTesting(suffix)
}
/**
* Runtime configuration interface for Spark.
*
* This is the interface through which the user can get and set all Spark configurations that
* are relevant to Spark SQL. When getting the value of a config, his defaults to the value set
* in server, if any.
*
* @since 3.4.0
*/
val conf: RuntimeConfig = new RuntimeConfig(client)
/** @inheritdoc */
@transient
val emptyDataFrame: DataFrame = emptyDataset(UnboundRowEncoder)
/** @inheritdoc */
def emptyDataset[T: Encoder]: Dataset[T] = createDataset[T](Nil)
private def createDataset[T](encoder: AgnosticEncoder[T], data: Iterator[T]): Dataset[T] = {
newDataset(encoder) { builder =>
if (data.nonEmpty) {
val arrowData = ArrowSerializer.serialize(data, encoder, allocator, timeZoneId)
if (arrowData.size() <= conf.get(SqlApiConf.LOCAL_RELATION_CACHE_THRESHOLD_KEY).toInt) {
builder.getLocalRelationBuilder
.setSchema(encoder.schema.json)
.setData(arrowData)
} else {
val hash = client.cacheLocalRelation(arrowData, encoder.schema.json)
builder.getCachedLocalRelationBuilder
.setHash(hash)
}
} else {
builder.getLocalRelationBuilder
.setSchema(encoder.schema.json)
}
}
}
/** @inheritdoc */
def createDataFrame[A <: Product: TypeTag](data: Seq[A]): DataFrame = {
createDataset(ScalaReflection.encoderFor[A], data.iterator).toDF()
}
/** @inheritdoc */
def createDataFrame(rows: java.util.List[Row], schema: StructType): DataFrame = {
createDataset(RowEncoder.encoderFor(schema), rows.iterator().asScala).toDF()
}
/** @inheritdoc */
def createDataFrame(data: java.util.List[_], beanClass: Class[_]): DataFrame = {
val encoder = JavaTypeInference.encoderFor(beanClass.asInstanceOf[Class[Any]])
createDataset(encoder, data.iterator().asScala).toDF()
}
/** @inheritdoc */
def createDataset[T: Encoder](data: Seq[T]): Dataset[T] = {
createDataset(encoderFor[T], data.iterator)
}
/** @inheritdoc */
def createDataset[T: Encoder](data: java.util.List[T]): Dataset[T] = {
createDataset(data.asScala.toSeq)
}
/** @inheritdoc */
@Experimental
def sql(sqlText: String, args: Array[_]): DataFrame = newDataFrame { builder =>
// Send the SQL once to the server and then check the output.
val cmd = newCommand(b =>
b.setSqlCommand(
proto.SqlCommand
.newBuilder()
.setSql(sqlText)
.addAllPosArguments(args.map(lit(_).expr).toImmutableArraySeq.asJava)))
val plan = proto.Plan.newBuilder().setCommand(cmd)
val responseIter = client.execute(plan.build())
try {
val response = responseIter
.find(_.hasSqlCommandResult)
.getOrElse(throw new RuntimeException("SQLCommandResult must be present"))
// Update the builder with the values from the result.
builder.mergeFrom(response.getSqlCommandResult.getRelation)
} finally {
// consume the rest of the iterator
responseIter.foreach(_ => ())
}
}
/** @inheritdoc */
@Experimental
def sql(sqlText: String, args: Map[String, Any]): DataFrame = {
sql(sqlText, args.asJava)
}
/** @inheritdoc */
@Experimental
override def sql(sqlText: String, args: java.util.Map[String, Any]): DataFrame = newDataFrame {
builder =>
// Send the SQL once to the server and then check the output.
val cmd = newCommand(b =>
b.setSqlCommand(
proto.SqlCommand
.newBuilder()
.setSql(sqlText)
.putAllNamedArguments(args.asScala.map { case (k, v) => (k, lit(v).expr) }.asJava)))
val plan = proto.Plan.newBuilder().setCommand(cmd)
val responseIter = client.execute(plan.build())
try {
val response = responseIter
.find(_.hasSqlCommandResult)
.getOrElse(throw new RuntimeException("SQLCommandResult must be present"))
// Update the builder with the values from the result.
builder.mergeFrom(response.getSqlCommandResult.getRelation)
} finally {
// consume the rest of the iterator
responseIter.foreach(_ => ())
}
}
/** @inheritdoc */
override def sql(query: String): DataFrame = {
sql(query, Array.empty)
}
/** @inheritdoc */
def read: DataFrameReader = new DataFrameReader(this)
/**
* Returns a `DataStreamReader` that can be used to read streaming data in as a `DataFrame`.
* {{{
* sparkSession.readStream.parquet("/path/to/directory/of/parquet/files")
* sparkSession.readStream.schema(schema).json("/path/to/directory/of/json/files")
* }}}
*
* @since 3.5.0
*/
def readStream: DataStreamReader = new DataStreamReader(this)
lazy val streams: StreamingQueryManager = new StreamingQueryManager(this)
/** @inheritdoc */
lazy val catalog: Catalog = new CatalogImpl(this)
/** @inheritdoc */
def table(tableName: String): DataFrame = {
read.table(tableName)
}
/** @inheritdoc */
def range(end: Long): Dataset[java.lang.Long] = range(0, end)
/** @inheritdoc */
def range(start: Long, end: Long): Dataset[java.lang.Long] = {
range(start, end, step = 1)
}
/** @inheritdoc */
def range(start: Long, end: Long, step: Long): Dataset[java.lang.Long] = {
range(start, end, step, None)
}
/** @inheritdoc */
def range(start: Long, end: Long, step: Long, numPartitions: Int): Dataset[java.lang.Long] = {
range(start, end, step, Option(numPartitions))
}
/** @inheritdoc */
lazy val udf: UDFRegistration = new UDFRegistration(this)
// scalastyle:off
// Disable style checker so "implicits" object can start with lowercase i
/**
* (Scala-specific) Implicit methods available in Scala for converting common names and Symbols
* into [[Column]]s, and for converting common Scala objects into DataFrame`s.
*
* {{{
* val sparkSession = SparkSession.builder.getOrCreate()
* import sparkSession.implicits._
* }}}
*
* @since 3.4.0
*/
object implicits extends SQLImplicits(this) with Serializable
// scalastyle:on
/** @inheritdoc */
def newSession(): SparkSession = {
SparkSession.builder().client(client.copy()).create()
}
private def range(
start: Long,
end: Long,
step: Long,
numPartitions: Option[Int]): Dataset[java.lang.Long] = {
newDataset(BoxedLongEncoder) { builder =>
val rangeBuilder = builder.getRangeBuilder
.setStart(start)
.setEnd(end)
.setStep(step)
numPartitions.foreach(rangeBuilder.setNumPartitions)
}
}
@Since("4.0.0")
@DeveloperApi
def newDataFrame(f: proto.Relation.Builder => Unit): DataFrame = {
newDataset(UnboundRowEncoder)(f)
}
@Since("4.0.0")
@DeveloperApi
def newDataset[T](encoder: AgnosticEncoder[T])(
f: proto.Relation.Builder => Unit): Dataset[T] = {
val builder = proto.Relation.newBuilder()
f(builder)
builder.getCommonBuilder.setPlanId(planIdGenerator.getAndIncrement())
val plan = proto.Plan.newBuilder().setRoot(builder).build()
new Dataset[T](this, plan, encoder)
}
private[sql] def newCommand[T](f: proto.Command.Builder => Unit): proto.Command = {
val builder = proto.Command.newBuilder()
f(builder)
builder.build()
}
private[sql] def analyze(
plan: proto.Plan,
method: proto.AnalyzePlanRequest.AnalyzeCase,
explainMode: Option[proto.AnalyzePlanRequest.Explain.ExplainMode] = None)
: proto.AnalyzePlanResponse = {
client.analyze(method, Some(plan), explainMode)
}
private[sql] def analyze(
f: proto.AnalyzePlanRequest.Builder => Unit): proto.AnalyzePlanResponse = {
val builder = proto.AnalyzePlanRequest.newBuilder()
f(builder)
client.analyze(builder)
}
private[sql] def sameSemantics(plan: proto.Plan, otherPlan: proto.Plan): Boolean = {
client.sameSemantics(plan, otherPlan).getSameSemantics.getResult
}
private[sql] def semanticHash(plan: proto.Plan): Int = {
client.semanticHash(plan).getSemanticHash.getResult
}
private[sql] def timeZoneId: String = conf.get(SqlApiConf.SESSION_LOCAL_TIMEZONE_KEY)
private[sql] def execute[T](plan: proto.Plan, encoder: AgnosticEncoder[T]): SparkResult[T] = {
val value = client.execute(plan)
new SparkResult(
value,
allocator,
encoder,
timeZoneId,
Some(setMetricsAndUnregisterObservation))
}
private[sql] def execute(f: proto.Relation.Builder => Unit): Unit = {
val builder = proto.Relation.newBuilder()
f(builder)
builder.getCommonBuilder.setPlanId(planIdGenerator.getAndIncrement())
val plan = proto.Plan.newBuilder().setRoot(builder).build()
// .foreach forces that the iterator is consumed and closed
client.execute(plan).foreach(_ => ())
}
@Since("4.0.0")
@DeveloperApi
def execute(command: proto.Command): Seq[ExecutePlanResponse] = {
val plan = proto.Plan.newBuilder().setCommand(command).build()
// .toSeq forces that the iterator is consumed and closed. On top, ignore all
// progress messages.
client.execute(plan).filter(!_.hasExecutionProgress).toSeq
}
private[sql] def execute(plan: proto.Plan): CloseableIterator[ExecutePlanResponse] =
client.execute(plan)
private[sql] def registerUdf(udf: proto.CommonInlineUserDefinedFunction): Unit = {
val command = proto.Command.newBuilder().setRegisterFunction(udf).build()
execute(command)
}
/** @inheritdoc */
@Experimental
override def addArtifact(path: String): Unit = client.addArtifact(path)
/** @inheritdoc */
@Experimental
override def addArtifact(uri: URI): Unit = client.addArtifact(uri)
/** @inheritdoc */
@Experimental
override def addArtifact(bytes: Array[Byte], target: String): Unit = {
client.addArtifact(bytes, target)
}
/** @inheritdoc */
@Experimental
override def addArtifact(source: String, target: String): Unit = {
client.addArtifact(source, target)
}
/** @inheritdoc */
@Experimental
@scala.annotation.varargs
override def addArtifacts(uri: URI*): Unit = client.addArtifacts(uri)
/**
* Register a ClassFinder for dynamically generated classes.
* @since 3.5.0
*/
@Experimental
def registerClassFinder(finder: ClassFinder): Unit = client.registerClassFinder(finder)
/**
* This resets the plan id generator so we can produce plans that are comparable.
*
* For testing only!
*/
private[sql] def resetPlanIdGenerator(): Unit = {
planIdGenerator.set(0)
}
/**
* Interrupt all operations of this session currently running on the connected server.
*
* @return
* sequence of operationIds of interrupted operations. Note: there is still a possibility of
* operation finishing just as it is interrupted.
*
* @since 3.5.0
*/
def interruptAll(): Seq[String] = {
client.interruptAll().getInterruptedIdsList.asScala.toSeq
}
/**
* Interrupt all operations of this session with the given operation tag.
*
* @return
* sequence of operationIds of interrupted operations. Note: there is still a possibility of
* operation finishing just as it is interrupted.
*
* @since 3.5.0
*/
def interruptTag(tag: String): Seq[String] = {
client.interruptTag(tag).getInterruptedIdsList.asScala.toSeq
}
/**
* Interrupt an operation of this session with the given operationId.
*
* @return
* sequence of operationIds of interrupted operations. Note: there is still a possibility of
* operation finishing just as it is interrupted.
*
* @since 3.5.0
*/
def interruptOperation(operationId: String): Seq[String] = {
client.interruptOperation(operationId).getInterruptedIdsList.asScala.toSeq
}
/**
* Close the [[SparkSession]].
*
* Release the current session and close the GRPC connection to the server. The API will not
* error if any of these operations fail. Closing a closed session is a no-op.
*
* Close the allocator. Fail if there are still open SparkResults.
*
* @since 3.4.0
*/
override def close(): Unit = {
if (releaseSessionOnClose) {
try {
client.releaseSession()
} catch {
case e: Exception => logWarning("session.stop: Failed to release session", e)
}
}
try {
client.shutdown()
} catch {
case e: Exception => logWarning("session.stop: Failed to shutdown the client", e)
}
allocator.close()
SparkSession.onSessionClose(this)
}
/**
* Add a tag to be assigned to all the operations started by this thread in this session.
*
* Often, a unit of execution in an application consists of multiple Spark executions.
* Application programmers can use this method to group all those jobs together and give a group
* tag. The application can use `org.apache.spark.sql.SparkSession.interruptTag` to cancel all
* running running executions with this tag. For example:
* {{{
* // In the main thread:
* spark.addTag("myjobs")
* spark.range(10).map(i => { Thread.sleep(10); i }).collect()
*
* // In a separate thread:
* spark.interruptTag("myjobs")
* }}}
*
* There may be multiple tags present at the same time, so different parts of application may
* use different tags to perform cancellation at different levels of granularity.
*
* @param tag
* The tag to be added. Cannot contain ',' (comma) character or be an empty string.
*
* @since 3.5.0
*/
def addTag(tag: String): Unit = {
client.addTag(tag)
}
/**
* Remove a tag previously added to be assigned to all the operations started by this thread in
* this session. Noop if such a tag was not added earlier.
*
* @param tag
* The tag to be removed. Cannot contain ',' (comma) character or be an empty string.
*
* @since 3.5.0
*/
def removeTag(tag: String): Unit = {
client.removeTag(tag)
}
/**
* Get the tags that are currently set to be assigned to all the operations started by this
* thread.
*
* @since 3.5.0
*/
def getTags(): Set[String] = {
client.getTags()
}
/**
* Clear the current thread's operation tags.
*
* @since 3.5.0
*/
def clearTags(): Unit = {
client.clearTags()
}
/**
* We cannot deserialize a connect [[SparkSession]] because of a class clash on the server side.
* We null out the instance for now.
*/
private def writeReplace(): Any = null
/**
* Set to false to prevent client.releaseSession on close() (testing only)
*/
private[sql] var releaseSessionOnClose = true
private[sql] def registerObservation(planId: Long, observation: Observation): Unit = {
observation.markRegistered()
observationRegistry.putIfAbsent(planId, observation)
}
private[sql] def setMetricsAndUnregisterObservation(planId: Long, metrics: Row): Unit = {
val observationOrNull = observationRegistry.remove(planId)
if (observationOrNull != null) {
observationOrNull.setMetricsAndNotify(metrics)
}
}
implicit class RichColumn(c: Column) {
def expr: proto.Expression = toExpr(c)
def typedExpr[T](e: Encoder[T]): proto.Expression = toTypedExpr(c, e)
}
}
// The minimal builder needed to create a spark session.
// TODO: implements all methods mentioned in the scaladoc of [[SparkSession]]
object SparkSession extends Logging {
private val MAX_CACHED_SESSIONS = 100
private val planIdGenerator = new AtomicLong
private var server: Option[Process] = None
private[sql] val sparkOptions = sys.props.filter { p =>
p._1.startsWith("spark.") && p._2.nonEmpty
}.toMap
private val sessions = CacheBuilder
.newBuilder()
.weakValues()
.maximumSize(MAX_CACHED_SESSIONS)
.build(new CacheLoader[Configuration, SparkSession] {
override def load(c: Configuration): SparkSession = create(c)
})
/** The active SparkSession for the current thread. */
private val activeThreadSession = new InheritableThreadLocal[SparkSession]
/** Reference to the root SparkSession. */
private val defaultSession = new AtomicReference[SparkSession]
/**
* Set the (global) default [[SparkSession]], and (thread-local) active [[SparkSession]] when
* they are not set yet or the associated [[SparkConnectClient]] is unusable.
*/
private def setDefaultAndActiveSession(session: SparkSession): Unit = {
val currentDefault = defaultSession.getAcquire
if (currentDefault == null || !currentDefault.client.isSessionValid) {
// Update `defaultSession` if it is null or the contained session is not valid. There is a
// chance that the following `compareAndSet` fails if a new default session has just been set,
// but that does not matter since that event has happened after this method was invoked.
defaultSession.compareAndSet(currentDefault, session)
}
if (getActiveSession.isEmpty) {
setActiveSession(session)
}
}
/**
* Create a new Spark Connect server to connect locally.
*/
private[sql] def withLocalConnectServer[T](f: => T): T = {
synchronized {
val remoteString = sparkOptions
.get("spark.remote")
.orElse(Option(System.getProperty("spark.remote"))) // Set from Spark Submit
.orElse(sys.env.get(SparkConnectClient.SPARK_REMOTE))
val maybeConnectScript =
Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "start-connect-server.sh"))
if (server.isEmpty &&
remoteString.exists(_.startsWith("local")) &&
maybeConnectScript.exists(Files.exists(_))) {
server = Some {
val args =
Seq(maybeConnectScript.get.toString, "--master", remoteString.get) ++ sparkOptions
.filter(p => !p._1.startsWith("spark.remote"))
.flatMap { case (k, v) => Seq("--conf", s"$k=$v") }
val pb = new ProcessBuilder(args: _*)
// So don't exclude spark-sql jar in classpath
pb.environment().remove(SparkConnectClient.SPARK_REMOTE)
pb.start()
}
// Let the server start. We will directly request to set the configurations
// and this sleep makes less noisy with retries.
Thread.sleep(2000L)
System.setProperty("spark.remote", "sc://localhost")
// scalastyle:off runtimeaddshutdownhook
Runtime.getRuntime.addShutdownHook(new Thread() {
override def run(): Unit = if (server.isDefined) {
new ProcessBuilder(maybeConnectScript.get.toString)
.start()
}
})
// scalastyle:on runtimeaddshutdownhook
}
}
f
}
/**
* Create a new [[SparkSession]] based on the connect client [[Configuration]].
*/
private[sql] def create(configuration: Configuration): SparkSession = {
new SparkSession(configuration.toSparkConnectClient, planIdGenerator)
}
/**
* Hook called when a session is closed.
*/
private[sql] def onSessionClose(session: SparkSession): Unit = {
sessions.invalidate(session.client.configuration)
defaultSession.compareAndSet(session, null)
if (getActiveSession.contains(session)) {
clearActiveSession()
}
}
/**
* Creates a [[SparkSession.Builder]] for constructing a [[SparkSession]].
*
* @since 3.4.0
*/
def builder(): Builder = new Builder()
class Builder() extends Logging {
// Initialize the connection string of the Spark Connect client builder from SPARK_REMOTE
// by default, if it exists. The connection string can be overridden using
// the remote() function, as it takes precedence over the SPARK_REMOTE environment variable.
private val builder = SparkConnectClient.builder().loadFromEnvironment()
private var client: SparkConnectClient = _
private[this] val options = new scala.collection.mutable.HashMap[String, String]
def remote(connectionString: String): Builder = {
builder.connectionString(connectionString)
this
}
/**
* Add an interceptor to be used during channel creation.
*
* Note that interceptors added last are executed first by gRPC.
*
* @since 3.5.0
*/
def interceptor(interceptor: ClientInterceptor): Builder = {
builder.interceptor(interceptor)
this
}
private[sql] def client(client: SparkConnectClient): Builder = {
this.client = client
this
}
/**
* Sets a config option. Options set using this method are automatically propagated to the
* Spark Connect session. Only runtime options are supported.
*
* @since 3.5.0
*/
def config(key: String, value: String): Builder = synchronized {
options += key -> value
this
}
/**
* Sets a config option. Options set using this method are automatically propagated to the
* Spark Connect session. Only runtime options are supported.
*
* @since 3.5.0
*/
def config(key: String, value: Long): Builder = synchronized {
options += key -> value.toString
this
}
/**
* Sets a config option. Options set using this method are automatically propagated to the
* Spark Connect session. Only runtime options are supported.
*
* @since 3.5.0
*/
def config(key: String, value: Double): Builder = synchronized {
options += key -> value.toString
this
}
/**
* Sets a config option. Options set using this method are automatically propagated to the
* Spark Connect session. Only runtime options are supported.
*
* @since 3.5.0
*/
def config(key: String, value: Boolean): Builder = synchronized {
options += key -> value.toString
this
}
/**
* Sets a config a map of options. Options set using this method are automatically propagated
* to the Spark Connect session. Only runtime options are supported.
*
* @since 3.5.0
*/
def config(map: Map[String, Any]): Builder = synchronized {
map.foreach { kv: (String, Any) =>
{
options += kv._1 -> kv._2.toString
}
}
this
}
/**
* Sets a config option. Options set using this method are automatically propagated to both
* `SparkConf` and SparkSession's own configuration.
*
* @since 3.5.0
*/
def config(map: java.util.Map[String, Any]): Builder = synchronized {
config(map.asScala.toMap)
}
@deprecated("enableHiveSupport does not work in Spark Connect")
def enableHiveSupport(): Builder = this
@deprecated("master does not work in Spark Connect, please use remote instead")
def master(master: String): Builder = this
@deprecated("appName does not work in Spark Connect")
def appName(name: String): Builder = this
private def tryCreateSessionFromClient(): Option[SparkSession] = {
if (client != null && client.isSessionValid) {
Option(new SparkSession(client, planIdGenerator))
} else {
None
}
}
private def applyOptions(session: SparkSession): Unit = {
// Only attempts to set Spark SQL configurations.
// If the configurations are static, it might throw an exception so
// simply ignore it for now.
sparkOptions
.filter { case (k, _) =>
k.startsWith("spark.sql.")
}
.foreach { case (key, value) =>
Try(session.conf.set(key, value))
}
options.foreach { case (key, value) =>
session.conf.set(key, value)
}
}
/**
* Build the [[SparkSession]].
*
* This will always return a newly created session.
*/
@deprecated(message = "Please use create() instead.", since = "3.5.0")
def build(): SparkSession = create()
/**
* Create a new [[SparkSession]].
*
* This will always return a newly created session.
*
* This method will update the default and/or active session if they are not set.
*
* @since 3.5.0
*/
def create(): SparkSession = withLocalConnectServer {
val session = tryCreateSessionFromClient()
.getOrElse(SparkSession.this.create(builder.configuration))
setDefaultAndActiveSession(session)
applyOptions(session)
session
}
/**
* Get or create a [[SparkSession]].
*
* If a session exist with the same configuration that is returned instead of creating a new
* session.
*
* This method will update the default and/or active session if they are not set. This method
* will always set the specified configuration options on the session, even when it is not
* newly created.
*
* @since 3.5.0
*/
def getOrCreate(): SparkSession = withLocalConnectServer {
val session = tryCreateSessionFromClient()
.getOrElse({
var existingSession = sessions.get(builder.configuration)
if (!existingSession.client.isSessionValid) {
// If the cached session has become invalid, e.g., due to a server restart, the cache
// entry is invalidated.
sessions.invalidate(builder.configuration)
existingSession = sessions.get(builder.configuration)
}
existingSession
})
setDefaultAndActiveSession(session)
applyOptions(session)
session
}
}
/**
* Returns the default SparkSession. If the previously set default SparkSession becomes
* unusable, returns None.
*
* @since 3.5.0
*/
def getDefaultSession: Option[SparkSession] =
Option(defaultSession.get()).filter(_.client.isSessionValid)
/**
* Sets the default SparkSession.
*
* @since 3.5.0
*/
def setDefaultSession(session: SparkSession): Unit = {
defaultSession.set(session)
}
/**
* Clears the default SparkSession.
*
* @since 3.5.0
*/
def clearDefaultSession(): Unit = {
defaultSession.set(null)
}
/**
* Returns the active SparkSession for the current thread. If the previously set active
* SparkSession becomes unusable, returns None.
*
* @since 3.5.0
*/
def getActiveSession: Option[SparkSession] =
Option(activeThreadSession.get()).filter(_.client.isSessionValid)
/**
* Changes the SparkSession that will be returned in this thread and its children when
* SparkSession.getOrCreate() is called. This can be used to ensure that a given thread receives
* an isolated SparkSession.
*
* @since 3.5.0
*/
def setActiveSession(session: SparkSession): Unit = {
activeThreadSession.set(session)
}
/**
* Clears the active SparkSession for current thread.
*
* @since 3.5.0
*/
def clearActiveSession(): Unit = {
activeThreadSession.remove()
}
/**
* Returns the currently active SparkSession, otherwise the default one. If there is no default
* SparkSession, throws an exception.
*
* @since 3.5.0
*/
def active: SparkSession = {
getActiveSession
.orElse(getDefaultSession)
.getOrElse(throw new IllegalStateException("No active or default Spark session found"))
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy