org.apache.spark.sql.SparkSession.scala Maven / Gradle / Ivy
The newest version!
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql
import java.io.Closeable
import java.net.URI
import java.util.concurrent.TimeUnit._
import java.util.concurrent.atomic.{AtomicLong, AtomicReference}
import scala.collection.JavaConverters._
import scala.reflect.runtime.universe.TypeTag
import com.google.common.cache.{CacheBuilder, CacheLoader}
import io.grpc.ClientInterceptor
import org.apache.arrow.memory.RootAllocator
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.ExecutePlanResponse
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalog.Catalog
import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection}
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BoxedLongEncoder, UnboundRowEncoder}
import org.apache.spark.sql.connect.client.{ClassFinder, SparkConnectClient, SparkResult}
import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration
import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
import org.apache.spark.sql.connect.client.util.Cleaner
import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto
import org.apache.spark.sql.internal.{CatalogImpl, SqlApiConf}
import org.apache.spark.sql.streaming.DataStreamReader
import org.apache.spark.sql.streaming.StreamingQueryManager
import org.apache.spark.sql.types.StructType
/**
* The entry point to programming Spark with the Dataset and DataFrame API.
*
* In environments that this has been created upfront (e.g. REPL, notebooks), use the builder to
* get an existing session:
*
* {{{
* SparkSession.builder().getOrCreate()
* }}}
*
* The builder can also be used to create a new session:
*
* {{{
* SparkSession.builder
* .remote("sc://localhost:15001/myapp")
* .getOrCreate()
* }}}
*/
class SparkSession private[sql] (
private[sql] val client: SparkConnectClient,
private val cleaner: Cleaner,
private val planIdGenerator: AtomicLong)
extends Serializable
with Closeable
with Logging {
private[this] val allocator = new RootAllocator()
// a unique session ID for this session from client.
private[sql] def sessionId: String = client.sessionId
lazy val version: String = {
client.analyze(proto.AnalyzePlanRequest.AnalyzeCase.SPARK_VERSION).getSparkVersion.getVersion
}
/**
* Runtime configuration interface for Spark.
*
* This is the interface through which the user can get and set all Spark configurations that
* are relevant to Spark SQL. When getting the value of a config, his defaults to the value set
* in server, if any.
*
* @since 3.4.0
*/
val conf: RuntimeConfig = new RuntimeConfig(client)
/**
* Executes some code block and prints to stdout the time taken to execute the block. This is
* available in Scala only and is used primarily for interactive testing and debugging.
*
* @since 3.4.0
*/
def time[T](f: => T): T = {
val start = System.nanoTime()
val ret = f
val end = System.nanoTime()
// scalastyle:off println
println(s"Time taken: ${NANOSECONDS.toMillis(end - start)} ms")
// scalastyle:on println
ret
}
/**
* Returns a `DataFrame` with no rows or columns.
*
* @since 3.4.0
*/
@transient
val emptyDataFrame: DataFrame = emptyDataset(UnboundRowEncoder)
/**
* Creates a new [[Dataset]] of type T containing zero elements.
*
* @since 3.4.0
*/
def emptyDataset[T: Encoder]: Dataset[T] = createDataset[T](Nil)
private def createDataset[T](encoder: AgnosticEncoder[T], data: Iterator[T]): Dataset[T] = {
newDataset(encoder) { builder =>
if (data.nonEmpty) {
val arrowData = ArrowSerializer.serialize(data, encoder, allocator, timeZoneId)
if (arrowData.size() <= conf.get(SqlApiConf.LOCAL_RELATION_CACHE_THRESHOLD_KEY).toInt) {
builder.getLocalRelationBuilder
.setSchema(encoder.schema.json)
.setData(arrowData)
} else {
val hash = client.cacheLocalRelation(arrowData, encoder.schema.json)
builder.getCachedLocalRelationBuilder
.setHash(hash)
}
} else {
builder.getLocalRelationBuilder
.setSchema(encoder.schema.json)
}
}
}
/**
* Creates a `DataFrame` from a local Seq of Product.
*
* @since 3.4.0
*/
def createDataFrame[A <: Product: TypeTag](data: Seq[A]): DataFrame = {
createDataset(ScalaReflection.encoderFor[A], data.iterator).toDF()
}
/**
* :: DeveloperApi :: Creates a `DataFrame` from a `java.util.List` containing [[Row]]s using
* the given schema. It is important to make sure that the structure of every [[Row]] of the
* provided List matches the provided schema. Otherwise, there will be runtime exception.
*
* @since 3.4.0
*/
def createDataFrame(rows: java.util.List[Row], schema: StructType): DataFrame = {
createDataset(RowEncoder.encoderFor(schema), rows.iterator().asScala).toDF()
}
/**
* Applies a schema to a List of Java Beans.
*
* WARNING: Since there is no guaranteed ordering for fields in a Java Bean, SELECT * queries
* will return the columns in an undefined order.
* @since 3.4.0
*/
def createDataFrame(data: java.util.List[_], beanClass: Class[_]): DataFrame = {
val encoder = JavaTypeInference.encoderFor(beanClass.asInstanceOf[Class[Any]])
createDataset(encoder, data.iterator().asScala).toDF()
}
/**
* Creates a [[Dataset]] from a local Seq of data of a given type. This method requires an
* encoder (to convert a JVM object of type `T` to and from the internal Spark SQL
* representation) that is generally created automatically through implicits from a
* `SparkSession`, or can be created explicitly by calling static methods on [[Encoders]].
*
* ==Example==
*
* {{{
*
* import spark.implicits._
* case class Person(name: String, age: Long)
* val data = Seq(Person("Michael", 29), Person("Andy", 30), Person("Justin", 19))
* val ds = spark.createDataset(data)
*
* ds.show()
* // +-------+---+
* // | name|age|
* // +-------+---+
* // |Michael| 29|
* // | Andy| 30|
* // | Justin| 19|
* // +-------+---+
* }}}
*
* @since 3.4.0
*/
def createDataset[T: Encoder](data: Seq[T]): Dataset[T] = {
createDataset(encoderFor[T], data.iterator)
}
/**
* Creates a [[Dataset]] from a `java.util.List` of a given type. This method requires an
* encoder (to convert a JVM object of type `T` to and from the internal Spark SQL
* representation) that is generally created automatically through implicits from a
* `SparkSession`, or can be created explicitly by calling static methods on [[Encoders]].
*
* ==Java Example==
*
* {{{
* List data = Arrays.asList("hello", "world");
* Dataset ds = spark.createDataset(data, Encoders.STRING());
* }}}
*
* @since 3.4.0
*/
def createDataset[T: Encoder](data: java.util.List[T]): Dataset[T] = {
createDataset(data.asScala.toSeq)
}
/**
* Executes a SQL query substituting positional parameters by the given arguments, returning the
* result as a `DataFrame`. This API eagerly runs DDL/DML commands, but not for SELECT queries.
*
* @param sqlText
* A SQL statement with positional parameters to execute.
* @param args
* An array of Java/Scala objects that can be converted to SQL literal expressions. See Supported Data
* Types for supported value types in Scala/Java. For example: 1, "Steven",
* LocalDate.of(2023, 4, 2). A value can be also a `Column` of literal expression, in that
* case it is taken as is.
*
* @since 3.5.0
*/
@Experimental
def sql(sqlText: String, args: Array[_]): DataFrame = newDataFrame { builder =>
// Send the SQL once to the server and then check the output.
val cmd = newCommand(b =>
b.setSqlCommand(
proto.SqlCommand
.newBuilder()
.setSql(sqlText)
.addAllPosArgs(args.map(toLiteralProto).toIterable.asJava)))
val plan = proto.Plan.newBuilder().setCommand(cmd)
// .toBuffer forces that the iterator is consumed and closed
val responseSeq = client.execute(plan.build()).toBuffer.toSeq
val response = responseSeq
.find(_.hasSqlCommandResult)
.getOrElse(throw new RuntimeException("SQLCommandResult must be present"))
// Update the builder with the values from the result.
builder.mergeFrom(response.getSqlCommandResult.getRelation)
}
/**
* Executes a SQL query substituting named parameters by the given arguments, returning the
* result as a `DataFrame`. This API eagerly runs DDL/DML commands, but not for SELECT queries.
*
* @param sqlText
* A SQL statement with named parameters to execute.
* @param args
* A map of parameter names to Java/Scala objects that can be converted to SQL literal
* expressions. See
* Supported Data Types for supported value types in Scala/Java. For example, map keys:
* "rank", "name", "birthdate"; map values: 1, "Steven", LocalDate.of(2023, 4, 2). Map value
* can be also a `Column` of literal expression, in that case it is taken as is.
*
* @since 3.4.0
*/
@Experimental
def sql(sqlText: String, args: Map[String, Any]): DataFrame = {
sql(sqlText, args.asJava)
}
/**
* Executes a SQL query substituting named parameters by the given arguments, returning the
* result as a `DataFrame`. This API eagerly runs DDL/DML commands, but not for SELECT queries.
*
* @param sqlText
* A SQL statement with named parameters to execute.
* @param args
* A map of parameter names to Java/Scala objects that can be converted to SQL literal
* expressions. See
* Supported Data Types for supported value types in Scala/Java. For example, map keys:
* "rank", "name", "birthdate"; map values: 1, "Steven", LocalDate.of(2023, 4, 2). Map value
* can be also a `Column` of literal expression, in that case it is taken as is.
*
* @since 3.4.0
*/
@Experimental
def sql(sqlText: String, args: java.util.Map[String, Any]): DataFrame = newDataFrame {
builder =>
// Send the SQL once to the server and then check the output.
val cmd = newCommand(b =>
b.setSqlCommand(
proto.SqlCommand
.newBuilder()
.setSql(sqlText)
.putAllArgs(args.asScala.mapValues(toLiteralProto).toMap.asJava)))
val plan = proto.Plan.newBuilder().setCommand(cmd)
// .toBuffer forces that the iterator is consumed and closed
val responseSeq = client.execute(plan.build()).toBuffer.toSeq
val response = responseSeq
.find(_.hasSqlCommandResult)
.getOrElse(throw new RuntimeException("SQLCommandResult must be present"))
// Update the builder with the values from the result.
builder.mergeFrom(response.getSqlCommandResult.getRelation)
}
/**
* Executes a SQL query using Spark, returning the result as a `DataFrame`. This API eagerly
* runs DDL/DML commands, but not for SELECT queries.
*
* @since 3.4.0
*/
def sql(query: String): DataFrame = {
sql(query, Array.empty)
}
/**
* Returns a [[DataFrameReader]] that can be used to read non-streaming data in as a
* `DataFrame`.
* {{{
* sparkSession.read.parquet("/path/to/file.parquet")
* sparkSession.read.schema(schema).json("/path/to/file.json")
* }}}
*
* @since 3.4.0
*/
def read: DataFrameReader = new DataFrameReader(this)
/**
* Returns a `DataStreamReader` that can be used to read streaming data in as a `DataFrame`.
* {{{
* sparkSession.readStream.parquet("/path/to/directory/of/parquet/files")
* sparkSession.readStream.schema(schema).json("/path/to/directory/of/json/files")
* }}}
*
* @since 3.5.0
*/
def readStream: DataStreamReader = new DataStreamReader(this)
lazy val streams: StreamingQueryManager = new StreamingQueryManager(this)
/**
* Interface through which the user may create, drop, alter or query underlying databases,
* tables, functions etc.
*
* @since 3.5.0
*/
lazy val catalog: Catalog = new CatalogImpl(this)
/**
* Returns the specified table/view as a `DataFrame`. If it's a table, it must support batch
* reading and the returned DataFrame is the batch scan query plan of this table. If it's a
* view, the returned DataFrame is simply the query plan of the view, which can either be a
* batch or streaming query plan.
*
* @param tableName
* is either a qualified or unqualified name that designates a table or view. If a database is
* specified, it identifies the table/view from the database. Otherwise, it first attempts to
* find a temporary view with the given name and then match the table/view from the current
* database. Note that, the global temporary view database is also valid here.
* @since 3.4.0
*/
def table(tableName: String): DataFrame = {
read.table(tableName)
}
/**
* Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements in a
* range from 0 to `end` (exclusive) with step value 1.
*
* @since 3.4.0
*/
def range(end: Long): Dataset[java.lang.Long] = range(0, end)
/**
* Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements in a
* range from `start` to `end` (exclusive) with step value 1.
*
* @since 3.4.0
*/
def range(start: Long, end: Long): Dataset[java.lang.Long] = {
range(start, end, step = 1)
}
/**
* Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements in a
* range from `start` to `end` (exclusive) with a step value.
*
* @since 3.4.0
*/
def range(start: Long, end: Long, step: Long): Dataset[java.lang.Long] = {
range(start, end, step, None)
}
/**
* Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements in a
* range from `start` to `end` (exclusive) with a step value, with partition number specified.
*
* @since 3.4.0
*/
def range(start: Long, end: Long, step: Long, numPartitions: Int): Dataset[java.lang.Long] = {
range(start, end, step, Option(numPartitions))
}
/**
* A collection of methods for registering user-defined functions (UDF).
*
* The following example registers a Scala closure as UDF:
* {{{
* sparkSession.udf.register("myUDF", (arg1: Int, arg2: String) => arg2 + arg1)
* }}}
*
* The following example registers a UDF in Java:
* {{{
* sparkSession.udf().register("myUDF",
* (Integer arg1, String arg2) -> arg2 + arg1,
* DataTypes.StringType);
* }}}
*
* @note
* The user-defined functions must be deterministic. Due to optimization, duplicate
* invocations may be eliminated or the function may even be invoked more times than it is
* present in the query.
*
* @since 3.5.0
*/
lazy val udf: UDFRegistration = new UDFRegistration(this)
// scalastyle:off
// Disable style checker so "implicits" object can start with lowercase i
/**
* (Scala-specific) Implicit methods available in Scala for converting common names and
* [[Symbol]]s into [[Column]]s, and for converting common Scala objects into `DataFrame`s.
*
* {{{
* val sparkSession = SparkSession.builder.getOrCreate()
* import sparkSession.implicits._
* }}}
*
* @since 3.4.0
*/
object implicits extends SQLImplicits(this) with Serializable
// scalastyle:on
def newSession(): SparkSession = {
SparkSession.builder().client(client.copy()).create()
}
private def range(
start: Long,
end: Long,
step: Long,
numPartitions: Option[Int]): Dataset[java.lang.Long] = {
newDataset(BoxedLongEncoder) { builder =>
val rangeBuilder = builder.getRangeBuilder
.setStart(start)
.setEnd(end)
.setStep(step)
numPartitions.foreach(rangeBuilder.setNumPartitions)
}
}
private[sql] def newDataFrame(f: proto.Relation.Builder => Unit): DataFrame = {
newDataset(UnboundRowEncoder)(f)
}
private[sql] def newDataset[T](encoder: AgnosticEncoder[T])(
f: proto.Relation.Builder => Unit): Dataset[T] = {
val builder = proto.Relation.newBuilder()
f(builder)
builder.getCommonBuilder.setPlanId(planIdGenerator.getAndIncrement())
val plan = proto.Plan.newBuilder().setRoot(builder).build()
new Dataset[T](this, plan, encoder)
}
@DeveloperApi
def newDataFrame(extension: com.google.protobuf.Any): DataFrame = {
newDataset(extension, UnboundRowEncoder)
}
@DeveloperApi
def newDataset[T](
extension: com.google.protobuf.Any,
encoder: AgnosticEncoder[T]): Dataset[T] = {
newDataset(encoder)(_.setExtension(extension))
}
private[sql] def newCommand[T](f: proto.Command.Builder => Unit): proto.Command = {
val builder = proto.Command.newBuilder()
f(builder)
builder.build()
}
private[sql] def analyze(
plan: proto.Plan,
method: proto.AnalyzePlanRequest.AnalyzeCase,
explainMode: Option[proto.AnalyzePlanRequest.Explain.ExplainMode] = None)
: proto.AnalyzePlanResponse = {
client.analyze(method, Some(plan), explainMode)
}
private[sql] def analyze(
f: proto.AnalyzePlanRequest.Builder => Unit): proto.AnalyzePlanResponse = {
val builder = proto.AnalyzePlanRequest.newBuilder()
f(builder)
client.analyze(builder)
}
private[sql] def sameSemantics(plan: proto.Plan, otherPlan: proto.Plan): Boolean = {
client.sameSemantics(plan, otherPlan).getSameSemantics.getResult
}
private[sql] def semanticHash(plan: proto.Plan): Int = {
client.semanticHash(plan).getSemanticHash.getResult
}
private[sql] def timeZoneId: String = conf.get(SqlApiConf.SESSION_LOCAL_TIMEZONE_KEY)
private[sql] def execute[T](plan: proto.Plan, encoder: AgnosticEncoder[T]): SparkResult[T] = {
val value = client.execute(plan)
val result = new SparkResult(value, allocator, encoder, timeZoneId)
cleaner.register(result)
result
}
private[sql] def execute(f: proto.Relation.Builder => Unit): Unit = {
val builder = proto.Relation.newBuilder()
f(builder)
builder.getCommonBuilder.setPlanId(planIdGenerator.getAndIncrement())
val plan = proto.Plan.newBuilder().setRoot(builder).build()
// .toBuffer forces that the iterator is consumed and closed
client.execute(plan).toBuffer
}
private[sql] def execute(command: proto.Command): Seq[ExecutePlanResponse] = {
val plan = proto.Plan.newBuilder().setCommand(command).build()
// .toBuffer forces that the iterator is consumed and closed
client.execute(plan).toBuffer.toSeq
}
private[sql] def registerUdf(udf: proto.CommonInlineUserDefinedFunction): Unit = {
val command = proto.Command.newBuilder().setRegisterFunction(udf).build()
execute(command)
}
@DeveloperApi
def execute(extension: com.google.protobuf.Any): Unit = {
val command = proto.Command.newBuilder().setExtension(extension).build()
execute(command)
}
/**
* Add a single artifact to the client session.
*
* Currently only local files with extensions .jar and .class are supported.
*
* @since 3.4.0
*/
@Experimental
def addArtifact(path: String): Unit = client.addArtifact(path)
/**
* Add a single artifact to the client session.
*
* Currently only local files with extensions .jar and .class are supported.
*
* @since 3.4.0
*/
@Experimental
def addArtifact(uri: URI): Unit = client.addArtifact(uri)
/**
* Add one or more artifacts to the session.
*
* Currently only local files with extensions .jar and .class are supported.
*
* @since 3.4.0
*/
@Experimental
@scala.annotation.varargs
def addArtifacts(uri: URI*): Unit = client.addArtifacts(uri)
/**
* Register a [[ClassFinder]] for dynamically generated classes.
*
* @since 3.5.0
*/
@Experimental
def registerClassFinder(finder: ClassFinder): Unit = client.registerClassFinder(finder)
/**
* This resets the plan id generator so we can produce plans that are comparable.
*
* For testing only!
*/
private[sql] def resetPlanIdGenerator(): Unit = {
planIdGenerator.set(0)
}
/**
* Interrupt all operations of this session currently running on the connected server.
*
* @return
* sequence of operationIds of interrupted operations. Note: there is still a possibility of
* operation finishing just as it is interrupted.
*
* @since 3.5.0
*/
def interruptAll(): Seq[String] = {
client.interruptAll().getInterruptedIdsList.asScala.toSeq
}
/**
* Interrupt all operations of this session with the given operation tag.
*
* @return
* sequence of operationIds of interrupted operations. Note: there is still a possibility of
* operation finishing just as it is interrupted.
*
* @since 3.5.0
*/
def interruptTag(tag: String): Seq[String] = {
client.interruptTag(tag).getInterruptedIdsList.asScala.toSeq
}
/**
* Interrupt an operation of this session with the given operationId.
*
* @return
* sequence of operationIds of interrupted operations. Note: there is still a possibility of
* operation finishing just as it is interrupted.
*
* @since 3.5.0
*/
def interruptOperation(operationId: String): Seq[String] = {
client.interruptOperation(operationId).getInterruptedIdsList.asScala.toSeq
}
/**
* Synonym for `close()`.
*
* @since 3.4.0
*/
def stop(): Unit = close()
/**
* Close the [[SparkSession]]. This closes the connection, and the allocator. The latter will
* throw an exception if there are still open [[SparkResult]]s.
*
* @since 3.4.0
*/
override def close(): Unit = {
client.shutdown()
allocator.close()
SparkSession.onSessionClose(this)
}
/**
* Add a tag to be assigned to all the operations started by this thread in this session.
*
* @param tag
* The tag to be added. Cannot contain ',' (comma) character or be an empty string.
*
* @since 3.5.0
*/
def addTag(tag: String): Unit = {
client.addTag(tag)
}
/**
* Remove a tag previously added to be assigned to all the operations started by this thread in
* this session. Noop if such a tag was not added earlier.
*
* @param tag
* The tag to be removed. Cannot contain ',' (comma) character or be an empty string.
*
* @since 3.5.0
*/
def removeTag(tag: String): Unit = {
client.removeTag(tag)
}
/**
* Get the tags that are currently set to be assigned to all the operations started by this
* thread.
*
* @since 3.5.0
*/
def getTags(): Set[String] = {
client.getTags()
}
/**
* Clear the current thread's operation tags.
*
* @since 3.5.0
*/
def clearTags(): Unit = {
client.clearTags()
}
/**
* We cannot deserialize a connect [[SparkSession]] because of a class clash on the server side.
* We null out the instance for now.
*/
private def writeReplace(): Any = null
}
// The minimal builder needed to create a spark session.
// TODO: implements all methods mentioned in the scaladoc of [[SparkSession]]
object SparkSession extends Logging {
private val MAX_CACHED_SESSIONS = 100
private val planIdGenerator = new AtomicLong
private val sessions = CacheBuilder
.newBuilder()
.weakValues()
.maximumSize(MAX_CACHED_SESSIONS)
.build(new CacheLoader[Configuration, SparkSession] {
override def load(c: Configuration): SparkSession = create(c)
})
/** The active SparkSession for the current thread. */
private val activeThreadSession = new InheritableThreadLocal[SparkSession]
/** Reference to the root SparkSession. */
private val defaultSession = new AtomicReference[SparkSession]
/**
* Set the (global) default [[SparkSession]], and (thread-local) active [[SparkSession]] when
* they are not set yet.
*/
private def setDefaultAndActiveSession(session: SparkSession): Unit = {
defaultSession.compareAndSet(null, session)
if (getActiveSession.isEmpty) {
setActiveSession(session)
}
}
/**
* Create a new [[SparkSession]] based on the connect client [[Configuration]].
*/
private[sql] def create(configuration: Configuration): SparkSession = {
new SparkSession(configuration.toSparkConnectClient, cleaner, planIdGenerator)
}
/**
* Hook called when a session is closed.
*/
private[sql] def onSessionClose(session: SparkSession): Unit = {
sessions.invalidate(session.client.configuration)
defaultSession.compareAndSet(session, null)
if (getActiveSession.contains(session)) {
clearActiveSession()
}
}
/**
* Creates a [[SparkSession.Builder]] for constructing a [[SparkSession]].
*
* @since 3.4.0
*/
def builder(): Builder = new Builder()
private[sql] lazy val cleaner = {
val cleaner = new Cleaner
cleaner.start()
cleaner
}
class Builder() extends Logging {
// Initialize the connection string of the Spark Connect client builder from SPARK_REMOTE
// by default, if it exists. The connection string can be overridden using
// the remote() function, as it takes precedence over the SPARK_REMOTE environment variable.
private val builder = SparkConnectClient.builder().loadFromEnvironment()
private var client: SparkConnectClient = _
private[this] val options = new scala.collection.mutable.HashMap[String, String]
def remote(connectionString: String): Builder = {
builder.connectionString(connectionString)
this
}
/**
* Add an interceptor [[ClientInterceptor]] to be used during channel creation.
*
* Note that interceptors added last are executed first by gRPC.
*
* @since 3.5.0
*/
def interceptor(interceptor: ClientInterceptor): Builder = {
builder.interceptor(interceptor)
this
}
private[sql] def client(client: SparkConnectClient): Builder = {
this.client = client
this
}
/**
* Sets a config option. Options set using this method are automatically propagated to the
* Spark Connect session. Only runtime options are supported.
*
* @since 3.5.0
*/
def config(key: String, value: String): Builder = synchronized {
options += key -> value
this
}
/**
* Sets a config option. Options set using this method are automatically propagated to the
* Spark Connect session. Only runtime options are supported.
*
* @since 3.5.0
*/
def config(key: String, value: Long): Builder = synchronized {
options += key -> value.toString
this
}
/**
* Sets a config option. Options set using this method are automatically propagated to the
* Spark Connect session. Only runtime options are supported.
*
* @since 3.5.0
*/
def config(key: String, value: Double): Builder = synchronized {
options += key -> value.toString
this
}
/**
* Sets a config option. Options set using this method are automatically propagated to the
* Spark Connect session. Only runtime options are supported.
*
* @since 3.5.0
*/
def config(key: String, value: Boolean): Builder = synchronized {
options += key -> value.toString
this
}
/**
* Sets a config a map of options. Options set using this method are automatically propagated
* to the Spark Connect session. Only runtime options are supported.
*
* @since 3.5.0
*/
def config(map: Map[String, Any]): Builder = synchronized {
map.foreach { kv: (String, Any) =>
{
options += kv._1 -> kv._2.toString
}
}
this
}
/**
* Sets a config option. Options set using this method are automatically propagated to both
* `SparkConf` and SparkSession's own configuration.
*
* @since 3.5.0
*/
def config(map: java.util.Map[String, Any]): Builder = synchronized {
config(map.asScala.toMap)
}
@deprecated("enableHiveSupport does not work in Spark Connect")
def enableHiveSupport(): Builder = this
@deprecated("master does not work in Spark Connect, please use remote instead")
def master(master: String): Builder = this
@deprecated("appName does not work in Spark Connect")
def appName(name: String): Builder = this
private def tryCreateSessionFromClient(): Option[SparkSession] = {
if (client != null) {
Option(new SparkSession(client, cleaner, planIdGenerator))
} else {
None
}
}
private def applyOptions(session: SparkSession): Unit = {
options.foreach { case (key, value) =>
session.conf.set(key, value)
}
}
/**
* Build the [[SparkSession]].
*
* This will always return a newly created session.
*/
@deprecated(message = "Please use create() instead.", since = "3.5.0")
def build(): SparkSession = create()
/**
* Create a new [[SparkSession]].
*
* This will always return a newly created session.
*
* This method will update the default and/or active session if they are not set.
*
* @since 3.5.0
*/
def create(): SparkSession = {
val session = tryCreateSessionFromClient()
.getOrElse(SparkSession.this.create(builder.configuration))
setDefaultAndActiveSession(session)
applyOptions(session)
session
}
/**
* Get or create a [[SparkSession]].
*
* If a session exist with the same configuration that is returned instead of creating a new
* session.
*
* This method will update the default and/or active session if they are not set. This method
* will always set the specified configuration options on the session, even when it is not
* newly created.
*
* @since 3.5.0
*/
def getOrCreate(): SparkSession = {
val session = tryCreateSessionFromClient()
.getOrElse(sessions.get(builder.configuration))
setDefaultAndActiveSession(session)
applyOptions(session)
session
}
}
/**
* Returns the default SparkSession.
*
* @since 3.5.0
*/
def getDefaultSession: Option[SparkSession] = Option(defaultSession.get())
/**
* Sets the default SparkSession.
*
* @since 3.5.0
*/
def setDefaultSession(session: SparkSession): Unit = {
defaultSession.set(session)
}
/**
* Clears the default SparkSession.
*
* @since 3.5.0
*/
def clearDefaultSession(): Unit = {
defaultSession.set(null)
}
/**
* Returns the active SparkSession for the current thread.
*
* @since 3.5.0
*/
def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get())
/**
* Changes the SparkSession that will be returned in this thread and its children when
* SparkSession.getOrCreate() is called. This can be used to ensure that a given thread receives
* an isolated SparkSession.
*
* @since 3.5.0
*/
def setActiveSession(session: SparkSession): Unit = {
activeThreadSession.set(session)
}
/**
* Clears the active SparkSession for current thread.
*
* @since 3.5.0
*/
def clearActiveSession(): Unit = {
activeThreadSession.remove()
}
/**
* Returns the currently active SparkSession, otherwise the default one. If there is no default
* SparkSession, throws an exception.
*
* @since 3.5.0
*/
def active: SparkSession = {
getActiveSession
.orElse(getDefaultSession)
.getOrElse(throw new IllegalStateException("No active or default Spark session found"))
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy