org.apache.spark.sql.Dataset.scala Maven / Gradle / Ivy
The newest version!
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql
import java.io.{ByteArrayOutputStream, CharArrayWriter, DataOutputStream}
import scala.annotation.varargs
import scala.collection.JavaConverters._
import scala.collection.mutable.{ArrayBuffer, HashSet}
import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal
import org.apache.commons.lang3.StringUtils
import org.apache.commons.text.StringEscapeUtils
import org.apache.spark.TaskContext
import org.apache.spark.annotation.{DeveloperApi, Stable, Unstable}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.java.function._
import org.apache.spark.api.python.{PythonRDD, SerDeUtil}
import org.apache.spark.api.r.RRDD
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, QueryPlanningTracker, ScalaReflection, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog.HiveTableRelation
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions}
import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.trees.{TreeNodeTag, TreePattern}
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, IntervalUtils}
import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
import org.apache.spark.sql.execution.arrow.{ArrowBatchStreamWriter, ArrowConverters}
import org.apache.spark.sql.execution.command._
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanRelation, FileTable}
import org.apache.spark.sql.execution.python.EvaluatePython
import org.apache.spark.sql.execution.stat.StatFunctions
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.DataStreamWriter
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SchemaUtils
import org.apache.spark.storage.StorageLevel
import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.util.Utils
private[sql] object Dataset {
val curId = new java.util.concurrent.atomic.AtomicLong()
val DATASET_ID_KEY = "__dataset_id"
val COL_POS_KEY = "__col_position"
val DATASET_ID_TAG = TreeNodeTag[HashSet[Long]]("dataset_id")
def apply[T: Encoder](sparkSession: SparkSession, logicalPlan: LogicalPlan): Dataset[T] = {
val dataset = new Dataset(sparkSession, logicalPlan, implicitly[Encoder[T]])
// Eagerly bind the encoder so we verify that the encoder matches the underlying
// schema. The user will get an error if this is not the case.
// optimization: it is guaranteed that [[InternalRow]] can be converted to [[Row]] so
// do not do this check in that case. this check can be expensive since it requires running
// the whole [[Analyzer]] to resolve the deserializer
if (dataset.exprEnc.clsTag.runtimeClass != classOf[Row]) {
dataset.resolvedEnc
}
dataset
}
def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan): DataFrame =
sparkSession.withActive {
val qe = sparkSession.sessionState.executePlan(logicalPlan)
qe.assertAnalyzed()
new Dataset[Row](qe, ExpressionEncoder(qe.analyzed.schema))
}
/** A variant of ofRows that allows passing in a tracker so we can track query parsing time. */
def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan, tracker: QueryPlanningTracker)
: DataFrame = sparkSession.withActive {
val qe = new QueryExecution(sparkSession, logicalPlan, tracker)
qe.assertAnalyzed()
new Dataset[Row](qe, ExpressionEncoder(qe.analyzed.schema))
}
}
/**
* A Dataset is a strongly typed collection of domain-specific objects that can be transformed
* in parallel using functional or relational operations. Each Dataset also has an untyped view
* called a `DataFrame`, which is a Dataset of [[Row]].
*
* Operations available on Datasets are divided into transformations and actions. Transformations
* are the ones that produce new Datasets, and actions are the ones that trigger computation and
* return results. Example transformations include map, filter, select, and aggregate (`groupBy`).
* Example actions count, show, or writing data out to file systems.
*
* Datasets are "lazy", i.e. computations are only triggered when an action is invoked. Internally,
* a Dataset represents a logical plan that describes the computation required to produce the data.
* When an action is invoked, Spark's query optimizer optimizes the logical plan and generates a
* physical plan for efficient execution in a parallel and distributed manner. To explore the
* logical plan as well as optimized physical plan, use the `explain` function.
*
* To efficiently support domain-specific objects, an [[Encoder]] is required. The encoder maps
* the domain specific type `T` to Spark's internal type system. For example, given a class `Person`
* with two fields, `name` (string) and `age` (int), an encoder is used to tell Spark to generate
* code at runtime to serialize the `Person` object into a binary structure. This binary structure
* often has much lower memory footprint as well as are optimized for efficiency in data processing
* (e.g. in a columnar format). To understand the internal binary representation for data, use the
* `schema` function.
*
* There are typically two ways to create a Dataset. The most common way is by pointing Spark
* to some files on storage systems, using the `read` function available on a `SparkSession`.
* {{{
* val people = spark.read.parquet("...").as[Person] // Scala
* Dataset people = spark.read().parquet("...").as(Encoders.bean(Person.class)); // Java
* }}}
*
* Datasets can also be created through transformations available on existing Datasets. For example,
* the following creates a new Dataset by applying a filter on the existing one:
* {{{
* val names = people.map(_.name) // in Scala; names is a Dataset[String]
* Dataset names = people.map((Person p) -> p.name, Encoders.STRING));
* }}}
*
* Dataset operations can also be untyped, through various domain-specific-language (DSL)
* functions defined in: Dataset (this class), [[Column]], and [[functions]]. These operations
* are very similar to the operations available in the data frame abstraction in R or Python.
*
* To select a column from the Dataset, use `apply` method in Scala and `col` in Java.
* {{{
* val ageCol = people("age") // in Scala
* Column ageCol = people.col("age"); // in Java
* }}}
*
* Note that the [[Column]] type can also be manipulated through its various functions.
* {{{
* // The following creates a new column that increases everybody's age by 10.
* people("age") + 10 // in Scala
* people.col("age").plus(10); // in Java
* }}}
*
* A more concrete example in Scala:
* {{{
* // To create Dataset[Row] using SparkSession
* val people = spark.read.parquet("...")
* val department = spark.read.parquet("...")
*
* people.filter("age > 30")
* .join(department, people("deptId") === department("id"))
* .groupBy(department("name"), people("gender"))
* .agg(avg(people("salary")), max(people("age")))
* }}}
*
* and in Java:
* {{{
* // To create Dataset using SparkSession
* Dataset people = spark.read().parquet("...");
* Dataset department = spark.read().parquet("...");
*
* people.filter(people.col("age").gt(30))
* .join(department, people.col("deptId").equalTo(department.col("id")))
* .groupBy(department.col("name"), people.col("gender"))
* .agg(avg(people.col("salary")), max(people.col("age")));
* }}}
*
* @groupname basic Basic Dataset functions
* @groupname action Actions
* @groupname untypedrel Untyped transformations
* @groupname typedrel Typed transformations
*
* @since 1.6.0
*/
@Stable
class Dataset[T] private[sql](
@DeveloperApi @Unstable @transient val queryExecution: QueryExecution,
@DeveloperApi @Unstable @transient val encoder: Encoder[T])
extends Serializable {
@transient lazy val sparkSession: SparkSession = {
if (queryExecution == null || queryExecution.sparkSession == null) {
throw QueryExecutionErrors.transformationsAndActionsNotInvokedByDriverError()
}
queryExecution.sparkSession
}
// A globally unique id of this Dataset.
private val id = Dataset.curId.getAndIncrement()
queryExecution.assertAnalyzed()
// Note for Spark contributors: if adding or updating any action in `Dataset`, please make sure
// you wrap it with `withNewExecutionId` if this actions doesn't call other action.
def this(sparkSession: SparkSession, logicalPlan: LogicalPlan, encoder: Encoder[T]) = {
this(sparkSession.sessionState.executePlan(logicalPlan), encoder)
}
def this(sqlContext: SQLContext, logicalPlan: LogicalPlan, encoder: Encoder[T]) = {
this(sqlContext.sparkSession, logicalPlan, encoder)
}
@transient private[sql] val logicalPlan: LogicalPlan = {
val plan = queryExecution.commandExecuted
if (sparkSession.conf.get(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED)) {
val dsIds = plan.getTagValue(Dataset.DATASET_ID_TAG).getOrElse(new HashSet[Long])
dsIds.add(id)
plan.setTagValue(Dataset.DATASET_ID_TAG, dsIds)
}
plan
}
/**
* Currently [[ExpressionEncoder]] is the only implementation of [[Encoder]], here we turn the
* passed in encoder to [[ExpressionEncoder]] explicitly, and mark it implicit so that we can use
* it when constructing new Dataset objects that have the same object type (that will be
* possibly resolved to a different schema).
*/
private[sql] implicit val exprEnc: ExpressionEncoder[T] = encoderFor(encoder)
// The resolved `ExpressionEncoder` which can be used to turn rows to objects of type T, after
// collecting rows to the driver side.
private lazy val resolvedEnc = {
exprEnc.resolveAndBind(logicalPlan.output, sparkSession.sessionState.analyzer)
}
private implicit def classTag = exprEnc.clsTag
// sqlContext must be val because a stable identifier is expected when you import implicits
@transient lazy val sqlContext: SQLContext = sparkSession.sqlContext
private[sql] def resolve(colName: String): NamedExpression = {
val resolver = sparkSession.sessionState.analyzer.resolver
queryExecution.analyzed.resolveQuoted(colName, resolver)
.getOrElse(throw QueryCompilationErrors.resolveException(colName, schema.fieldNames))
}
private[sql] def numericColumns: Seq[Expression] = {
schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n =>
queryExecution.analyzed.resolveQuoted(n.name, sparkSession.sessionState.analyzer.resolver).get
}
}
/**
* Get rows represented in Sequence by specific truncate and vertical requirement.
*
* @param numRows Number of rows to return
* @param truncate If set to more than 0, truncates strings to `truncate` characters and
* all cells will be aligned right.
*/
private[sql] def getRows(
numRows: Int,
truncate: Int): Seq[Seq[String]] = {
val newDf = logicalPlan match {
case c: CommandResult =>
// Convert to `LocalRelation` and let `ConvertToLocalRelation` do the casting locally to
// avoid triggering a job
Dataset.ofRows(sparkSession, LocalRelation(c.output, c.rows))
case _ => toDF()
}
val castCols = newDf.logicalPlan.output.map { col =>
Column(ToPrettyString(col))
}
val data = newDf.select(castCols: _*).take(numRows + 1)
// For array values, replace Seq and Array with square brackets
// For cells that are beyond `truncate` characters, replace it with the
// first `truncate-3` and "..."
schema.fieldNames.map(SchemaUtils.escapeMetaCharacters).toSeq +: data.map { row =>
row.toSeq.map { cell =>
assert(cell != null, "ToPrettyString is not nullable and should not return null value")
// Escapes meta-characters not to break the `showString` format
val str = SchemaUtils.escapeMetaCharacters(cell.toString)
if (truncate > 0 && str.length > truncate) {
// do not show ellipses for strings shorter than 4 characters.
if (truncate < 4) str.substring(0, truncate)
else str.substring(0, truncate - 3) + "..."
} else {
str
}
}: Seq[String]
}
}
/**
* Compose the string representing rows for output
*
* @param _numRows Number of rows to show
* @param truncate If set to more than 0, truncates strings to `truncate` characters and
* all cells will be aligned right.
* @param vertical If set to true, prints output rows vertically (one line per column value).
*/
private[sql] def showString(
_numRows: Int,
truncate: Int = 20,
vertical: Boolean = false): String = {
val numRows = _numRows.max(0).min(ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH - 1)
// Get rows represented by Seq[Seq[String]], we may get one more line if it has more data.
val tmpRows = getRows(numRows, truncate)
val hasMoreData = tmpRows.length - 1 > numRows
val rows = tmpRows.take(numRows + 1)
val sb = new StringBuilder
val numCols = schema.fieldNames.length
// We set a minimum column width at '3'
val minimumColWidth = 3
if (!vertical) {
// Initialise the width of each column to a minimum value
val colWidths = Array.fill(numCols)(minimumColWidth)
// Compute the width of each column
for (row <- rows) {
for ((cell, i) <- row.zipWithIndex) {
colWidths(i) = math.max(colWidths(i), Utils.stringHalfWidth(cell))
}
}
val paddedRows = rows.map { row =>
row.zipWithIndex.map { case (cell, i) =>
if (truncate > 0) {
StringUtils.leftPad(cell, colWidths(i) - Utils.stringHalfWidth(cell) + cell.length)
} else {
StringUtils.rightPad(cell, colWidths(i) - Utils.stringHalfWidth(cell) + cell.length)
}
}
}
// Create SeparateLine
val sep: String = colWidths.map("-" * _).addString(sb, "+", "+", "+\n").toString()
// column names
paddedRows.head.addString(sb, "|", "|", "|\n")
sb.append(sep)
// data
paddedRows.tail.foreach(_.addString(sb, "|", "|", "|\n"))
sb.append(sep)
} else {
// Extended display mode enabled
val fieldNames = rows.head
val dataRows = rows.tail
// Compute the width of field name and data columns
val fieldNameColWidth = fieldNames.foldLeft(minimumColWidth) { case (curMax, fieldName) =>
math.max(curMax, Utils.stringHalfWidth(fieldName))
}
val dataColWidth = dataRows.foldLeft(minimumColWidth) { case (curMax, row) =>
math.max(curMax, row.map(cell => Utils.stringHalfWidth(cell)).max)
}
dataRows.zipWithIndex.foreach { case (row, i) =>
// "+ 5" in size means a character length except for padded names and data
val rowHeader = StringUtils.rightPad(
s"-RECORD $i", fieldNameColWidth + dataColWidth + 5, "-")
sb.append(rowHeader).append("\n")
row.zipWithIndex.map { case (cell, j) =>
val fieldName = StringUtils.rightPad(fieldNames(j),
fieldNameColWidth - Utils.stringHalfWidth(fieldNames(j)) + fieldNames(j).length)
val data = StringUtils.rightPad(cell,
dataColWidth - Utils.stringHalfWidth(cell) + cell.length)
s" $fieldName | $data "
}.addString(sb, "", "\n", "\n")
}
}
// Print a footer
if (vertical && rows.tail.isEmpty) {
// In a vertical mode, print an empty row set explicitly
sb.append("(0 rows)\n")
} else if (hasMoreData) {
// For Data that has more than "numRows" records
val rowsString = if (numRows == 1) "row" else "rows"
sb.append(s"only showing top $numRows $rowsString\n")
}
sb.toString()
}
/**
* Compose the HTML representing rows for output
*
* @param _numRows Number of rows to show
* @param truncate If set to more than 0, truncates strings to `truncate` characters and
* all cells will be aligned right.
*/
private[sql] def htmlString(
_numRows: Int,
truncate: Int = 20): String = {
val numRows = _numRows.max(0).min(ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH - 1)
// Get rows represented by Seq[Seq[String]], we may get one more line if it has more data.
val tmpRows = getRows(numRows, truncate)
val hasMoreData = tmpRows.length - 1 > numRows
val rows = tmpRows.take(numRows + 1)
val sb = new StringBuilder
sb.append("\n")
sb.append(rows.head.map(StringEscapeUtils.escapeHtml4)
.mkString("", " ", " \n"))
rows.tail.foreach { row =>
sb.append(row.map(StringEscapeUtils.escapeHtml4)
.mkString("", " ", " \n"))
}
sb.append("
\n")
if (hasMoreData) {
sb.append(s"only showing top $numRows ${if (numRows == 1) "row" else "rows"}\n")
}
sb.toString()
}
override def toString: String = {
try {
val builder = new StringBuilder
val fields = schema.take(2).map {
case f => s"${f.name}: ${f.dataType.simpleString(2)}"
}
builder.append("[")
builder.append(fields.mkString(", "))
if (schema.length > 2) {
if (schema.length - fields.size == 1) {
builder.append(" ... 1 more field")
} else {
builder.append(" ... " + (schema.length - 2) + " more fields")
}
}
builder.append("]").toString()
} catch {
case NonFatal(e) =>
s"Invalid tree; ${e.getMessage}:\n$queryExecution"
}
}
/**
* Converts this strongly typed collection of data to generic Dataframe. In contrast to the
* strongly typed objects that Dataset operations work on, a Dataframe returns generic [[Row]]
* objects that allow fields to be accessed by ordinal or name.
*
* @group basic
* @since 1.6.0
*/
// This is declared with parentheses to prevent the Scala compiler from treating
// `ds.toDF("1")` as invoking this toDF and then apply on the returned DataFrame.
def toDF(): DataFrame = new Dataset[Row](queryExecution, ExpressionEncoder(schema))
/**
* Returns a new Dataset where each record has been mapped on to the specified type. The
* method used to map columns depend on the type of `U`:
*
* - When `U` is a class, fields for the class will be mapped to columns of the same name
* (case sensitivity is determined by `spark.sql.caseSensitive`).
* - When `U` is a tuple, the columns will be mapped by ordinal (i.e. the first column will
* be assigned to `_1`).
* - When `U` is a primitive type (i.e. String, Int, etc), then the first column of the
* `DataFrame` will be used.
*
*
* If the schema of the Dataset does not match the desired `U` type, you can use `select`
* along with `alias` or `as` to rearrange or rename as required.
*
* Note that `as[]` only changes the view of the data that is passed into typed operations,
* such as `map()`, and does not eagerly project away any columns that are not present in
* the specified class.
*
* @group basic
* @since 1.6.0
*/
def as[U : Encoder]: Dataset[U] = Dataset[U](sparkSession, logicalPlan)
/**
* Returns a new DataFrame where each row is reconciled to match the specified schema. Spark will:
*
* - Reorder columns and/or inner fields by name to match the specified schema.
* - Project away columns and/or inner fields that are not needed by the specified schema.
* Missing columns and/or inner fields (present in the specified schema but not input DataFrame)
* lead to failures.
* - Cast the columns and/or inner fields to match the data types in the specified schema, if
* the types are compatible, e.g., numeric to numeric (error if overflows), but not string to
* int.
* - Carry over the metadata from the specified schema, while the columns and/or inner fields
* still keep their own metadata if not overwritten by the specified schema.
* - Fail if the nullability is not compatible. For example, the column and/or inner field is
* nullable but the specified schema requires them to be not nullable.
*
*
* @group basic
* @since 3.4.0
*/
def to(schema: StructType): DataFrame = withPlan {
val replaced = CharVarcharUtils.failIfHasCharVarchar(schema).asInstanceOf[StructType]
Project.matchSchema(logicalPlan, replaced, sparkSession.sessionState.conf)
}
/**
* Converts this strongly typed collection of data to generic `DataFrame` with columns renamed.
* This can be quite convenient in conversion from an RDD of tuples into a `DataFrame` with
* meaningful names. For example:
* {{{
* val rdd: RDD[(Int, String)] = ...
* rdd.toDF() // this implicit conversion creates a DataFrame with column name `_1` and `_2`
* rdd.toDF("id", "name") // this creates a DataFrame with column name "id" and "name"
* }}}
*
* @group basic
* @since 2.0.0
*/
@scala.annotation.varargs
def toDF(colNames: String*): DataFrame = {
require(schema.size == colNames.size,
"The number of columns doesn't match.\n" +
s"Old column names (${schema.size}): " + schema.fields.map(_.name).mkString(", ") + "\n" +
s"New column names (${colNames.size}): " + colNames.mkString(", "))
val newCols = logicalPlan.output.zip(colNames).map { case (oldAttribute, newName) =>
Column(oldAttribute).as(newName)
}
select(newCols : _*)
}
/**
* Returns the schema of this Dataset.
*
* @group basic
* @since 1.6.0
*/
def schema: StructType = sparkSession.withActive {
queryExecution.analyzed.schema
}
/**
* Prints the schema to the console in a nice tree format.
*
* @group basic
* @since 1.6.0
*/
def printSchema(): Unit = printSchema(Int.MaxValue)
// scalastyle:off println
/**
* Prints the schema up to the given level to the console in a nice tree format.
*
* @group basic
* @since 3.0.0
*/
def printSchema(level: Int): Unit = println(schema.treeString(level))
// scalastyle:on println
/**
* Prints the plans (logical and physical) with a format specified by a given explain mode.
*
* @param mode specifies the expected output format of plans.
*
* - `simple` Print only a physical plan.
* - `extended`: Print both logical and physical plans.
* - `codegen`: Print a physical plan and generated codes if they are
* available.
* - `cost`: Print a logical plan and statistics if they are available.
* - `formatted`: Split explain output into two sections: a physical plan outline
* and node details.
*
* @group basic
* @since 3.0.0
*/
def explain(mode: String): Unit = sparkSession.withActive {
// Because temporary views are resolved during analysis when we create a Dataset, and
// `ExplainCommand` analyzes input query plan and resolves temporary views again. Using
// `ExplainCommand` here will probably output different query plans, compared to the results
// of evaluation of the Dataset. So just output QueryExecution's query plans here.
// scalastyle:off println
println(queryExecution.explainString(ExplainMode.fromString(mode)))
// scalastyle:on println
}
/**
* Prints the plans (logical and physical) to the console for debugging purposes.
*
* @param extended default `false`. If `false`, prints only the physical plan.
*
* @group basic
* @since 1.6.0
*/
def explain(extended: Boolean): Unit = if (extended) {
explain(ExtendedMode.name)
} else {
explain(SimpleMode.name)
}
/**
* Prints the physical plan to the console for debugging purposes.
*
* @group basic
* @since 1.6.0
*/
def explain(): Unit = explain(SimpleMode.name)
/**
* Returns all column names and their data types as an array.
*
* @group basic
* @since 1.6.0
*/
def dtypes: Array[(String, String)] = schema.fields.map { field =>
(field.name, field.dataType.toString)
}
/**
* Returns all column names as an array.
*
* @group basic
* @since 1.6.0
*/
def columns: Array[String] = schema.fields.map(_.name)
/**
* Returns true if the `collect` and `take` methods can be run locally
* (without any Spark executors).
*
* @group basic
* @since 1.6.0
*/
def isLocal: Boolean = logicalPlan.isInstanceOf[LocalRelation] ||
logicalPlan.isInstanceOf[CommandResult]
/**
* Returns true if the `Dataset` is empty.
*
* @group basic
* @since 2.4.0
*/
def isEmpty: Boolean = withAction("isEmpty", select().queryExecution) { plan =>
plan.executeTake(1).isEmpty
}
/**
* Returns true if this Dataset contains one or more sources that continuously
* return data as it arrives. A Dataset that reads data from a streaming source
* must be executed as a `StreamingQuery` using the `start()` method in
* `DataStreamWriter`. Methods that return a single answer, e.g. `count()` or
* `collect()`, will throw an [[AnalysisException]] when there is a streaming
* source present.
*
* @group streaming
* @since 2.0.0
*/
def isStreaming: Boolean = logicalPlan.isStreaming
/**
* Eagerly checkpoint a Dataset and return the new Dataset. Checkpointing can be used to truncate
* the logical plan of this Dataset, which is especially useful in iterative algorithms where the
* plan may grow exponentially. It will be saved to files inside the checkpoint
* directory set with `SparkContext#setCheckpointDir`.
*
* @group basic
* @since 2.1.0
*/
def checkpoint(): Dataset[T] = checkpoint(eager = true, reliableCheckpoint = true)
/**
* Returns a checkpointed version of this Dataset. Checkpointing can be used to truncate the
* logical plan of this Dataset, which is especially useful in iterative algorithms where the
* plan may grow exponentially. It will be saved to files inside the checkpoint
* directory set with `SparkContext#setCheckpointDir`.
*
* @group basic
* @since 2.1.0
*/
def checkpoint(eager: Boolean): Dataset[T] = checkpoint(eager = eager, reliableCheckpoint = true)
/**
* Eagerly locally checkpoints a Dataset and return the new Dataset. Checkpointing can be
* used to truncate the logical plan of this Dataset, which is especially useful in iterative
* algorithms where the plan may grow exponentially. Local checkpoints are written to executor
* storage and despite potentially faster they are unreliable and may compromise job completion.
*
* @group basic
* @since 2.3.0
*/
def localCheckpoint(): Dataset[T] = checkpoint(eager = true, reliableCheckpoint = false)
/**
* Locally checkpoints a Dataset and return the new Dataset. Checkpointing can be used to truncate
* the logical plan of this Dataset, which is especially useful in iterative algorithms where the
* plan may grow exponentially. Local checkpoints are written to executor storage and despite
* potentially faster they are unreliable and may compromise job completion.
*
* @group basic
* @since 2.3.0
*/
def localCheckpoint(eager: Boolean): Dataset[T] = checkpoint(
eager = eager,
reliableCheckpoint = false
)
/**
* Returns a checkpointed version of this Dataset.
*
* @param eager Whether to checkpoint this dataframe immediately
* @param reliableCheckpoint Whether to create a reliable checkpoint saved to files inside the
* checkpoint directory. If false creates a local checkpoint using
* the caching subsystem
*/
private def checkpoint(eager: Boolean, reliableCheckpoint: Boolean): Dataset[T] = {
val actionName = if (reliableCheckpoint) "checkpoint" else "localCheckpoint"
withAction(actionName, queryExecution) { physicalPlan =>
val internalRdd = physicalPlan.execute().map(_.copy())
if (reliableCheckpoint) {
internalRdd.checkpoint()
} else {
internalRdd.localCheckpoint()
}
if (eager) {
internalRdd.doCheckpoint()
}
Dataset.ofRows(
sparkSession,
LogicalRDD.fromDataset(rdd = internalRdd, originDataset = this, isStreaming = isStreaming)
).as[T]
}
}
/**
* Defines an event time watermark for this [[Dataset]]. A watermark tracks a point in time
* before which we assume no more late data is going to arrive.
*
* Spark will use this watermark for several purposes:
*
* - To know when a given time window aggregation can be finalized and thus can be emitted
* when using output modes that do not allow updates.
* - To minimize the amount of state that we need to keep for on-going aggregations,
* `mapGroupsWithState` and `dropDuplicates` operators.
*
* The current watermark is computed by looking at the `MAX(eventTime)` seen across
* all of the partitions in the query minus a user specified `delayThreshold`. Due to the cost
* of coordinating this value across partitions, the actual watermark used is only guaranteed
* to be at least `delayThreshold` behind the actual event time. In some cases we may still
* process records that arrive more than `delayThreshold` late.
*
* @param eventTime the name of the column that contains the event time of the row.
* @param delayThreshold the minimum delay to wait to data to arrive late, relative to the latest
* record that has been processed in the form of an interval
* (e.g. "1 minute" or "5 hours"). NOTE: This should not be negative.
*
* @group streaming
* @since 2.1.0
*/
// We only accept an existing column name, not a derived column here as a watermark that is
// defined on a derived column cannot referenced elsewhere in the plan.
def withWatermark(eventTime: String, delayThreshold: String): Dataset[T] = withTypedPlan {
val parsedDelay = IntervalUtils.fromIntervalString(delayThreshold)
require(!IntervalUtils.isNegative(parsedDelay),
s"delay threshold ($delayThreshold) should not be negative.")
EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, logicalPlan)
}
/**
* Displays the Dataset in a tabular form. Strings more than 20 characters will be truncated,
* and all cells will be aligned right. For example:
* {{{
* year month AVG('Adj Close) MAX('Adj Close)
* 1980 12 0.503218 0.595103
* 1981 01 0.523289 0.570307
* 1982 02 0.436504 0.475256
* 1983 03 0.410516 0.442194
* 1984 04 0.450090 0.483521
* }}}
*
* @param numRows Number of rows to show
*
* @group action
* @since 1.6.0
*/
def show(numRows: Int): Unit = show(numRows, truncate = true)
/**
* Displays the top 20 rows of Dataset in a tabular form. Strings more than 20 characters
* will be truncated, and all cells will be aligned right.
*
* @group action
* @since 1.6.0
*/
def show(): Unit = show(20)
/**
* Displays the top 20 rows of Dataset in a tabular form.
*
* @param truncate Whether truncate long strings. If true, strings more than 20 characters will
* be truncated and all cells will be aligned right
*
* @group action
* @since 1.6.0
*/
def show(truncate: Boolean): Unit = show(20, truncate)
/**
* Displays the Dataset in a tabular form. For example:
* {{{
* year month AVG('Adj Close) MAX('Adj Close)
* 1980 12 0.503218 0.595103
* 1981 01 0.523289 0.570307
* 1982 02 0.436504 0.475256
* 1983 03 0.410516 0.442194
* 1984 04 0.450090 0.483521
* }}}
* @param numRows Number of rows to show
* @param truncate Whether truncate long strings. If true, strings more than 20 characters will
* be truncated and all cells will be aligned right
*
* @group action
* @since 1.6.0
*/
// scalastyle:off println
def show(numRows: Int, truncate: Boolean): Unit = if (truncate) {
println(showString(numRows, truncate = 20))
} else {
println(showString(numRows, truncate = 0))
}
/**
* Displays the Dataset in a tabular form. For example:
* {{{
* year month AVG('Adj Close) MAX('Adj Close)
* 1980 12 0.503218 0.595103
* 1981 01 0.523289 0.570307
* 1982 02 0.436504 0.475256
* 1983 03 0.410516 0.442194
* 1984 04 0.450090 0.483521
* }}}
*
* @param numRows Number of rows to show
* @param truncate If set to more than 0, truncates strings to `truncate` characters and
* all cells will be aligned right.
* @group action
* @since 1.6.0
*/
def show(numRows: Int, truncate: Int): Unit = show(numRows, truncate, vertical = false)
/**
* Displays the Dataset in a tabular form. For example:
* {{{
* year month AVG('Adj Close) MAX('Adj Close)
* 1980 12 0.503218 0.595103
* 1981 01 0.523289 0.570307
* 1982 02 0.436504 0.475256
* 1983 03 0.410516 0.442194
* 1984 04 0.450090 0.483521
* }}}
*
* If `vertical` enabled, this command prints output rows vertically (one line per column value)?
*
* {{{
* -RECORD 0-------------------
* year | 1980
* month | 12
* AVG('Adj Close) | 0.503218
* AVG('Adj Close) | 0.595103
* -RECORD 1-------------------
* year | 1981
* month | 01
* AVG('Adj Close) | 0.523289
* AVG('Adj Close) | 0.570307
* -RECORD 2-------------------
* year | 1982
* month | 02
* AVG('Adj Close) | 0.436504
* AVG('Adj Close) | 0.475256
* -RECORD 3-------------------
* year | 1983
* month | 03
* AVG('Adj Close) | 0.410516
* AVG('Adj Close) | 0.442194
* -RECORD 4-------------------
* year | 1984
* month | 04
* AVG('Adj Close) | 0.450090
* AVG('Adj Close) | 0.483521
* }}}
*
* @param numRows Number of rows to show
* @param truncate If set to more than 0, truncates strings to `truncate` characters and
* all cells will be aligned right.
* @param vertical If set to true, prints output rows vertically (one line per column value).
* @group action
* @since 2.3.0
*/
// scalastyle:off println
def show(numRows: Int, truncate: Int, vertical: Boolean): Unit =
println(showString(numRows, truncate, vertical))
// scalastyle:on println
/**
* Returns a [[DataFrameNaFunctions]] for working with missing data.
* {{{
* // Dropping rows containing any null values.
* ds.na.drop()
* }}}
*
* @group untypedrel
* @since 1.6.0
*/
def na: DataFrameNaFunctions = new DataFrameNaFunctions(toDF())
/**
* Returns a [[DataFrameStatFunctions]] for working statistic functions support.
* {{{
* // Finding frequent items in column with name 'a'.
* ds.stat.freqItems(Seq("a"))
* }}}
*
* @group untypedrel
* @since 1.6.0
*/
def stat: DataFrameStatFunctions = new DataFrameStatFunctions(toDF())
/**
* Join with another `DataFrame`.
*
* Behaves as an INNER JOIN and requires a subsequent join predicate.
*
* @param right Right side of the join operation.
*
* @group untypedrel
* @since 2.0.0
*/
def join(right: Dataset[_]): DataFrame = withPlan {
Join(logicalPlan, right.logicalPlan, joinType = Inner, None, JoinHint.NONE)
}
/**
* Inner equi-join with another `DataFrame` using the given column.
*
* Different from other join functions, the join column will only appear once in the output,
* i.e. similar to SQL's `JOIN USING` syntax.
*
* {{{
* // Joining df1 and df2 using the column "user_id"
* df1.join(df2, "user_id")
* }}}
*
* @param right Right side of the join operation.
* @param usingColumn Name of the column to join on. This column must exist on both sides.
*
* @note If you perform a self-join using this function without aliasing the input
* `DataFrame`s, you will NOT be able to reference any columns after the join, since
* there is no way to disambiguate which side of the join you would like to reference.
*
* @group untypedrel
* @since 2.0.0
*/
def join(right: Dataset[_], usingColumn: String): DataFrame = {
join(right, Seq(usingColumn))
}
/**
* (Java-specific) Inner equi-join with another `DataFrame` using the given columns. See the
* Scala-specific overload for more details.
*
* @param right Right side of the join operation.
* @param usingColumns Names of the columns to join on. This columns must exist on both sides.
*
* @group untypedrel
* @since 3.4.0
*/
def join(right: Dataset[_], usingColumns: Array[String]): DataFrame = {
join(right, usingColumns.toSeq)
}
/**
* (Scala-specific) Inner equi-join with another `DataFrame` using the given columns.
*
* Different from other join functions, the join columns will only appear once in the output,
* i.e. similar to SQL's `JOIN USING` syntax.
*
* {{{
* // Joining df1 and df2 using the columns "user_id" and "user_name"
* df1.join(df2, Seq("user_id", "user_name"))
* }}}
*
* @param right Right side of the join operation.
* @param usingColumns Names of the columns to join on. This columns must exist on both sides.
*
* @note If you perform a self-join using this function without aliasing the input
* `DataFrame`s, you will NOT be able to reference any columns after the join, since
* there is no way to disambiguate which side of the join you would like to reference.
*
* @group untypedrel
* @since 2.0.0
*/
def join(right: Dataset[_], usingColumns: Seq[String]): DataFrame = {
join(right, usingColumns, "inner")
}
/**
* Equi-join with another `DataFrame` using the given column. A cross join with a predicate
* is specified as an inner join. If you would explicitly like to perform a cross join use the
* `crossJoin` method.
*
* Different from other join functions, the join column will only appear once in the output,
* i.e. similar to SQL's `JOIN USING` syntax.
*
* @param right Right side of the join operation.
* @param usingColumn Name of the column to join on. This column must exist on both sides.
* @param joinType Type of join to perform. Default `inner`. Must be one of:
* `inner`, `cross`, `outer`, `full`, `fullouter`, `full_outer`, `left`,
* `leftouter`, `left_outer`, `right`, `rightouter`, `right_outer`,
* `semi`, `leftsemi`, `left_semi`, `anti`, `leftanti`, left_anti`.
*
* @note If you perform a self-join using this function without aliasing the input
* `DataFrame`s, you will NOT be able to reference any columns after the join, since
* there is no way to disambiguate which side of the join you would like to reference.
*
* @group untypedrel
* @since 3.4.0
*/
def join(right: Dataset[_], usingColumn: String, joinType: String): DataFrame = {
join(right, Seq(usingColumn), joinType)
}
/**
* (Java-specific) Equi-join with another `DataFrame` using the given columns. See the
* Scala-specific overload for more details.
*
* @param right Right side of the join operation.
* @param usingColumns Names of the columns to join on. This columns must exist on both sides.
* @param joinType Type of join to perform. Default `inner`. Must be one of:
* `inner`, `cross`, `outer`, `full`, `fullouter`, `full_outer`, `left`,
* `leftouter`, `left_outer`, `right`, `rightouter`, `right_outer`,
* `semi`, `leftsemi`, `left_semi`, `anti`, `leftanti`, left_anti`.
*
* @group untypedrel
* @since 3.4.0
*/
def join(right: Dataset[_], usingColumns: Array[String], joinType: String): DataFrame = {
join(right, usingColumns.toSeq, joinType)
}
/**
* (Scala-specific) Equi-join with another `DataFrame` using the given columns. A cross join
* with a predicate is specified as an inner join. If you would explicitly like to perform a
* cross join use the `crossJoin` method.
*
* Different from other join functions, the join columns will only appear once in the output,
* i.e. similar to SQL's `JOIN USING` syntax.
*
* @param right Right side of the join operation.
* @param usingColumns Names of the columns to join on. This columns must exist on both sides.
* @param joinType Type of join to perform. Default `inner`. Must be one of:
* `inner`, `cross`, `outer`, `full`, `fullouter`, `full_outer`, `left`,
* `leftouter`, `left_outer`, `right`, `rightouter`, `right_outer`,
* `semi`, `leftsemi`, `left_semi`, `anti`, `leftanti`, `left_anti`.
*
* @note If you perform a self-join using this function without aliasing the input
* `DataFrame`s, you will NOT be able to reference any columns after the join, since
* there is no way to disambiguate which side of the join you would like to reference.
*
* @group untypedrel
* @since 2.0.0
*/
def join(right: Dataset[_], usingColumns: Seq[String], joinType: String): DataFrame = {
// Analyze the self join. The assumption is that the analyzer will disambiguate left vs right
// by creating a new instance for one of the branch.
val joined = sparkSession.sessionState.executePlan(
Join(logicalPlan, right.logicalPlan, joinType = JoinType(joinType), None, JoinHint.NONE))
.analyzed.asInstanceOf[Join]
withPlan {
Join(
joined.left,
joined.right,
UsingJoin(JoinType(joinType), usingColumns.toIndexedSeq),
None,
JoinHint.NONE)
}
}
/**
* Inner join with another `DataFrame`, using the given join expression.
*
* {{{
* // The following two are equivalent:
* df1.join(df2, $"df1Key" === $"df2Key")
* df1.join(df2).where($"df1Key" === $"df2Key")
* }}}
*
* @group untypedrel
* @since 2.0.0
*/
def join(right: Dataset[_], joinExprs: Column): DataFrame = join(right, joinExprs, "inner")
/**
* find the trivially true predicates and automatically resolves them to both sides.
*/
private def resolveSelfJoinCondition(
right: Dataset[_],
joinExprs: Option[Column],
joinType: String): Join = {
// Note that in this function, we introduce a hack in the case of self-join to automatically
// resolve ambiguous join conditions into ones that might make sense [SPARK-6231].
// Consider this case: df.join(df, df("key") === df("key"))
// Since df("key") === df("key") is a trivially true condition, this actually becomes a
// cartesian join. However, most likely users expect to perform a self join using "key".
// With that assumption, this hack turns the trivially true condition into equality on join
// keys that are resolved to both sides.
// Trigger analysis so in the case of self-join, the analyzer will clone the plan.
// After the cloning, left and right side will have distinct expression ids.
val plan = withPlan(
Join(logicalPlan, right.logicalPlan,
JoinType(joinType), joinExprs.map(_.expr), JoinHint.NONE))
.queryExecution.analyzed.asInstanceOf[Join]
// If auto self join alias is disabled, return the plan.
if (!sparkSession.sessionState.conf.dataFrameSelfJoinAutoResolveAmbiguity) {
return plan
}
// If left/right have no output set intersection, return the plan.
val lanalyzed = this.queryExecution.analyzed
val ranalyzed = right.queryExecution.analyzed
if (lanalyzed.outputSet.intersect(ranalyzed.outputSet).isEmpty) {
return plan
}
// Otherwise, find the trivially true predicates and automatically resolves them to both sides.
// By the time we get here, since we have already run analysis, all attributes should've been
// resolved and become AttributeReference.
JoinWith.resolveSelfJoinCondition(sparkSession.sessionState.analyzer.resolver, plan)
}
/**
* Join with another `DataFrame`, using the given join expression. The following performs
* a full outer join between `df1` and `df2`.
*
* {{{
* // Scala:
* import org.apache.spark.sql.functions._
* df1.join(df2, $"df1Key" === $"df2Key", "outer")
*
* // Java:
* import static org.apache.spark.sql.functions.*;
* df1.join(df2, col("df1Key").equalTo(col("df2Key")), "outer");
* }}}
*
* @param right Right side of the join.
* @param joinExprs Join expression.
* @param joinType Type of join to perform. Default `inner`. Must be one of:
* `inner`, `cross`, `outer`, `full`, `fullouter`, `full_outer`, `left`,
* `leftouter`, `left_outer`, `right`, `rightouter`, `right_outer`,
* `semi`, `leftsemi`, `left_semi`, `anti`, `leftanti`, left_anti`.
*
* @group untypedrel
* @since 2.0.0
*/
def join(right: Dataset[_], joinExprs: Column, joinType: String): DataFrame = {
withPlan {
resolveSelfJoinCondition(right, Some(joinExprs), joinType)
}
}
/**
* Explicit cartesian join with another `DataFrame`.
*
* @param right Right side of the join operation.
*
* @note Cartesian joins are very expensive without an extra filter that can be pushed down.
*
* @group untypedrel
* @since 2.1.0
*/
def crossJoin(right: Dataset[_]): DataFrame = withPlan {
Join(logicalPlan, right.logicalPlan, joinType = Cross, None, JoinHint.NONE)
}
/**
* Joins this Dataset returning a `Tuple2` for each pair where `condition` evaluates to
* true.
*
* This is similar to the relation `join` function with one important difference in the
* result schema. Since `joinWith` preserves objects present on either side of the join, the
* result schema is similarly nested into a tuple under the column names `_1` and `_2`.
*
* This type of join can be useful both for preserving type-safety with the original object
* types as well as working with relational data where either side of the join has column
* names in common.
*
* @param other Right side of the join.
* @param condition Join expression.
* @param joinType Type of join to perform. Default `inner`. Must be one of:
* `inner`, `cross`, `outer`, `full`, `fullouter`,`full_outer`, `left`,
* `leftouter`, `left_outer`, `right`, `rightouter`, `right_outer`.
*
* @group typedrel
* @since 1.6.0
*/
def joinWith[U](other: Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] = {
// Creates a Join node and resolve it first, to get join condition resolved, self-join resolved,
// etc.
val joined = sparkSession.sessionState.executePlan(
Join(
this.logicalPlan,
other.logicalPlan,
JoinType(joinType),
Some(condition.expr),
JoinHint.NONE)).analyzed.asInstanceOf[Join]
implicit val tuple2Encoder: Encoder[(T, U)] =
ExpressionEncoder
.tuple(Seq(this.exprEnc, other.exprEnc), useNullSafeDeserializer = true)
.asInstanceOf[Encoder[(T, U)]]
withTypedPlan(JoinWith.typedJoinWith(
joined,
sqlContext.conf.dataFrameSelfJoinAutoResolveAmbiguity,
sparkSession.sessionState.analyzer.resolver,
this.exprEnc.isSerializedAsStructForTopLevel,
other.exprEnc.isSerializedAsStructForTopLevel))
}
/**
* Using inner equi-join to join this Dataset returning a `Tuple2` for each pair
* where `condition` evaluates to true.
*
* @param other Right side of the join.
* @param condition Join expression.
*
* @group typedrel
* @since 1.6.0
*/
def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = {
joinWith(other, condition, "inner")
}
// TODO(SPARK-22947): Fix the DataFrame API.
private[sql] def joinAsOf(
other: Dataset[_],
leftAsOf: Column,
rightAsOf: Column,
usingColumns: Seq[String],
joinType: String,
tolerance: Column,
allowExactMatches: Boolean,
direction: String): DataFrame = {
val joinExprs = usingColumns.map { column =>
EqualTo(resolve(column), other.resolve(column))
}.reduceOption(And).map(Column.apply).orNull
joinAsOf(other, leftAsOf, rightAsOf, joinExprs, joinType,
tolerance, allowExactMatches, direction)
}
// TODO(SPARK-22947): Fix the DataFrame API.
private[sql] def joinAsOf(
other: Dataset[_],
leftAsOf: Column,
rightAsOf: Column,
joinExprs: Column,
joinType: String,
tolerance: Column,
allowExactMatches: Boolean,
direction: String): DataFrame = {
val joined = resolveSelfJoinCondition(other, Option(joinExprs), joinType)
val leftAsOfExpr = leftAsOf.expr.transformUp {
case a: AttributeReference if logicalPlan.outputSet.contains(a) =>
val index = logicalPlan.output.indexWhere(_.exprId == a.exprId)
joined.left.output(index)
}
val rightAsOfExpr = rightAsOf.expr.transformUp {
case a: AttributeReference if other.logicalPlan.outputSet.contains(a) =>
val index = other.logicalPlan.output.indexWhere(_.exprId == a.exprId)
joined.right.output(index)
}
withPlan {
AsOfJoin(
joined.left, joined.right,
leftAsOfExpr, rightAsOfExpr,
joined.condition,
joined.joinType,
Option(tolerance).map(_.expr),
allowExactMatches,
AsOfJoinDirection(direction)
)
}
}
/**
* Returns a new Dataset with each partition sorted by the given expressions.
*
* This is the same operation as "SORT BY" in SQL (Hive QL).
*
* @group typedrel
* @since 2.0.0
*/
@scala.annotation.varargs
def sortWithinPartitions(sortCol: String, sortCols: String*): Dataset[T] = {
sortWithinPartitions((sortCol +: sortCols).map(Column(_)) : _*)
}
/**
* Returns a new Dataset with each partition sorted by the given expressions.
*
* This is the same operation as "SORT BY" in SQL (Hive QL).
*
* @group typedrel
* @since 2.0.0
*/
@scala.annotation.varargs
def sortWithinPartitions(sortExprs: Column*): Dataset[T] = {
sortInternal(global = false, sortExprs)
}
/**
* Returns a new Dataset sorted by the specified column, all in ascending order.
* {{{
* // The following 3 are equivalent
* ds.sort("sortcol")
* ds.sort($"sortcol")
* ds.sort($"sortcol".asc)
* }}}
*
* @group typedrel
* @since 2.0.0
*/
@scala.annotation.varargs
def sort(sortCol: String, sortCols: String*): Dataset[T] = {
sort((sortCol +: sortCols).map(Column(_)) : _*)
}
/**
* Returns a new Dataset sorted by the given expressions. For example:
* {{{
* ds.sort($"col1", $"col2".desc)
* }}}
*
* @group typedrel
* @since 2.0.0
*/
@scala.annotation.varargs
def sort(sortExprs: Column*): Dataset[T] = {
sortInternal(global = true, sortExprs)
}
/**
* Returns a new Dataset sorted by the given expressions.
* This is an alias of the `sort` function.
*
* @group typedrel
* @since 2.0.0
*/
@scala.annotation.varargs
def orderBy(sortCol: String, sortCols: String*): Dataset[T] = sort(sortCol, sortCols : _*)
/**
* Returns a new Dataset sorted by the given expressions.
* This is an alias of the `sort` function.
*
* @group typedrel
* @since 2.0.0
*/
@scala.annotation.varargs
def orderBy(sortExprs: Column*): Dataset[T] = sort(sortExprs : _*)
/**
* Selects column based on the column name and returns it as a [[Column]].
*
* @note The column name can also reference to a nested column like `a.b`.
*
* @group untypedrel
* @since 2.0.0
*/
def apply(colName: String): Column = col(colName)
/**
* Specifies some hint on the current Dataset. As an example, the following code specifies
* that one of the plan can be broadcasted:
*
* {{{
* df1.join(df2.hint("broadcast"))
* }}}
*
* @group basic
* @since 2.2.0
*/
@scala.annotation.varargs
def hint(name: String, parameters: Any*): Dataset[T] = withTypedPlan {
UnresolvedHint(name, parameters, logicalPlan)
}
/**
* Selects column based on the column name and returns it as a [[Column]].
*
* @note The column name can also reference to a nested column like `a.b`.
*
* @group untypedrel
* @since 2.0.0
*/
def col(colName: String): Column = colName match {
case "*" =>
Column(ResolvedStar(queryExecution.analyzed.output))
case _ =>
if (sqlContext.conf.supportQuotedRegexColumnName) {
colRegex(colName)
} else {
Column(addDataFrameIdToCol(resolve(colName)))
}
}
/**
* Selects a metadata column based on its logical column name, and returns it as a [[Column]].
*
* A metadata column can be accessed this way even if the underlying data source defines a data
* column with a conflicting name.
*
* @group untypedrel
* @since 3.5.0
*/
def metadataColumn(colName: String): Column =
Column(queryExecution.analyzed.getMetadataAttributeByName(colName))
// Attach the dataset id and column position to the column reference, so that we can detect
// ambiguous self-join correctly. See the rule `DetectAmbiguousSelfJoin`.
// This must be called before we return a `Column` that contains `AttributeReference`.
// Note that, the metadata added here are only available in the analyzer, as the analyzer rule
// `DetectAmbiguousSelfJoin` will remove it.
private def addDataFrameIdToCol(expr: NamedExpression): NamedExpression = {
val newExpr = expr transform {
case a: AttributeReference
if sparkSession.conf.get(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED) =>
val metadata = new MetadataBuilder()
.withMetadata(a.metadata)
.putLong(Dataset.DATASET_ID_KEY, id)
.putLong(Dataset.COL_POS_KEY, logicalPlan.output.indexWhere(a.semanticEquals))
.build()
a.withMetadata(metadata)
}
newExpr.asInstanceOf[NamedExpression]
}
/**
* Selects column based on the column name specified as a regex and returns it as [[Column]].
* @group untypedrel
* @since 2.3.0
*/
def colRegex(colName: String): Column = {
val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
colName match {
case ParserUtils.escapedIdentifier(columnNameRegex) =>
Column(UnresolvedRegex(columnNameRegex, None, caseSensitive))
case ParserUtils.qualifiedEscapedIdentifier(nameParts, columnNameRegex) =>
Column(UnresolvedRegex(columnNameRegex, Some(nameParts), caseSensitive))
case _ =>
Column(addDataFrameIdToCol(resolve(colName)))
}
}
/**
* Returns a new Dataset with an alias set.
*
* @group typedrel
* @since 1.6.0
*/
def as(alias: String): Dataset[T] = withTypedPlan {
SubqueryAlias(alias, logicalPlan)
}
/**
* (Scala-specific) Returns a new Dataset with an alias set.
*
* @group typedrel
* @since 2.0.0
*/
def as(alias: Symbol): Dataset[T] = as(alias.name)
/**
* Returns a new Dataset with an alias set. Same as `as`.
*
* @group typedrel
* @since 2.0.0
*/
def alias(alias: String): Dataset[T] = as(alias)
/**
* (Scala-specific) Returns a new Dataset with an alias set. Same as `as`.
*
* @group typedrel
* @since 2.0.0
*/
def alias(alias: Symbol): Dataset[T] = as(alias)
/**
* Selects a set of column based expressions.
* {{{
* ds.select($"colA", $"colB" + 1)
* }}}
*
* @group untypedrel
* @since 2.0.0
*/
@scala.annotation.varargs
def select(cols: Column*): DataFrame = withPlan {
val untypedCols = cols.map {
case typedCol: TypedColumn[_, _] =>
// Checks if a `TypedColumn` has been inserted with
// specific input type and schema by `withInputType`.
val needInputType = typedCol.expr.exists {
case ta: TypedAggregateExpression if ta.inputDeserializer.isEmpty => true
case _ => false
}
if (!needInputType) {
typedCol
} else {
throw QueryCompilationErrors.cannotPassTypedColumnInUntypedSelectError(typedCol.toString)
}
case other => other
}
Project(untypedCols.map(_.named), logicalPlan)
}
/**
* Selects a set of columns. This is a variant of `select` that can only select
* existing columns using column names (i.e. cannot construct expressions).
*
* {{{
* // The following two are equivalent:
* ds.select("colA", "colB")
* ds.select($"colA", $"colB")
* }}}
*
* @group untypedrel
* @since 2.0.0
*/
@scala.annotation.varargs
def select(col: String, cols: String*): DataFrame = select((col +: cols).map(Column(_)) : _*)
/**
* Selects a set of SQL expressions. This is a variant of `select` that accepts
* SQL expressions.
*
* {{{
* // The following are equivalent:
* ds.selectExpr("colA", "colB as newName", "abs(colC)")
* ds.select(expr("colA"), expr("colB as newName"), expr("abs(colC)"))
* }}}
*
* @group untypedrel
* @since 2.0.0
*/
@scala.annotation.varargs
def selectExpr(exprs: String*): DataFrame = sparkSession.withActive {
select(exprs.map { expr =>
Column(sparkSession.sessionState.sqlParser.parseExpression(expr))
}: _*)
}
/**
* Returns a new Dataset by computing the given [[Column]] expression for each element.
*
* {{{
* val ds = Seq(1, 2, 3).toDS()
* val newDS = ds.select(expr("value + 1").as[Int])
* }}}
*
* @group typedrel
* @since 1.6.0
*/
def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = {
implicit val encoder = c1.encoder
val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil, logicalPlan)
if (!encoder.isSerializedAsStructForTopLevel) {
new Dataset[U1](sparkSession, project, encoder)
} else {
// Flattens inner fields of U1
new Dataset[Tuple1[U1]](sparkSession, project, ExpressionEncoder.tuple(encoder)).map(_._1)
}
}
/**
* Internal helper function for building typed selects that return tuples. For simplicity and
* code reuse, we do this without the help of the type system and then use helper functions
* that cast appropriately for the user facing interface.
*/
protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
val encoders = columns.map(_.encoder)
val namedColumns =
columns.map(_.withInputType(exprEnc, logicalPlan.output).named)
val execution = new QueryExecution(sparkSession, Project(namedColumns, logicalPlan))
new Dataset(execution, ExpressionEncoder.tuple(encoders))
}
/**
* Returns a new Dataset by computing the given [[Column]] expressions for each element.
*
* @group typedrel
* @since 1.6.0
*/
def select[U1, U2](c1: TypedColumn[T, U1], c2: TypedColumn[T, U2]): Dataset[(U1, U2)] =
selectUntyped(c1, c2).asInstanceOf[Dataset[(U1, U2)]]
/**
* Returns a new Dataset by computing the given [[Column]] expressions for each element.
*
* @group typedrel
* @since 1.6.0
*/
def select[U1, U2, U3](
c1: TypedColumn[T, U1],
c2: TypedColumn[T, U2],
c3: TypedColumn[T, U3]): Dataset[(U1, U2, U3)] =
selectUntyped(c1, c2, c3).asInstanceOf[Dataset[(U1, U2, U3)]]
/**
* Returns a new Dataset by computing the given [[Column]] expressions for each element.
*
* @group typedrel
* @since 1.6.0
*/
def select[U1, U2, U3, U4](
c1: TypedColumn[T, U1],
c2: TypedColumn[T, U2],
c3: TypedColumn[T, U3],
c4: TypedColumn[T, U4]): Dataset[(U1, U2, U3, U4)] =
selectUntyped(c1, c2, c3, c4).asInstanceOf[Dataset[(U1, U2, U3, U4)]]
/**
* Returns a new Dataset by computing the given [[Column]] expressions for each element.
*
* @group typedrel
* @since 1.6.0
*/
def select[U1, U2, U3, U4, U5](
c1: TypedColumn[T, U1],
c2: TypedColumn[T, U2],
c3: TypedColumn[T, U3],
c4: TypedColumn[T, U4],
c5: TypedColumn[T, U5]): Dataset[(U1, U2, U3, U4, U5)] =
selectUntyped(c1, c2, c3, c4, c5).asInstanceOf[Dataset[(U1, U2, U3, U4, U5)]]
/**
* Filters rows using the given condition.
* {{{
* // The following are equivalent:
* peopleDs.filter($"age" > 15)
* peopleDs.where($"age" > 15)
* }}}
*
* @group typedrel
* @since 1.6.0
*/
def filter(condition: Column): Dataset[T] = withTypedPlan {
Filter(condition.expr, logicalPlan)
}
/**
* Filters rows using the given SQL expression.
* {{{
* peopleDs.filter("age > 15")
* }}}
*
* @group typedrel
* @since 1.6.0
*/
def filter(conditionExpr: String): Dataset[T] = sparkSession.withActive {
filter(Column(sparkSession.sessionState.sqlParser.parseExpression(conditionExpr)))
}
/**
* Filters rows using the given condition. This is an alias for `filter`.
* {{{
* // The following are equivalent:
* peopleDs.filter($"age" > 15)
* peopleDs.where($"age" > 15)
* }}}
*
* @group typedrel
* @since 1.6.0
*/
def where(condition: Column): Dataset[T] = filter(condition)
/**
* Filters rows using the given SQL expression.
* {{{
* peopleDs.where("age > 15")
* }}}
*
* @group typedrel
* @since 1.6.0
*/
def where(conditionExpr: String): Dataset[T] = filter(conditionExpr)
/**
* Groups the Dataset using the specified columns, so we can run aggregation on them. See
* [[RelationalGroupedDataset]] for all the available aggregate functions.
*
* {{{
* // Compute the average for all numeric columns grouped by department.
* ds.groupBy($"department").avg()
*
* // Compute the max age and average salary, grouped by department and gender.
* ds.groupBy($"department", $"gender").agg(Map(
* "salary" -> "avg",
* "age" -> "max"
* ))
* }}}
*
* @group untypedrel
* @since 2.0.0
*/
@scala.annotation.varargs
def groupBy(cols: Column*): RelationalGroupedDataset = {
RelationalGroupedDataset(toDF(), cols.map(_.expr), RelationalGroupedDataset.GroupByType)
}
/**
* Create a multi-dimensional rollup for the current Dataset using the specified columns,
* so we can run aggregation on them.
* See [[RelationalGroupedDataset]] for all the available aggregate functions.
*
* {{{
* // Compute the average for all numeric columns rolled up by department and group.
* ds.rollup($"department", $"group").avg()
*
* // Compute the max age and average salary, rolled up by department and gender.
* ds.rollup($"department", $"gender").agg(Map(
* "salary" -> "avg",
* "age" -> "max"
* ))
* }}}
*
* @group untypedrel
* @since 2.0.0
*/
@scala.annotation.varargs
def rollup(cols: Column*): RelationalGroupedDataset = {
RelationalGroupedDataset(toDF(), cols.map(_.expr), RelationalGroupedDataset.RollupType)
}
/**
* Create a multi-dimensional cube for the current Dataset using the specified columns,
* so we can run aggregation on them.
* See [[RelationalGroupedDataset]] for all the available aggregate functions.
*
* {{{
* // Compute the average for all numeric columns cubed by department and group.
* ds.cube($"department", $"group").avg()
*
* // Compute the max age and average salary, cubed by department and gender.
* ds.cube($"department", $"gender").agg(Map(
* "salary" -> "avg",
* "age" -> "max"
* ))
* }}}
*
* @group untypedrel
* @since 2.0.0
*/
@scala.annotation.varargs
def cube(cols: Column*): RelationalGroupedDataset = {
RelationalGroupedDataset(toDF(), cols.map(_.expr), RelationalGroupedDataset.CubeType)
}
/**
* Groups the Dataset using the specified columns, so that we can run aggregation on them.
* See [[RelationalGroupedDataset]] for all the available aggregate functions.
*
* This is a variant of groupBy that can only group by existing columns using column names
* (i.e. cannot construct expressions).
*
* {{{
* // Compute the average for all numeric columns grouped by department.
* ds.groupBy("department").avg()
*
* // Compute the max age and average salary, grouped by department and gender.
* ds.groupBy($"department", $"gender").agg(Map(
* "salary" -> "avg",
* "age" -> "max"
* ))
* }}}
* @group untypedrel
* @since 2.0.0
*/
@scala.annotation.varargs
def groupBy(col1: String, cols: String*): RelationalGroupedDataset = {
val colNames: Seq[String] = col1 +: cols
RelationalGroupedDataset(
toDF(), colNames.map(colName => resolve(colName)), RelationalGroupedDataset.GroupByType)
}
/**
* (Scala-specific)
* Reduces the elements of this Dataset using the specified binary function. The given `func`
* must be commutative and associative or the result may be non-deterministic.
*
* @group action
* @since 1.6.0
*/
def reduce(func: (T, T) => T): T = withNewRDDExecutionId("reduce") {
rdd.reduce(func)
}
/**
* (Java-specific)
* Reduces the elements of this Dataset using the specified binary function. The given `func`
* must be commutative and associative or the result may be non-deterministic.
*
* @group action
* @since 1.6.0
*/
def reduce(func: ReduceFunction[T]): T = reduce(func.call(_, _))
/**
* (Scala-specific)
* Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given key `func`.
*
* @group typedrel
* @since 2.0.0
*/
def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = {
val withGroupingKey = AppendColumns(func, logicalPlan)
val executed = sparkSession.sessionState.executePlan(withGroupingKey)
new KeyValueGroupedDataset(
encoderFor[K],
encoderFor[T],
executed,
logicalPlan.output,
withGroupingKey.newColumns)
}
/**
* (Java-specific)
* Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given key `func`.
*
* @group typedrel
* @since 2.0.0
*/
def groupByKey[K](func: MapFunction[T, K], encoder: Encoder[K]): KeyValueGroupedDataset[K, T] =
groupByKey(func.call(_))(encoder)
/**
* Create a multi-dimensional rollup for the current Dataset using the specified columns,
* so we can run aggregation on them.
* See [[RelationalGroupedDataset]] for all the available aggregate functions.
*
* This is a variant of rollup that can only group by existing columns using column names
* (i.e. cannot construct expressions).
*
* {{{
* // Compute the average for all numeric columns rolled up by department and group.
* ds.rollup("department", "group").avg()
*
* // Compute the max age and average salary, rolled up by department and gender.
* ds.rollup($"department", $"gender").agg(Map(
* "salary" -> "avg",
* "age" -> "max"
* ))
* }}}
*
* @group untypedrel
* @since 2.0.0
*/
@scala.annotation.varargs
def rollup(col1: String, cols: String*): RelationalGroupedDataset = {
val colNames: Seq[String] = col1 +: cols
RelationalGroupedDataset(
toDF(), colNames.map(colName => resolve(colName)), RelationalGroupedDataset.RollupType)
}
/**
* Create a multi-dimensional cube for the current Dataset using the specified columns,
* so we can run aggregation on them.
* See [[RelationalGroupedDataset]] for all the available aggregate functions.
*
* This is a variant of cube that can only group by existing columns using column names
* (i.e. cannot construct expressions).
*
* {{{
* // Compute the average for all numeric columns cubed by department and group.
* ds.cube("department", "group").avg()
*
* // Compute the max age and average salary, cubed by department and gender.
* ds.cube($"department", $"gender").agg(Map(
* "salary" -> "avg",
* "age" -> "max"
* ))
* }}}
* @group untypedrel
* @since 2.0.0
*/
@scala.annotation.varargs
def cube(col1: String, cols: String*): RelationalGroupedDataset = {
val colNames: Seq[String] = col1 +: cols
RelationalGroupedDataset(
toDF(), colNames.map(colName => resolve(colName)), RelationalGroupedDataset.CubeType)
}
/**
* (Scala-specific) Aggregates on the entire Dataset without groups.
* {{{
* // ds.agg(...) is a shorthand for ds.groupBy().agg(...)
* ds.agg("age" -> "max", "salary" -> "avg")
* ds.groupBy().agg("age" -> "max", "salary" -> "avg")
* }}}
*
* @group untypedrel
* @since 2.0.0
*/
def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = {
groupBy().agg(aggExpr, aggExprs : _*)
}
/**
* (Scala-specific) Aggregates on the entire Dataset without groups.
* {{{
* // ds.agg(...) is a shorthand for ds.groupBy().agg(...)
* ds.agg(Map("age" -> "max", "salary" -> "avg"))
* ds.groupBy().agg(Map("age" -> "max", "salary" -> "avg"))
* }}}
*
* @group untypedrel
* @since 2.0.0
*/
def agg(exprs: Map[String, String]): DataFrame = groupBy().agg(exprs)
/**
* (Java-specific) Aggregates on the entire Dataset without groups.
* {{{
* // ds.agg(...) is a shorthand for ds.groupBy().agg(...)
* ds.agg(Map("age" -> "max", "salary" -> "avg"))
* ds.groupBy().agg(Map("age" -> "max", "salary" -> "avg"))
* }}}
*
* @group untypedrel
* @since 2.0.0
*/
def agg(exprs: java.util.Map[String, String]): DataFrame = groupBy().agg(exprs)
/**
* Aggregates on the entire Dataset without groups.
* {{{
* // ds.agg(...) is a shorthand for ds.groupBy().agg(...)
* ds.agg(max($"age"), avg($"salary"))
* ds.groupBy().agg(max($"age"), avg($"salary"))
* }}}
*
* @group untypedrel
* @since 2.0.0
*/
@scala.annotation.varargs
def agg(expr: Column, exprs: Column*): DataFrame = groupBy().agg(expr, exprs : _*)
/**
* Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns set.
* This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation,
* which cannot be reversed.
*
* This function is useful to massage a DataFrame into a format where some
* columns are identifier columns ("ids"), while all other columns ("values")
* are "unpivoted" to the rows, leaving just two non-id columns, named as given
* by `variableColumnName` and `valueColumnName`.
*
* {{{
* val df = Seq((1, 11, 12L), (2, 21, 22L)).toDF("id", "int", "long")
* df.show()
* // output:
* // +---+---+----+
* // | id|int|long|
* // +---+---+----+
* // | 1| 11| 12|
* // | 2| 21| 22|
* // +---+---+----+
*
* df.unpivot(Array($"id"), Array($"int", $"long"), "variable", "value").show()
* // output:
* // +---+--------+-----+
* // | id|variable|value|
* // +---+--------+-----+
* // | 1| int| 11|
* // | 1| long| 12|
* // | 2| int| 21|
* // | 2| long| 22|
* // +---+--------+-----+
* // schema:
* //root
* // |-- id: integer (nullable = false)
* // |-- variable: string (nullable = false)
* // |-- value: long (nullable = true)
* }}}
*
* When no "id" columns are given, the unpivoted DataFrame consists of only the
* "variable" and "value" columns.
*
* All "value" columns must share a least common data type. Unless they are the same data type,
* all "value" columns are cast to the nearest common data type. For instance,
* types `IntegerType` and `LongType` are cast to `LongType`, while `IntegerType` and `StringType`
* do not have a common data type and `unpivot` fails with an `AnalysisException`.
*
* @param ids Id columns
* @param values Value columns to unpivot
* @param variableColumnName Name of the variable column
* @param valueColumnName Name of the value column
*
* @group untypedrel
* @since 3.4.0
*/
def unpivot(
ids: Array[Column],
values: Array[Column],
variableColumnName: String,
valueColumnName: String): DataFrame = withPlan {
Unpivot(
Some(ids.map(_.named)),
Some(values.map(v => Seq(v.named))),
None,
variableColumnName,
Seq(valueColumnName),
logicalPlan
)
}
/**
* Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns set.
* This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation,
* which cannot be reversed.
*
* @see `org.apache.spark.sql.Dataset.unpivot(Array, Array, String, String)`
*
* This is equivalent to calling `Dataset#unpivot(Array, Array, String, String)`
* where `values` is set to all non-id columns that exist in the DataFrame.
*
* @param ids Id columns
* @param variableColumnName Name of the variable column
* @param valueColumnName Name of the value column
*
* @group untypedrel
* @since 3.4.0
*/
def unpivot(
ids: Array[Column],
variableColumnName: String,
valueColumnName: String): DataFrame = withPlan {
Unpivot(
Some(ids.map(_.named)),
None,
None,
variableColumnName,
Seq(valueColumnName),
logicalPlan
)
}
/**
* Called from Python as Seq[Column] are easier to create via py4j than Array[Column].
* We use Array[Column] for unpivot rather than Seq[Column] as those are Java-friendly.
*/
private[sql] def unpivotWithSeq(
ids: Seq[Column],
values: Seq[Column],
variableColumnName: String,
valueColumnName: String): DataFrame =
unpivot(ids.toArray, values.toArray, variableColumnName, valueColumnName)
/**
* Called from Python as Seq[Column] are easier to create via py4j than Array[Column].
* We use Array[Column] for unpivot rather than Seq[Column] as those are Java-friendly.
*/
private[sql] def unpivotWithSeq(
ids: Seq[Column],
variableColumnName: String,
valueColumnName: String): DataFrame =
unpivot(ids.toArray, variableColumnName, valueColumnName)
/**
* Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns set.
* This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation,
* which cannot be reversed. This is an alias for `unpivot`.
*
* @see `org.apache.spark.sql.Dataset.unpivot(Array, Array, String, String)`
*
* @param ids Id columns
* @param values Value columns to unpivot
* @param variableColumnName Name of the variable column
* @param valueColumnName Name of the value column
*
* @group untypedrel
* @since 3.4.0
*/
def melt(
ids: Array[Column],
values: Array[Column],
variableColumnName: String,
valueColumnName: String): DataFrame =
unpivot(ids, values, variableColumnName, valueColumnName)
/**
* Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns set.
* This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation,
* which cannot be reversed. This is an alias for `unpivot`.
*
* @see `org.apache.spark.sql.Dataset.unpivot(Array, Array, String, String)`
*
* This is equivalent to calling `Dataset#unpivot(Array, Array, String, String)`
* where `values` is set to all non-id columns that exist in the DataFrame.
*
* @param ids Id columns
* @param variableColumnName Name of the variable column
* @param valueColumnName Name of the value column
*
* @group untypedrel
* @since 3.4.0
*/
def melt(
ids: Array[Column],
variableColumnName: String,
valueColumnName: String): DataFrame =
unpivot(ids, variableColumnName, valueColumnName)
/**
* Define (named) metrics to observe on the Dataset. This method returns an 'observed' Dataset
* that returns the same result as the input, with the following guarantees:
*
* - It will compute the defined aggregates (metrics) on all the data that is flowing through
* the Dataset at that point.
* - It will report the value of the defined aggregate columns as soon as we reach a completion
* point. A completion point is either the end of a query (batch mode) or the end of a streaming
* epoch. The value of the aggregates only reflects the data processed since the previous
* completion point.
*
* Please note that continuous execution is currently not supported.
*
* The metrics columns must either contain a literal (e.g. lit(42)), or should contain one or
* more aggregate functions (e.g. sum(a) or sum(a + b) + avg(c) - lit(1)). Expressions that
* contain references to the input Dataset's columns must always be wrapped in an aggregate
* function.
*
* A user can observe these metrics by either adding
* [[org.apache.spark.sql.streaming.StreamingQueryListener]] or a
* [[org.apache.spark.sql.util.QueryExecutionListener]] to the spark session.
*
* {{{
* // Monitor the metrics using a listener.
* spark.streams.addListener(new StreamingQueryListener() {
* override def onQueryStarted(event: QueryStartedEvent): Unit = {}
* override def onQueryProgress(event: QueryProgressEvent): Unit = {
* event.progress.observedMetrics.asScala.get("my_event").foreach { row =>
* // Trigger if the number of errors exceeds 5 percent
* val num_rows = row.getAs[Long]("rc")
* val num_error_rows = row.getAs[Long]("erc")
* val ratio = num_error_rows.toDouble / num_rows
* if (ratio > 0.05) {
* // Trigger alert
* }
* }
* }
* override def onQueryTerminated(event: QueryTerminatedEvent): Unit = {}
* })
* // Observe row count (rc) and error row count (erc) in the streaming Dataset
* val observed_ds = ds.observe("my_event", count(lit(1)).as("rc"), count($"error").as("erc"))
* observed_ds.writeStream.format("...").start()
* }}}
*
* @group typedrel
* @since 3.0.0
*/
@varargs
def observe(name: String, expr: Column, exprs: Column*): Dataset[T] = withTypedPlan {
CollectMetrics(name, (expr +: exprs).map(_.named), logicalPlan, id)
}
/**
* Observe (named) metrics through an `org.apache.spark.sql.Observation` instance.
* This is equivalent to calling `observe(String, Column, Column*)` but does not require
* adding `org.apache.spark.sql.util.QueryExecutionListener` to the spark session.
* This method does not support streaming datasets.
*
* A user can retrieve the metrics by accessing `org.apache.spark.sql.Observation.get`.
*
* {{{
* // Observe row count (rows) and highest id (maxid) in the Dataset while writing it
* val observation = Observation("my_metrics")
* val observed_ds = ds.observe(observation, count(lit(1)).as("rows"), max($"id").as("maxid"))
* observed_ds.write.parquet("ds.parquet")
* val metrics = observation.get
* }}}
*
* @throws IllegalArgumentException If this is a streaming Dataset (this.isStreaming == true)
*
* @group typedrel
* @since 3.3.0
*/
@varargs
def observe(observation: Observation, expr: Column, exprs: Column*): Dataset[T] = {
observation.on(this, expr, exprs: _*)
}
/**
* Returns a new Dataset by taking the first `n` rows. The difference between this function
* and `head` is that `head` is an action and returns an array (by triggering query execution)
* while `limit` returns a new Dataset.
*
* @group typedrel
* @since 2.0.0
*/
def limit(n: Int): Dataset[T] = withTypedPlan {
Limit(Literal(n), logicalPlan)
}
/**
* Returns a new Dataset by skipping the first `n` rows.
*
* @group typedrel
* @since 3.4.0
*/
def offset(n: Int): Dataset[T] = withTypedPlan {
Offset(Literal(n), logicalPlan)
}
// This breaks caching, but it's usually ok because it addresses a very specific use case:
// using union to union many files or partitions.
private def combineUnions(plan: LogicalPlan): LogicalPlan = {
plan.transformDownWithPruning(_.containsPattern(TreePattern.UNION)) {
case Distinct(u: Union) =>
Distinct(flattenUnion(u, isUnionDistinct = true))
// Only handle distinct-like 'Deduplicate', where the keys == output
case Deduplicate(keys: Seq[Attribute], u: Union) if AttributeSet(keys) == u.outputSet =>
Deduplicate(keys, flattenUnion(u, true))
case u: Union =>
flattenUnion(u, isUnionDistinct = false)
}
}
private def flattenUnion(u: Union, isUnionDistinct: Boolean): Union = {
var changed = false
// We only need to look at the direct children of Union, as the nested adjacent Unions should
// have been combined already by previous `Dataset#union` transformations.
val newChildren = u.children.flatMap {
case Distinct(Union(children, byName, allowMissingCol))
if isUnionDistinct && byName == u.byName && allowMissingCol == u.allowMissingCol =>
changed = true
children
// Only handle distinct-like 'Deduplicate', where the keys == output
case Deduplicate(keys: Seq[Attribute], child @ Union(children, byName, allowMissingCol))
if AttributeSet(keys) == child.outputSet && isUnionDistinct && byName == u.byName &&
allowMissingCol == u.allowMissingCol =>
changed = true
children
case Union(children, byName, allowMissingCol)
if !isUnionDistinct && byName == u.byName && allowMissingCol == u.allowMissingCol =>
changed = true
children
case other =>
Seq(other)
}
if (changed) {
val newUnion = Union(newChildren)
newUnion.copyTagsFrom(u)
newUnion
} else {
u
}
}
/**
* Returns a new Dataset containing union of rows in this Dataset and another Dataset.
*
* This is equivalent to `UNION ALL` in SQL. To do a SQL-style set union (that does
* deduplication of elements), use this function followed by a [[distinct]].
*
* Also as standard in SQL, this function resolves columns by position (not by name):
*
* {{{
* val df1 = Seq((1, 2, 3)).toDF("col0", "col1", "col2")
* val df2 = Seq((4, 5, 6)).toDF("col1", "col2", "col0")
* df1.union(df2).show
*
* // output:
* // +----+----+----+
* // |col0|col1|col2|
* // +----+----+----+
* // | 1| 2| 3|
* // | 4| 5| 6|
* // +----+----+----+
* }}}
*
* Notice that the column positions in the schema aren't necessarily matched with the
* fields in the strongly typed objects in a Dataset. This function resolves columns
* by their positions in the schema, not the fields in the strongly typed objects. Use
* [[unionByName]] to resolve columns by field name in the typed objects.
*
* @group typedrel
* @since 2.0.0
*/
def union(other: Dataset[T]): Dataset[T] = withSetOperator {
combineUnions(Union(logicalPlan, other.logicalPlan))
}
/**
* Returns a new Dataset containing union of rows in this Dataset and another Dataset.
* This is an alias for `union`.
*
* This is equivalent to `UNION ALL` in SQL. To do a SQL-style set union (that does
* deduplication of elements), use this function followed by a [[distinct]].
*
* Also as standard in SQL, this function resolves columns by position (not by name).
*
* @group typedrel
* @since 2.0.0
*/
def unionAll(other: Dataset[T]): Dataset[T] = union(other)
/**
* Returns a new Dataset containing union of rows in this Dataset and another Dataset.
*
* This is different from both `UNION ALL` and `UNION DISTINCT` in SQL. To do a SQL-style set
* union (that does deduplication of elements), use this function followed by a [[distinct]].
*
* The difference between this function and [[union]] is that this function
* resolves columns by name (not by position):
*
* {{{
* val df1 = Seq((1, 2, 3)).toDF("col0", "col1", "col2")
* val df2 = Seq((4, 5, 6)).toDF("col1", "col2", "col0")
* df1.unionByName(df2).show
*
* // output:
* // +----+----+----+
* // |col0|col1|col2|
* // +----+----+----+
* // | 1| 2| 3|
* // | 6| 4| 5|
* // +----+----+----+
* }}}
*
* Note that this supports nested columns in struct and array types. Nested columns in map types
* are not currently supported.
*
* @group typedrel
* @since 2.3.0
*/
def unionByName(other: Dataset[T]): Dataset[T] = unionByName(other, false)
/**
* Returns a new Dataset containing union of rows in this Dataset and another Dataset.
*
* The difference between this function and [[union]] is that this function
* resolves columns by name (not by position).
*
* When the parameter `allowMissingColumns` is `true`, the set of column names
* in this and other `Dataset` can differ; missing columns will be filled with null.
* Further, the missing columns of this `Dataset` will be added at the end
* in the schema of the union result:
*
* {{{
* val df1 = Seq((1, 2, 3)).toDF("col0", "col1", "col2")
* val df2 = Seq((4, 5, 6)).toDF("col1", "col0", "col3")
* df1.unionByName(df2, true).show
*
* // output: "col3" is missing at left df1 and added at the end of schema.
* // +----+----+----+----+
* // |col0|col1|col2|col3|
* // +----+----+----+----+
* // | 1| 2| 3|NULL|
* // | 5| 4|NULL| 6|
* // +----+----+----+----+
*
* df2.unionByName(df1, true).show
*
* // output: "col2" is missing at left df2 and added at the end of schema.
* // +----+----+----+----+
* // |col1|col0|col3|col2|
* // +----+----+----+----+
* // | 4| 5| 6|NULL|
* // | 2| 1|NULL| 3|
* // +----+----+----+----+
* }}}
*
* Note that this supports nested columns in struct and array types. With `allowMissingColumns`,
* missing nested columns of struct columns with the same name will also be filled with null
* values and added to the end of struct. Nested columns in map types are not currently
* supported.
*
* @group typedrel
* @since 3.1.0
*/
def unionByName(other: Dataset[T], allowMissingColumns: Boolean): Dataset[T] = withSetOperator {
// We need to resolve the by-name Union first, as the underlying Unions are already resolved
// and we can only combine adjacent Unions if they are all resolved.
val resolvedUnion = sparkSession.sessionState.executePlan(
Union(logicalPlan :: other.logicalPlan :: Nil, true, allowMissingColumns))
combineUnions(resolvedUnion.analyzed)
}
/**
* Returns a new Dataset containing rows only in both this Dataset and another Dataset.
* This is equivalent to `INTERSECT` in SQL.
*
* @note Equality checking is performed directly on the encoded representation of the data
* and thus is not affected by a custom `equals` function defined on `T`.
*
* @group typedrel
* @since 1.6.0
*/
def intersect(other: Dataset[T]): Dataset[T] = withSetOperator {
Intersect(logicalPlan, other.logicalPlan, isAll = false)
}
/**
* Returns a new Dataset containing rows only in both this Dataset and another Dataset while
* preserving the duplicates.
* This is equivalent to `INTERSECT ALL` in SQL.
*
* @note Equality checking is performed directly on the encoded representation of the data
* and thus is not affected by a custom `equals` function defined on `T`. Also as standard
* in SQL, this function resolves columns by position (not by name).
*
* @group typedrel
* @since 2.4.0
*/
def intersectAll(other: Dataset[T]): Dataset[T] = withSetOperator {
Intersect(logicalPlan, other.logicalPlan, isAll = true)
}
/**
* Returns a new Dataset containing rows in this Dataset but not in another Dataset.
* This is equivalent to `EXCEPT DISTINCT` in SQL.
*
* @note Equality checking is performed directly on the encoded representation of the data
* and thus is not affected by a custom `equals` function defined on `T`.
*
* @group typedrel
* @since 2.0.0
*/
def except(other: Dataset[T]): Dataset[T] = withSetOperator {
Except(logicalPlan, other.logicalPlan, isAll = false)
}
/**
* Returns a new Dataset containing rows in this Dataset but not in another Dataset while
* preserving the duplicates.
* This is equivalent to `EXCEPT ALL` in SQL.
*
* @note Equality checking is performed directly on the encoded representation of the data
* and thus is not affected by a custom `equals` function defined on `T`. Also as standard in
* SQL, this function resolves columns by position (not by name).
*
* @group typedrel
* @since 2.4.0
*/
def exceptAll(other: Dataset[T]): Dataset[T] = withSetOperator {
Except(logicalPlan, other.logicalPlan, isAll = true)
}
/**
* Returns a new [[Dataset]] by sampling a fraction of rows (without replacement),
* using a user-supplied seed.
*
* @param fraction Fraction of rows to generate, range [0.0, 1.0].
* @param seed Seed for sampling.
*
* @note This is NOT guaranteed to provide exactly the fraction of the count
* of the given [[Dataset]].
*
* @group typedrel
* @since 2.3.0
*/
def sample(fraction: Double, seed: Long): Dataset[T] = {
sample(withReplacement = false, fraction = fraction, seed = seed)
}
/**
* Returns a new [[Dataset]] by sampling a fraction of rows (without replacement),
* using a random seed.
*
* @param fraction Fraction of rows to generate, range [0.0, 1.0].
*
* @note This is NOT guaranteed to provide exactly the fraction of the count
* of the given [[Dataset]].
*
* @group typedrel
* @since 2.3.0
*/
def sample(fraction: Double): Dataset[T] = {
sample(withReplacement = false, fraction = fraction)
}
/**
* Returns a new [[Dataset]] by sampling a fraction of rows, using a user-supplied seed.
*
* @param withReplacement Sample with replacement or not.
* @param fraction Fraction of rows to generate, range [0.0, 1.0].
* @param seed Seed for sampling.
*
* @note This is NOT guaranteed to provide exactly the fraction of the count
* of the given [[Dataset]].
*
* @group typedrel
* @since 1.6.0
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T] = {
withTypedPlan {
Sample(0.0, fraction, withReplacement, seed, logicalPlan)
}
}
/**
* Returns a new [[Dataset]] by sampling a fraction of rows, using a random seed.
*
* @param withReplacement Sample with replacement or not.
* @param fraction Fraction of rows to generate, range [0.0, 1.0].
*
* @note This is NOT guaranteed to provide exactly the fraction of the total count
* of the given [[Dataset]].
*
* @group typedrel
* @since 1.6.0
*/
def sample(withReplacement: Boolean, fraction: Double): Dataset[T] = {
sample(withReplacement, fraction, Utils.random.nextLong)
}
/**
* Randomly splits this Dataset with the provided weights.
*
* @param weights weights for splits, will be normalized if they don't sum to 1.
* @param seed Seed for sampling.
*
* For Java API, use [[randomSplitAsList]].
*
* @group typedrel
* @since 2.0.0
*/
def randomSplit(weights: Array[Double], seed: Long): Array[Dataset[T]] = {
require(weights.forall(_ >= 0),
s"Weights must be nonnegative, but got ${weights.mkString("[", ",", "]")}")
require(weights.sum > 0,
s"Sum of weights must be positive, but got ${weights.mkString("[", ",", "]")}")
// It is possible that the underlying dataframe doesn't guarantee the ordering of rows in its
// constituent partitions each time a split is materialized which could result in
// overlapping splits. To prevent this, we explicitly sort each input partition to make the
// ordering deterministic. Note that MapTypes cannot be sorted and are explicitly pruned out
// from the sort order.
val sortOrder = logicalPlan.output
.filter(attr => RowOrdering.isOrderable(attr.dataType))
.map(SortOrder(_, Ascending))
val plan = if (sortOrder.nonEmpty) {
Sort(sortOrder, global = false, logicalPlan)
} else {
// SPARK-12662: If sort order is empty, we materialize the dataset to guarantee determinism
cache()
logicalPlan
}
val sum = weights.sum
val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
normalizedCumWeights.sliding(2).map { x =>
new Dataset[T](
sparkSession, Sample(x(0), x(1), withReplacement = false, seed, plan), encoder)
}.toArray
}
/**
* Returns a Java list that contains randomly split Dataset with the provided weights.
*
* @param weights weights for splits, will be normalized if they don't sum to 1.
* @param seed Seed for sampling.
*
* @group typedrel
* @since 2.0.0
*/
def randomSplitAsList(weights: Array[Double], seed: Long): java.util.List[Dataset[T]] = {
val values = randomSplit(weights, seed)
java.util.Arrays.asList(values : _*)
}
/**
* Randomly splits this Dataset with the provided weights.
*
* @param weights weights for splits, will be normalized if they don't sum to 1.
* @group typedrel
* @since 2.0.0
*/
def randomSplit(weights: Array[Double]): Array[Dataset[T]] = {
randomSplit(weights, Utils.random.nextLong)
}
/**
* Randomly splits this Dataset with the provided weights. Provided for the Python Api.
*
* @param weights weights for splits, will be normalized if they don't sum to 1.
* @param seed Seed for sampling.
*/
private[spark] def randomSplit(weights: List[Double], seed: Long): Array[Dataset[T]] = {
randomSplit(weights.toArray, seed)
}
/**
* (Scala-specific) Returns a new Dataset where each row has been expanded to zero or more
* rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. The columns of
* the input row are implicitly joined with each row that is output by the function.
*
* Given that this is deprecated, as an alternative, you can explode columns either using
* `functions.explode()` or `flatMap()`. The following example uses these alternatives to count
* the number of books that contain a given word:
*
* {{{
* case class Book(title: String, words: String)
* val ds: Dataset[Book]
*
* val allWords = ds.select($"title", explode(split($"words", " ")).as("word"))
*
* val bookCountPerWord = allWords.groupBy("word").agg(count_distinct("title"))
* }}}
*
* Using `flatMap()` this can similarly be exploded as:
*
* {{{
* ds.flatMap(_.words.split(" "))
* }}}
*
* @group untypedrel
* @since 2.0.0
*/
@deprecated("use flatMap() or select() with functions.explode() instead", "2.0.0")
def explode[A <: Product : TypeTag](input: Column*)(f: Row => TraversableOnce[A]): DataFrame = {
val elementSchema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
val convert = CatalystTypeConverters.createToCatalystConverter(elementSchema)
val rowFunction =
f.andThen(_.map(convert(_).asInstanceOf[InternalRow]))
val generator = UserDefinedGenerator(elementSchema, rowFunction, input.map(_.expr))
withPlan {
Generate(generator, unrequiredChildIndex = Nil, outer = false,
qualifier = None, generatorOutput = Nil, logicalPlan)
}
}
/**
* (Scala-specific) Returns a new Dataset where a single column has been expanded to zero
* or more rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. All
* columns of the input row are implicitly joined with each value that is output by the function.
*
* Given that this is deprecated, as an alternative, you can explode columns either using
* `functions.explode()`:
*
* {{{
* ds.select(explode(split($"words", " ")).as("word"))
* }}}
*
* or `flatMap()`:
*
* {{{
* ds.flatMap(_.words.split(" "))
* }}}
*
* @group untypedrel
* @since 2.0.0
*/
@deprecated("use flatMap() or select() with functions.explode() instead", "2.0.0")
def explode[A, B : TypeTag](inputColumn: String, outputColumn: String)(f: A => TraversableOnce[B])
: DataFrame = {
val dataType = ScalaReflection.schemaFor[B].dataType
val attributes = AttributeReference(outputColumn, dataType)() :: Nil
// TODO handle the metadata?
val elementSchema = attributes.toStructType
def rowFunction(row: Row): TraversableOnce[InternalRow] = {
val convert = CatalystTypeConverters.createToCatalystConverter(dataType)
f(row(0).asInstanceOf[A]).map(o => InternalRow(convert(o)))
}
val generator = UserDefinedGenerator(elementSchema, rowFunction, apply(inputColumn).expr :: Nil)
withPlan {
Generate(generator, unrequiredChildIndex = Nil, outer = false,
qualifier = None, generatorOutput = Nil, logicalPlan)
}
}
/**
* Returns a new Dataset by adding a column or replacing the existing column that has
* the same name.
*
* `column`'s expression must only refer to attributes supplied by this Dataset. It is an
* error to add a column that refers to some other Dataset.
*
* @note this method introduces a projection internally. Therefore, calling it multiple times,
* for instance, via loops in order to add multiple columns can generate big plans which
* can cause performance issues and even `StackOverflowException`. To avoid this,
* use `select` with the multiple columns at once.
*
* @group untypedrel
* @since 2.0.0
*/
def withColumn(colName: String, col: Column): DataFrame = withColumns(Seq(colName), Seq(col))
/**
* (Scala-specific) Returns a new Dataset by adding columns or replacing the existing columns
* that has the same names.
*
* `colsMap` is a map of column name and column, the column must only refer to attributes
* supplied by this Dataset. It is an error to add columns that refers to some other Dataset.
*
* @group untypedrel
* @since 3.3.0
*/
def withColumns(colsMap: Map[String, Column]): DataFrame = {
val (colNames, newCols) = colsMap.toSeq.unzip
withColumns(colNames, newCols)
}
/**
* (Java-specific) Returns a new Dataset by adding columns or replacing the existing columns
* that has the same names.
*
* `colsMap` is a map of column name and column, the column must only refer to attribute
* supplied by this Dataset. It is an error to add columns that refers to some other Dataset.
*
* @group untypedrel
* @since 3.3.0
*/
def withColumns(colsMap: java.util.Map[String, Column]): DataFrame = withColumns(
colsMap.asScala.toMap
)
/**
* Returns a new Dataset by adding columns or replacing the existing columns that has
* the same names.
*/
private[spark] def withColumns(colNames: Seq[String], cols: Seq[Column]): DataFrame = {
require(colNames.size == cols.size,
s"The size of column names: ${colNames.size} isn't equal to " +
s"the size of columns: ${cols.size}")
SchemaUtils.checkColumnNameDuplication(
colNames,
sparkSession.sessionState.conf.caseSensitiveAnalysis)
val resolver = sparkSession.sessionState.analyzer.resolver
val output = queryExecution.analyzed.output
val columnSeq = colNames.zip(cols)
val replacedAndExistingColumns = output.map { field =>
columnSeq.find { case (colName, _) =>
resolver(field.name, colName)
} match {
case Some((colName: String, col: Column)) => col.as(colName)
case _ => Column(field)
}
}
val newColumns = columnSeq.filter { case (colName, col) =>
!output.exists(f => resolver(f.name, colName))
}.map { case (colName, col) => col.as(colName) }
select(replacedAndExistingColumns ++ newColumns : _*)
}
/**
* Returns a new Dataset by adding columns with metadata.
*/
private[spark] def withColumns(
colNames: Seq[String],
cols: Seq[Column],
metadata: Seq[Metadata]): DataFrame = {
require(colNames.size == metadata.size,
s"The size of column names: ${colNames.size} isn't equal to " +
s"the size of metadata elements: ${metadata.size}")
val newCols = colNames.zip(cols).zip(metadata).map { case ((colName, col), metadata) =>
col.as(colName, metadata)
}
withColumns(colNames, newCols)
}
/**
* Returns a new Dataset by adding a column with metadata.
*/
private[spark] def withColumn(colName: String, col: Column, metadata: Metadata): DataFrame =
withColumns(Seq(colName), Seq(col), Seq(metadata))
/**
* Returns a new Dataset with a column renamed.
* This is a no-op if schema doesn't contain existingName.
*
* @group untypedrel
* @since 2.0.0
*/
def withColumnRenamed(existingName: String, newName: String): DataFrame = {
val resolver = sparkSession.sessionState.analyzer.resolver
val output = queryExecution.analyzed.output
val shouldRename = output.exists(f => resolver(f.name, existingName))
if (shouldRename) {
val columns = output.map { col =>
if (resolver(col.name, existingName)) {
Column(col).as(newName)
} else {
Column(col)
}
}
select(columns : _*)
} else {
toDF()
}
}
/**
* (Scala-specific)
* Returns a new Dataset with a columns renamed.
* This is a no-op if schema doesn't contain existingName.
*
* `colsMap` is a map of existing column name and new column name.
*
* @throws AnalysisException if there are duplicate names in resulting projection
*
* @group untypedrel
* @since 3.4.0
*/
@throws[AnalysisException]
def withColumnsRenamed(colsMap: Map[String, String]): DataFrame = {
val resolver = sparkSession.sessionState.analyzer.resolver
val output: Seq[NamedExpression] = queryExecution.analyzed.output
val projectList = colsMap.foldLeft(output) {
case (attrs, (existingName, newName)) =>
attrs.map(attr =>
if (resolver(attr.name, existingName)) {
Alias(attr, newName)()
} else {
attr
}
)
}
SchemaUtils.checkColumnNameDuplication(
projectList.map(_.name),
sparkSession.sessionState.conf.caseSensitiveAnalysis)
withPlan(Project(projectList, logicalPlan))
}
/**
* (Java-specific)
* Returns a new Dataset with a columns renamed.
* This is a no-op if schema doesn't contain existingName.
*
* `colsMap` is a map of existing column name and new column name.
*
* @group untypedrel
* @since 3.4.0
*/
def withColumnsRenamed(colsMap: java.util.Map[String, String]): DataFrame =
withColumnsRenamed(colsMap.asScala.toMap)
/**
* Returns a new Dataset by updating an existing column with metadata.
*
* @group untypedrel
* @since 3.3.0
*/
def withMetadata(columnName: String, metadata: Metadata): DataFrame = {
withColumn(columnName, col(columnName), metadata)
}
/**
* Returns a new Dataset with a column dropped. This is a no-op if schema doesn't contain
* column name.
*
* This method can only be used to drop top level columns. the colName string is treated
* literally without further interpretation.
*
* Note: `drop(colName)` has different semantic with `drop(col(colName))`, for example:
* 1, multi column have the same colName:
* {{{
* val df1 = spark.range(0, 2).withColumn("key1", lit(1))
* val df2 = spark.range(0, 2).withColumn("key2", lit(2))
* val df3 = df1.join(df2)
*
* df3.show
* // +---+----+---+----+
* // | id|key1| id|key2|
* // +---+----+---+----+
* // | 0| 1| 0| 2|
* // | 0| 1| 1| 2|
* // | 1| 1| 0| 2|
* // | 1| 1| 1| 2|
* // +---+----+---+----+
*
* df3.drop("id").show()
* // output: the two 'id' columns are both dropped.
* // |key1|key2|
* // +----+----+
* // | 1| 2|
* // | 1| 2|
* // | 1| 2|
* // | 1| 2|
* // +----+----+
*
* df3.drop(col("id")).show()
* // ...AnalysisException: [AMBIGUOUS_REFERENCE] Reference `id` is ambiguous...
* }}}
*
* 2, colName contains special characters, like dot.
* {{{
* val df = spark.range(0, 2).withColumn("a.b.c", lit(1))
*
* df.show()
* // +---+-----+
* // | id|a.b.c|
* // +---+-----+
* // | 0| 1|
* // | 1| 1|
* // +---+-----+
*
* df.drop("a.b.c").show()
* // +---+
* // | id|
* // +---+
* // | 0|
* // | 1|
* // +---+
*
* df.drop(col("a.b.c")).show()
* // no column match the expression 'a.b.c'
* // +---+-----+
* // | id|a.b.c|
* // +---+-----+
* // | 0| 1|
* // | 1| 1|
* // +---+-----+
* }}}
*
* @group untypedrel
* @since 2.0.0
*/
def drop(colName: String): DataFrame = {
drop(Seq(colName) : _*)
}
/**
* Returns a new Dataset with columns dropped.
* This is a no-op if schema doesn't contain column name(s).
*
* This method can only be used to drop top level columns. the colName string is treated literally
* without further interpretation.
*
* @group untypedrel
* @since 2.0.0
*/
@scala.annotation.varargs
def drop(colNames: String*): DataFrame = {
val resolver = sparkSession.sessionState.analyzer.resolver
val allColumns = queryExecution.analyzed.output
val remainingCols = allColumns.filter { attribute =>
colNames.forall(n => !resolver(attribute.name, n))
}.map(attribute => Column(attribute))
if (remainingCols.size == allColumns.size) {
toDF()
} else {
this.select(remainingCols: _*)
}
}
/**
* Returns a new Dataset with column dropped.
*
* This method can only be used to drop top level column.
* This version of drop accepts a [[Column]] rather than a name.
* This is a no-op if the Dataset doesn't have a column
* with an equivalent expression.
*
* Note: `drop(col(colName))` has different semantic with `drop(colName)`,
* please refer to `Dataset#drop(colName: String)`.
*
* @group untypedrel
* @since 2.0.0
*/
def drop(col: Column): DataFrame = {
drop(col, Seq.empty : _*)
}
/**
* Returns a new Dataset with columns dropped.
*
* This method can only be used to drop top level columns.
* This is a no-op if the Dataset doesn't have a columns
* with an equivalent expression.
*
* @group untypedrel
* @since 3.4.0
*/
@scala.annotation.varargs
def drop(col: Column, cols: Column*): DataFrame = withPlan {
DataFrameDropColumns((col +: cols).map(_.expr), logicalPlan)
}
/**
* Returns a new Dataset that contains only the unique rows from this Dataset.
* This is an alias for `distinct`.
*
* For a static batch [[Dataset]], it just drops duplicate rows. For a streaming [[Dataset]], it
* will keep all data across triggers as intermediate state to drop duplicates rows. You can use
* [[withWatermark]] to limit how late the duplicate data can be and system will accordingly limit
* the state. In addition, too late data older than watermark will be dropped to avoid any
* possibility of duplicates.
*
* @group typedrel
* @since 2.0.0
*/
def dropDuplicates(): Dataset[T] = dropDuplicates(this.columns)
/**
* (Scala-specific) Returns a new Dataset with duplicate rows removed, considering only
* the subset of columns.
*
* For a static batch [[Dataset]], it just drops duplicate rows. For a streaming [[Dataset]], it
* will keep all data across triggers as intermediate state to drop duplicates rows. You can use
* [[withWatermark]] to limit how late the duplicate data can be and system will accordingly limit
* the state. In addition, too late data older than watermark will be dropped to avoid any
* possibility of duplicates.
*
* @group typedrel
* @since 2.0.0
*/
def dropDuplicates(colNames: Seq[String]): Dataset[T] = withTypedPlan {
val groupCols = groupColsFromDropDuplicates(colNames)
Deduplicate(groupCols, logicalPlan)
}
/**
* Returns a new Dataset with duplicate rows removed, considering only
* the subset of columns.
*
* For a static batch [[Dataset]], it just drops duplicate rows. For a streaming [[Dataset]], it
* will keep all data across triggers as intermediate state to drop duplicates rows. You can use
* [[withWatermark]] to limit how late the duplicate data can be and system will accordingly limit
* the state. In addition, too late data older than watermark will be dropped to avoid any
* possibility of duplicates.
*
* @group typedrel
* @since 2.0.0
*/
def dropDuplicates(colNames: Array[String]): Dataset[T] = dropDuplicates(colNames.toSeq)
/**
* Returns a new [[Dataset]] with duplicate rows removed, considering only
* the subset of columns.
*
* For a static batch [[Dataset]], it just drops duplicate rows. For a streaming [[Dataset]], it
* will keep all data across triggers as intermediate state to drop duplicates rows. You can use
* [[withWatermark]] to limit how late the duplicate data can be and system will accordingly limit
* the state. In addition, too late data older than watermark will be dropped to avoid any
* possibility of duplicates.
*
* @group typedrel
* @since 2.0.0
*/
@scala.annotation.varargs
def dropDuplicates(col1: String, cols: String*): Dataset[T] = {
val colNames: Seq[String] = col1 +: cols
dropDuplicates(colNames)
}
/**
* Returns a new Dataset with duplicates rows removed, within watermark.
*
* This only works with streaming [[Dataset]], and watermark for the input [[Dataset]] must be
* set via [[withWatermark]].
*
* For a streaming [[Dataset]], this will keep all data across triggers as intermediate state
* to drop duplicated rows. The state will be kept to guarantee the semantic, "Events are
* deduplicated as long as the time distance of earliest and latest events are smaller than the
* delay threshold of watermark." Users are encouraged to set the delay threshold of watermark
* longer than max timestamp differences among duplicated events.
*
* Note: too late data older than watermark will be dropped.
*
* @group typedrel
* @since 3.5.0
*/
def dropDuplicatesWithinWatermark(): Dataset[T] = {
dropDuplicatesWithinWatermark(this.columns)
}
/**
* Returns a new Dataset with duplicates rows removed, considering only the subset of columns,
* within watermark.
*
* This only works with streaming [[Dataset]], and watermark for the input [[Dataset]] must be
* set via [[withWatermark]].
*
* For a streaming [[Dataset]], this will keep all data across triggers as intermediate state
* to drop duplicated rows. The state will be kept to guarantee the semantic, "Events are
* deduplicated as long as the time distance of earliest and latest events are smaller than the
* delay threshold of watermark." Users are encouraged to set the delay threshold of watermark
* longer than max timestamp differences among duplicated events.
*
* Note: too late data older than watermark will be dropped.
*
* @group typedrel
* @since 3.5.0
*/
def dropDuplicatesWithinWatermark(colNames: Seq[String]): Dataset[T] = withTypedPlan {
val groupCols = groupColsFromDropDuplicates(colNames)
// UnsupportedOperationChecker will fail the query if this is called with batch Dataset.
DeduplicateWithinWatermark(groupCols, logicalPlan)
}
/**
* Returns a new Dataset with duplicates rows removed, considering only the subset of columns,
* within watermark.
*
* This only works with streaming [[Dataset]], and watermark for the input [[Dataset]] must be
* set via [[withWatermark]].
*
* For a streaming [[Dataset]], this will keep all data across triggers as intermediate state
* to drop duplicated rows. The state will be kept to guarantee the semantic, "Events are
* deduplicated as long as the time distance of earliest and latest events are smaller than the
* delay threshold of watermark." Users are encouraged to set the delay threshold of watermark
* longer than max timestamp differences among duplicated events.
*
* Note: too late data older than watermark will be dropped.
*
* @group typedrel
* @since 3.5.0
*/
def dropDuplicatesWithinWatermark(colNames: Array[String]): Dataset[T] = {
dropDuplicatesWithinWatermark(colNames.toSeq)
}
/**
* Returns a new Dataset with duplicates rows removed, considering only the subset of columns,
* within watermark.
*
* This only works with streaming [[Dataset]], and watermark for the input [[Dataset]] must be
* set via [[withWatermark]].
*
* For a streaming [[Dataset]], this will keep all data across triggers as intermediate state
* to drop duplicated rows. The state will be kept to guarantee the semantic, "Events are
* deduplicated as long as the time distance of earliest and latest events are smaller than the
* delay threshold of watermark." Users are encouraged to set the delay threshold of watermark
* longer than max timestamp differences among duplicated events.
*
* Note: too late data older than watermark will be dropped.
*
* @group typedrel
* @since 3.5.0
*/
@scala.annotation.varargs
def dropDuplicatesWithinWatermark(col1: String, cols: String*): Dataset[T] = {
val colNames: Seq[String] = col1 +: cols
dropDuplicatesWithinWatermark(colNames)
}
private def groupColsFromDropDuplicates(colNames: Seq[String]): Seq[Attribute] = {
val resolver = sparkSession.sessionState.analyzer.resolver
val allColumns = queryExecution.analyzed.output
// SPARK-31990: We must keep `toSet.toSeq` here because of the backward compatibility issue
// (the Streaming's state store depends on the `groupCols` order).
colNames.toSet.toSeq.flatMap { (colName: String) =>
// It is possibly there are more than one columns with the same name,
// so we call filter instead of find.
val cols = allColumns.filter(col => resolver(col.name, colName))
if (cols.isEmpty) {
throw QueryCompilationErrors.cannotResolveColumnNameAmongAttributesError(
colName, schema.fieldNames.mkString(", "))
}
cols
}
}
/**
* Computes basic statistics for numeric and string columns, including count, mean, stddev, min,
* and max. If no columns are given, this function computes statistics for all numerical or
* string columns.
*
* This function is meant for exploratory data analysis, as we make no guarantee about the
* backward compatibility of the schema of the resulting Dataset. If you want to
* programmatically compute summary statistics, use the `agg` function instead.
*
* {{{
* ds.describe("age", "height").show()
*
* // output:
* // summary age height
* // count 10.0 10.0
* // mean 53.3 178.05
* // stddev 11.6 15.7
* // min 18.0 163.0
* // max 92.0 192.0
* }}}
*
* Use [[summary]] for expanded statistics and control over which statistics to compute.
*
* @param cols Columns to compute statistics on.
*
* @group action
* @since 1.6.0
*/
@scala.annotation.varargs
def describe(cols: String*): DataFrame = {
val selected = if (cols.isEmpty) this else select(cols.head, cols.tail: _*)
selected.summary("count", "mean", "stddev", "min", "max")
}
/**
* Computes specified statistics for numeric and string columns. Available statistics are:
*
* - count
* - mean
* - stddev
* - min
* - max
* - arbitrary approximate percentiles specified as a percentage (e.g. 75%)
* - count_distinct
* - approx_count_distinct
*
*
* If no statistics are given, this function computes count, mean, stddev, min,
* approximate quartiles (percentiles at 25%, 50%, and 75%), and max.
*
* This function is meant for exploratory data analysis, as we make no guarantee about the
* backward compatibility of the schema of the resulting Dataset. If you want to
* programmatically compute summary statistics, use the `agg` function instead.
*
* {{{
* ds.summary().show()
*
* // output:
* // summary age height
* // count 10.0 10.0
* // mean 53.3 178.05
* // stddev 11.6 15.7
* // min 18.0 163.0
* // 25% 24.0 176.0
* // 50% 24.0 176.0
* // 75% 32.0 180.0
* // max 92.0 192.0
* }}}
*
* {{{
* ds.summary("count", "min", "25%", "75%", "max").show()
*
* // output:
* // summary age height
* // count 10.0 10.0
* // min 18.0 163.0
* // 25% 24.0 176.0
* // 75% 32.0 180.0
* // max 92.0 192.0
* }}}
*
* To do a summary for specific columns first select them:
*
* {{{
* ds.select("age", "height").summary().show()
* }}}
*
* Specify statistics to output custom summaries:
*
* {{{
* ds.summary("count", "count_distinct").show()
* }}}
*
* The distinct count isn't included by default.
*
* You can also run approximate distinct counts which are faster:
*
* {{{
* ds.summary("count", "approx_count_distinct").show()
* }}}
*
* See also [[describe]] for basic statistics.
*
* @param statistics Statistics from above list to be computed.
*
* @group action
* @since 2.3.0
*/
@scala.annotation.varargs
def summary(statistics: String*): DataFrame = StatFunctions.summary(this, statistics.toSeq)
/**
* Returns the first `n` rows.
*
* @note this method should only be used if the resulting array is expected to be small, as
* all the data is loaded into the driver's memory.
*
* @group action
* @since 1.6.0
*/
def head(n: Int): Array[T] = withAction("head", limit(n).queryExecution)(collectFromPlan)
/**
* Returns the first row.
* @group action
* @since 1.6.0
*/
def head(): T = head(1).head
/**
* Returns the first row. Alias for head().
* @group action
* @since 1.6.0
*/
def first(): T = head()
/**
* Concise syntax for chaining custom transformations.
* {{{
* def featurize(ds: Dataset[T]): Dataset[U] = ...
*
* ds
* .transform(featurize)
* .transform(...)
* }}}
*
* @group typedrel
* @since 1.6.0
*/
def transform[U](t: Dataset[T] => Dataset[U]): Dataset[U] = t(this)
/**
* (Scala-specific)
* Returns a new Dataset that only contains elements where `func` returns `true`.
*
* @group typedrel
* @since 1.6.0
*/
def filter(func: T => Boolean): Dataset[T] = {
withTypedPlan(TypedFilter(func, logicalPlan))
}
/**
* (Java-specific)
* Returns a new Dataset that only contains elements where `func` returns `true`.
*
* @group typedrel
* @since 1.6.0
*/
def filter(func: FilterFunction[T]): Dataset[T] = {
withTypedPlan(TypedFilter(func, logicalPlan))
}
/**
* (Scala-specific)
* Returns a new Dataset that contains the result of applying `func` to each element.
*
* @group typedrel
* @since 1.6.0
*/
def map[U : Encoder](func: T => U): Dataset[U] = withTypedPlan {
MapElements[T, U](func, logicalPlan)
}
/**
* (Java-specific)
* Returns a new Dataset that contains the result of applying `func` to each element.
*
* @group typedrel
* @since 1.6.0
*/
def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
implicit val uEnc = encoder
withTypedPlan(MapElements[T, U](func, logicalPlan))
}
/**
* (Scala-specific)
* Returns a new Dataset that contains the result of applying `func` to each partition.
*
* @group typedrel
* @since 1.6.0
*/
def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = {
new Dataset[U](
sparkSession,
MapPartitions[T, U](func, logicalPlan),
implicitly[Encoder[U]])
}
/**
* (Java-specific)
* Returns a new Dataset that contains the result of applying `f` to each partition.
*
* @group typedrel
* @since 1.6.0
*/
def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
val func: (Iterator[T]) => Iterator[U] = x => f.call(x.asJava).asScala
mapPartitions(func)(encoder)
}
/**
* Returns a new `DataFrame` that contains the result of applying a serialized R function
* `func` to each partition.
*/
private[sql] def mapPartitionsInR(
func: Array[Byte],
packageNames: Array[Byte],
broadcastVars: Array[Broadcast[Object]],
schema: StructType): DataFrame = {
val rowEncoder = encoder.asInstanceOf[ExpressionEncoder[Row]]
Dataset.ofRows(
sparkSession,
MapPartitionsInR(func, packageNames, broadcastVars, schema, rowEncoder, logicalPlan))
}
/**
* Applies a Scalar iterator Pandas UDF to each partition. The user-defined function
* defines a transformation: `iter(pandas.DataFrame)` -> `iter(pandas.DataFrame)`.
* Each partition is each iterator consisting of DataFrames as batches.
*
* This function uses Apache Arrow as serialization format between Java executors and Python
* workers.
*/
private[sql] def mapInPandas(func: PythonUDF, isBarrier: Boolean = false): DataFrame = {
Dataset.ofRows(
sparkSession,
MapInPandas(
func,
toAttributes(func.dataType.asInstanceOf[StructType]),
logicalPlan,
isBarrier))
}
/**
* Applies a function to each partition in Arrow format. The user-defined function
* defines a transformation: `iter(pyarrow.RecordBatch)` -> `iter(pyarrow.RecordBatch)`.
* Each partition is each iterator consisting of `pyarrow.RecordBatch`s as batches.
*/
private[sql] def pythonMapInArrow(func: PythonUDF, isBarrier: Boolean = false): DataFrame = {
Dataset.ofRows(
sparkSession,
PythonMapInArrow(
func,
toAttributes(func.dataType.asInstanceOf[StructType]),
logicalPlan,
isBarrier))
}
/**
* (Scala-specific)
* Returns a new Dataset by first applying a function to all elements of this Dataset,
* and then flattening the results.
*
* @group typedrel
* @since 1.6.0
*/
def flatMap[U : Encoder](func: T => TraversableOnce[U]): Dataset[U] =
mapPartitions(_.flatMap(func))
/**
* (Java-specific)
* Returns a new Dataset by first applying a function to all elements of this Dataset,
* and then flattening the results.
*
* @group typedrel
* @since 1.6.0
*/
def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
val func: (T) => Iterator[U] = x => f.call(x).asScala
flatMap(func)(encoder)
}
/**
* Applies a function `f` to all rows.
*
* @group action
* @since 1.6.0
*/
def foreach(f: T => Unit): Unit = withNewRDDExecutionId("foreach") {
rdd.foreach(f)
}
/**
* (Java-specific)
* Runs `func` on each element of this Dataset.
*
* @group action
* @since 1.6.0
*/
def foreach(func: ForeachFunction[T]): Unit = foreach(func.call(_))
/**
* Applies a function `f` to each partition of this Dataset.
*
* @group action
* @since 1.6.0
*/
def foreachPartition(f: Iterator[T] => Unit): Unit = withNewRDDExecutionId("foreachPartition") {
rdd.foreachPartition(f)
}
/**
* (Java-specific)
* Runs `func` on each partition of this Dataset.
*
* @group action
* @since 1.6.0
*/
def foreachPartition(func: ForeachPartitionFunction[T]): Unit = {
foreachPartition((it: Iterator[T]) => func.call(it.asJava))
}
/**
* Returns the first `n` rows in the Dataset.
*
* Running take requires moving data into the application's driver process, and doing so with
* a very large `n` can crash the driver process with OutOfMemoryError.
*
* @group action
* @since 1.6.0
*/
def take(n: Int): Array[T] = head(n)
/**
* Returns the last `n` rows in the Dataset.
*
* Running tail requires moving data into the application's driver process, and doing so with
* a very large `n` can crash the driver process with OutOfMemoryError.
*
* @group action
* @since 3.0.0
*/
def tail(n: Int): Array[T] = withAction(
"tail", withTypedPlan(Tail(Literal(n), logicalPlan)).queryExecution)(collectFromPlan)
/**
* Returns the first `n` rows in the Dataset as a list.
*
* Running take requires moving data into the application's driver process, and doing so with
* a very large `n` can crash the driver process with OutOfMemoryError.
*
* @group action
* @since 1.6.0
*/
def takeAsList(n: Int): java.util.List[T] = java.util.Arrays.asList(take(n) : _*)
/**
* Returns an array that contains all rows in this Dataset.
*
* Running collect requires moving all the data into the application's driver process, and
* doing so on a very large dataset can crash the driver process with OutOfMemoryError.
*
* For Java API, use [[collectAsList]].
*
* @group action
* @since 1.6.0
*/
def collect(): Array[T] = withAction("collect", queryExecution)(collectFromPlan)
/**
* Returns a Java list that contains all rows in this Dataset.
*
* Running collect requires moving all the data into the application's driver process, and
* doing so on a very large dataset can crash the driver process with OutOfMemoryError.
*
* @group action
* @since 1.6.0
*/
def collectAsList(): java.util.List[T] = withAction("collectAsList", queryExecution) { plan =>
val values = collectFromPlan(plan)
java.util.Arrays.asList(values : _*)
}
/**
* Returns an iterator that contains all rows in this Dataset.
*
* The iterator will consume as much memory as the largest partition in this Dataset.
*
* @note this results in multiple Spark jobs, and if the input Dataset is the result
* of a wide transformation (e.g. join with different partitioners), to avoid
* recomputing the input Dataset should be cached first.
*
* @group action
* @since 2.0.0
*/
def toLocalIterator(): java.util.Iterator[T] = {
withAction("toLocalIterator", queryExecution) { plan =>
val fromRow = resolvedEnc.createDeserializer()
plan.executeToIterator().map(fromRow).asJava
}
}
/**
* Returns the number of rows in the Dataset.
* @group action
* @since 1.6.0
*/
def count(): Long = withAction("count", groupBy().count().queryExecution) { plan =>
plan.executeCollect().head.getLong(0)
}
/**
* Returns a new Dataset that has exactly `numPartitions` partitions.
*
* @group typedrel
* @since 1.6.0
*/
def repartition(numPartitions: Int): Dataset[T] = withTypedPlan {
Repartition(numPartitions, shuffle = true, logicalPlan)
}
private def repartitionByExpression(
numPartitions: Option[Int],
partitionExprs: Seq[Column]): Dataset[T] = {
// The underlying `LogicalPlan` operator special-cases all-`SortOrder` arguments.
// However, we don't want to complicate the semantics of this API method.
// Instead, let's give users a friendly error message, pointing them to the new method.
val sortOrders = partitionExprs.filter(_.expr.isInstanceOf[SortOrder])
if (sortOrders.nonEmpty) throw new IllegalArgumentException(
s"""Invalid partitionExprs specified: $sortOrders
|For range partitioning use repartitionByRange(...) instead.
""".stripMargin)
withTypedPlan {
RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, numPartitions)
}
}
/**
* Returns a new Dataset partitioned by the given partitioning expressions into
* `numPartitions`. The resulting Dataset is hash partitioned.
*
* This is the same operation as "DISTRIBUTE BY" in SQL (Hive QL).
*
* @group typedrel
* @since 2.0.0
*/
@scala.annotation.varargs
def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = {
repartitionByExpression(Some(numPartitions), partitionExprs)
}
/**
* Returns a new Dataset partitioned by the given partitioning expressions, using
* `spark.sql.shuffle.partitions` as number of partitions.
* The resulting Dataset is hash partitioned.
*
* This is the same operation as "DISTRIBUTE BY" in SQL (Hive QL).
*
* @group typedrel
* @since 2.0.0
*/
@scala.annotation.varargs
def repartition(partitionExprs: Column*): Dataset[T] = {
repartitionByExpression(None, partitionExprs)
}
private def repartitionByRange(
numPartitions: Option[Int],
partitionExprs: Seq[Column]): Dataset[T] = {
require(partitionExprs.nonEmpty, "At least one partition-by expression must be specified.")
val sortOrder: Seq[SortOrder] = partitionExprs.map(_.expr match {
case expr: SortOrder => expr
case expr: Expression => SortOrder(expr, Ascending)
})
withTypedPlan {
RepartitionByExpression(sortOrder, logicalPlan, numPartitions)
}
}
/**
* Returns a new Dataset partitioned by the given partitioning expressions into
* `numPartitions`. The resulting Dataset is range partitioned.
*
* At least one partition-by expression must be specified.
* When no explicit sort order is specified, "ascending nulls first" is assumed.
* Note, the rows are not sorted in each partition of the resulting Dataset.
*
*
* Note that due to performance reasons this method uses sampling to estimate the ranges.
* Hence, the output may not be consistent, since sampling can return different values.
* The sample size can be controlled by the config
* `spark.sql.execution.rangeExchange.sampleSizePerPartition`.
*
* @group typedrel
* @since 2.3.0
*/
@scala.annotation.varargs
def repartitionByRange(numPartitions: Int, partitionExprs: Column*): Dataset[T] = {
repartitionByRange(Some(numPartitions), partitionExprs)
}
/**
* Returns a new Dataset partitioned by the given partitioning expressions, using
* `spark.sql.shuffle.partitions` as number of partitions.
* The resulting Dataset is range partitioned.
*
* At least one partition-by expression must be specified.
* When no explicit sort order is specified, "ascending nulls first" is assumed.
* Note, the rows are not sorted in each partition of the resulting Dataset.
*
* Note that due to performance reasons this method uses sampling to estimate the ranges.
* Hence, the output may not be consistent, since sampling can return different values.
* The sample size can be controlled by the config
* `spark.sql.execution.rangeExchange.sampleSizePerPartition`.
*
* @group typedrel
* @since 2.3.0
*/
@scala.annotation.varargs
def repartitionByRange(partitionExprs: Column*): Dataset[T] = {
repartitionByRange(None, partitionExprs)
}
/**
* Returns a new Dataset that has exactly `numPartitions` partitions, when the fewer partitions
* are requested. If a larger number of partitions is requested, it will stay at the current
* number of partitions. Similar to coalesce defined on an `RDD`, this operation results in
* a narrow dependency, e.g. if you go from 1000 partitions to 100 partitions, there will not
* be a shuffle, instead each of the 100 new partitions will claim 10 of the current partitions.
*
* However, if you're doing a drastic coalesce, e.g. to numPartitions = 1,
* this may result in your computation taking place on fewer nodes than
* you like (e.g. one node in the case of numPartitions = 1). To avoid this,
* you can call repartition. This will add a shuffle step, but means the
* current upstream partitions will be executed in parallel (per whatever
* the current partitioning is).
*
* @group typedrel
* @since 1.6.0
*/
def coalesce(numPartitions: Int): Dataset[T] = withTypedPlan {
Repartition(numPartitions, shuffle = false, logicalPlan)
}
/**
* Returns a new Dataset that contains only the unique rows from this Dataset.
* This is an alias for `dropDuplicates`.
*
* Note that for a streaming [[Dataset]], this method returns distinct rows only once
* regardless of the output mode, which the behavior may not be same with `DISTINCT` in SQL
* against streaming [[Dataset]].
*
* @note Equality checking is performed directly on the encoded representation of the data
* and thus is not affected by a custom `equals` function defined on `T`.
*
* @group typedrel
* @since 2.0.0
*/
def distinct(): Dataset[T] = dropDuplicates()
/**
* Persist this Dataset with the default storage level (`MEMORY_AND_DISK`).
*
* @group basic
* @since 1.6.0
*/
def persist(): this.type = {
sparkSession.sharedState.cacheManager.cacheQuery(this)
this
}
/**
* Persist this Dataset with the default storage level (`MEMORY_AND_DISK`).
*
* @group basic
* @since 1.6.0
*/
def cache(): this.type = persist()
/**
* Persist this Dataset with the given storage level.
* @param newLevel One of: `MEMORY_ONLY`, `MEMORY_AND_DISK`, `MEMORY_ONLY_SER`,
* `MEMORY_AND_DISK_SER`, `DISK_ONLY`, `MEMORY_ONLY_2`,
* `MEMORY_AND_DISK_2`, etc.
*
* @group basic
* @since 1.6.0
*/
def persist(newLevel: StorageLevel): this.type = {
sparkSession.sharedState.cacheManager.cacheQuery(this, None, newLevel)
this
}
/**
* Get the Dataset's current storage level, or StorageLevel.NONE if not persisted.
*
* @group basic
* @since 2.1.0
*/
def storageLevel: StorageLevel = {
sparkSession.sharedState.cacheManager.lookupCachedData(this).map { cachedData =>
cachedData.cachedRepresentation.cacheBuilder.storageLevel
}.getOrElse(StorageLevel.NONE)
}
/**
* Mark the Dataset as non-persistent, and remove all blocks for it from memory and disk.
* This will not un-persist any cached data that is built upon this Dataset.
*
* @param blocking Whether to block until all blocks are deleted.
*
* @group basic
* @since 1.6.0
*/
def unpersist(blocking: Boolean): this.type = {
sparkSession.sharedState.cacheManager.uncacheQuery(
sparkSession, logicalPlan, cascade = false, blocking)
this
}
/**
* Mark the Dataset as non-persistent, and remove all blocks for it from memory and disk.
* This will not un-persist any cached data that is built upon this Dataset.
*
* @group basic
* @since 1.6.0
*/
def unpersist(): this.type = unpersist(blocking = false)
// Represents the `QueryExecution` used to produce the content of the Dataset as an `RDD`.
@transient private lazy val rddQueryExecution: QueryExecution = {
val deserialized = CatalystSerde.deserialize[T](logicalPlan)
sparkSession.sessionState.executePlan(deserialized)
}
/**
* Represents the content of the Dataset as an `RDD` of `T`.
*
* @group basic
* @since 1.6.0
*/
lazy val rdd: RDD[T] = {
val objectType = exprEnc.deserializer.dataType
rddQueryExecution.toRdd.mapPartitions { rows =>
rows.map(_.get(0, objectType).asInstanceOf[T])
}
}
/**
* Returns the content of the Dataset as a `JavaRDD` of `T`s.
* @group basic
* @since 1.6.0
*/
def toJavaRDD: JavaRDD[T] = rdd.toJavaRDD()
/**
* Returns the content of the Dataset as a `JavaRDD` of `T`s.
* @group basic
* @since 1.6.0
*/
def javaRDD: JavaRDD[T] = toJavaRDD
/**
* Registers this Dataset as a temporary table using the given name. The lifetime of this
* temporary table is tied to the [[SparkSession]] that was used to create this Dataset.
*
* @group basic
* @since 1.6.0
*/
@deprecated("Use createOrReplaceTempView(viewName) instead.", "2.0.0")
def registerTempTable(tableName: String): Unit = {
createOrReplaceTempView(tableName)
}
/**
* Creates a local temporary view using the given name. The lifetime of this
* temporary view is tied to the [[SparkSession]] that was used to create this Dataset.
*
* Local temporary view is session-scoped. Its lifetime is the lifetime of the session that
* created it, i.e. it will be automatically dropped when the session terminates. It's not
* tied to any databases, i.e. we can't use `db1.view1` to reference a local temporary view.
*
* @throws AnalysisException if the view name is invalid or already exists
*
* @group basic
* @since 2.0.0
*/
@throws[AnalysisException]
def createTempView(viewName: String): Unit = withPlan {
createTempViewCommand(viewName, replace = false, global = false)
}
/**
* Creates a local temporary view using the given name. The lifetime of this
* temporary view is tied to the [[SparkSession]] that was used to create this Dataset.
*
* @group basic
* @since 2.0.0
*/
def createOrReplaceTempView(viewName: String): Unit = withPlan {
createTempViewCommand(viewName, replace = true, global = false)
}
/**
* Creates a global temporary view using the given name. The lifetime of this
* temporary view is tied to this Spark application.
*
* Global temporary view is cross-session. Its lifetime is the lifetime of the Spark application,
* i.e. it will be automatically dropped when the application terminates. It's tied to a system
* preserved database `global_temp`, and we must use the qualified name to refer a global temp
* view, e.g. `SELECT * FROM global_temp.view1`.
*
* @throws AnalysisException if the view name is invalid or already exists
*
* @group basic
* @since 2.1.0
*/
@throws[AnalysisException]
def createGlobalTempView(viewName: String): Unit = withPlan {
createTempViewCommand(viewName, replace = false, global = true)
}
/**
* Creates or replaces a global temporary view using the given name. The lifetime of this
* temporary view is tied to this Spark application.
*
* Global temporary view is cross-session. Its lifetime is the lifetime of the Spark application,
* i.e. it will be automatically dropped when the application terminates. It's tied to a system
* preserved database `global_temp`, and we must use the qualified name to refer a global temp
* view, e.g. `SELECT * FROM global_temp.view1`.
*
* @group basic
* @since 2.2.0
*/
def createOrReplaceGlobalTempView(viewName: String): Unit = withPlan {
createTempViewCommand(viewName, replace = true, global = true)
}
private def createTempViewCommand(
viewName: String,
replace: Boolean,
global: Boolean): CreateViewCommand = sparkSession.withActive {
val viewType = if (global) GlobalTempView else LocalTempView
val identifier = try {
sparkSession.sessionState.sqlParser.parseMultipartIdentifier(viewName)
} catch {
case _: ParseException => throw QueryCompilationErrors.invalidViewNameError(viewName)
}
if (!SQLConf.get.allowsTempViewCreationWithMultipleNameparts && identifier.size > 1) {
// Temporary view names should NOT contain database prefix like "database.table"
throw new AnalysisException(
errorClass = "TEMP_VIEW_NAME_TOO_MANY_NAME_PARTS",
messageParameters = Map("actualName" -> viewName))
}
CreateViewCommand(
name = TableIdentifier(identifier.last),
userSpecifiedColumns = Nil,
comment = None,
properties = Map.empty,
originalText = None,
plan = logicalPlan,
allowExisting = false,
replace = replace,
viewType = viewType,
isAnalyzed = true)
}
/**
* Interface for saving the content of the non-streaming Dataset out into external storage.
*
* @group basic
* @since 1.6.0
*/
def write: DataFrameWriter[T] = {
if (isStreaming) {
logicalPlan.failAnalysis(
errorClass = "CALL_ON_STREAMING_DATASET_UNSUPPORTED",
messageParameters = Map("methodName" -> toSQLId("write")))
}
new DataFrameWriter[T](this)
}
/**
* Create a write configuration builder for v2 sources.
*
* This builder is used to configure and execute write operations. For example, to append to an
* existing table, run:
*
* {{{
* df.writeTo("catalog.db.table").append()
* }}}
*
* This can also be used to create or replace existing tables:
*
* {{{
* df.writeTo("catalog.db.table").partitionedBy($"col").createOrReplace()
* }}}
*
* @group basic
* @since 3.0.0
*/
def writeTo(table: String): DataFrameWriterV2[T] = {
// TODO: streaming could be adapted to use this interface
if (isStreaming) {
logicalPlan.failAnalysis(
errorClass = "CALL_ON_STREAMING_DATASET_UNSUPPORTED",
messageParameters = Map("methodName" -> toSQLId("writeTo")))
}
new DataFrameWriterV2[T](table, this)
}
/**
* Interface for saving the content of the streaming Dataset out into external storage.
*
* @group basic
* @since 2.0.0
*/
def writeStream: DataStreamWriter[T] = {
if (!isStreaming) {
logicalPlan.failAnalysis(
errorClass = "WRITE_STREAM_NOT_ALLOWED",
messageParameters = Map.empty)
}
new DataStreamWriter[T](this)
}
/**
* Returns the content of the Dataset as a Dataset of JSON strings.
* @since 2.0.0
*/
def toJSON: Dataset[String] = {
val rowSchema = exprEnc.schema
val sessionLocalTimeZone = sparkSession.sessionState.conf.sessionLocalTimeZone
mapPartitions { iter =>
val writer = new CharArrayWriter()
// create the Generator without separator inserted between 2 records
val gen = new JacksonGenerator(rowSchema, writer,
new JSONOptions(Map.empty[String, String], sessionLocalTimeZone))
new Iterator[String] {
private val toRow = exprEnc.createSerializer()
override def hasNext: Boolean = iter.hasNext
override def next(): String = {
gen.write(toRow(iter.next()))
gen.flush()
val json = writer.toString
if (hasNext) {
writer.reset()
} else {
gen.close()
}
json
}
}
} (Encoders.STRING)
}
/**
* Returns a best-effort snapshot of the files that compose this Dataset. This method simply
* asks each constituent BaseRelation for its respective files and takes the union of all results.
* Depending on the source relations, this may not find all input files. Duplicates are removed.
*
* @group basic
* @since 2.0.0
*/
def inputFiles: Array[String] = {
val files: Seq[String] = queryExecution.optimizedPlan.collect {
case LogicalRelation(fsBasedRelation: FileRelation, _, _, _) =>
fsBasedRelation.inputFiles
case fr: FileRelation =>
fr.inputFiles
case r: HiveTableRelation =>
r.tableMeta.storage.locationUri.map(_.toString).toArray
case DataSourceV2ScanRelation(DataSourceV2Relation(table: FileTable, _, _, _, _),
_, _, _, _) =>
table.fileIndex.inputFiles
}.flatten
files.toSet.toArray
}
/**
* Returns `true` when the logical query plans inside both [[Dataset]]s are equal and
* therefore return same results.
*
* @note The equality comparison here is simplified by tolerating the cosmetic differences
* such as attribute names.
* @note This API can compare both [[Dataset]]s very fast but can still return `false` on
* the [[Dataset]] that return the same results, for instance, from different plans. Such
* false negative semantic can be useful when caching as an example.
* @since 3.1.0
*/
@DeveloperApi
def sameSemantics(other: Dataset[T]): Boolean = {
queryExecution.analyzed.sameResult(other.queryExecution.analyzed)
}
/**
* Returns a `hashCode` of the logical query plan against this [[Dataset]].
*
* @note Unlike the standard `hashCode`, the hash is calculated against the query plan
* simplified by tolerating the cosmetic differences such as attribute names.
* @since 3.1.0
*/
@DeveloperApi
def semanticHash(): Int = {
queryExecution.analyzed.semanticHash()
}
////////////////////////////////////////////////////////////////////////////
// For Python API
////////////////////////////////////////////////////////////////////////////
/**
* It adds a new long column with the name `name` that increases one by one.
* This is for 'distributed-sequence' default index in pandas API on Spark.
*/
private[sql] def withSequenceColumn(name: String) = {
select(Column(DistributedSequenceID()).alias(name), col("*"))
}
/**
* Converts a JavaRDD to a PythonRDD.
*/
private[sql] def javaToPython: JavaRDD[Array[Byte]] = {
val structType = schema // capture it for closure
val rdd = queryExecution.toRdd.map(EvaluatePython.toJava(_, structType))
EvaluatePython.javaToPython(rdd)
}
private[sql] def collectToPython(): Array[Any] = {
EvaluatePython.registerPicklers()
withAction("collectToPython", queryExecution) { plan =>
val toJava: (Any) => Any = EvaluatePython.toJava(_, schema)
val iter: Iterator[Array[Byte]] = new SerDeUtil.AutoBatchedPickler(
plan.executeCollect().iterator.map(toJava))
PythonRDD.serveIterator(iter, "serve-DataFrame")
}
}
private[sql] def tailToPython(n: Int): Array[Any] = {
EvaluatePython.registerPicklers()
withAction("tailToPython", queryExecution) { plan =>
val toJava: (Any) => Any = EvaluatePython.toJava(_, schema)
val iter: Iterator[Array[Byte]] = new SerDeUtil.AutoBatchedPickler(
plan.executeTail(n).iterator.map(toJava))
PythonRDD.serveIterator(iter, "serve-DataFrame")
}
}
private[sql] def getRowsToPython(
_numRows: Int,
truncate: Int): Array[Any] = {
EvaluatePython.registerPicklers()
val numRows = _numRows.max(0).min(ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH - 1)
val rows = getRows(numRows, truncate).map(_.toArray).toArray
val toJava: (Any) => Any = EvaluatePython.toJava(_, ArrayType(ArrayType(StringType)))
val iter: Iterator[Array[Byte]] = new SerDeUtil.AutoBatchedPickler(
rows.iterator.map(toJava))
PythonRDD.serveIterator(iter, "serve-GetRows")
}
/**
* Collect a Dataset as Arrow batches and serve stream to SparkR. It sends
* arrow batches in an ordered manner with buffering. This is inevitable
* due to missing R API that reads batches from socket directly. See ARROW-4512.
* Eventually, this code should be deduplicated by `collectAsArrowToPython`.
*/
private[sql] def collectAsArrowToR(): Array[Any] = {
val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
RRDD.serveToStream("serve-Arrow") { outputStream =>
withAction("collectAsArrowToR", queryExecution) { plan =>
val buffer = new ByteArrayOutputStream()
val out = new DataOutputStream(outputStream)
val batchWriter =
new ArrowBatchStreamWriter(schema, buffer, timeZoneId, errorOnDuplicatedFieldNames = true)
val arrowBatchRdd = toArrowBatchRdd(plan)
val numPartitions = arrowBatchRdd.partitions.length
// Store collection results for worst case of 1 to N-1 partitions
val results = new Array[Array[Array[Byte]]](Math.max(0, numPartitions - 1))
var lastIndex = -1 // index of last partition written
// Handler to eagerly write partitions to Python in order
def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = {
// If result is from next partition in order
if (index - 1 == lastIndex) {
batchWriter.writeBatches(arrowBatches.iterator)
lastIndex += 1
// Write stored partitions that come next in order
while (lastIndex < results.length && results(lastIndex) != null) {
batchWriter.writeBatches(results(lastIndex).iterator)
results(lastIndex) = null
lastIndex += 1
}
// After last batch, end the stream
if (lastIndex == results.length) {
batchWriter.end()
val batches = buffer.toByteArray
out.writeInt(batches.length)
out.write(batches)
}
} else {
// Store partitions received out of order
results(index - 1) = arrowBatches
}
}
sparkSession.sparkContext.runJob(
arrowBatchRdd,
(ctx: TaskContext, it: Iterator[Array[Byte]]) => it.toArray,
0 until numPartitions,
handlePartitionBatches)
}
}
}
/**
* Collect a Dataset as Arrow batches and serve stream to PySpark. It sends
* arrow batches in an un-ordered manner without buffering, and then batch order
* information at the end. The batches should be reordered at Python side.
*/
private[sql] def collectAsArrowToPython: Array[Any] = {
val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
val errorOnDuplicatedFieldNames =
sparkSession.sessionState.conf.pandasStructHandlingMode == "legacy"
PythonRDD.serveToStream("serve-Arrow") { outputStream =>
withAction("collectAsArrowToPython", queryExecution) { plan =>
val out = new DataOutputStream(outputStream)
val batchWriter =
new ArrowBatchStreamWriter(schema, out, timeZoneId, errorOnDuplicatedFieldNames)
// Batches ordered by (index of partition, batch index in that partition) tuple
val batchOrder = ArrayBuffer.empty[(Int, Int)]
// Handler to eagerly write batches to Python as they arrive, un-ordered
val handlePartitionBatches = (index: Int, arrowBatches: Array[Array[Byte]]) =>
if (arrowBatches.nonEmpty) {
// Write all batches (can be more than 1) in the partition, store the batch order tuple
batchWriter.writeBatches(arrowBatches.iterator)
arrowBatches.indices.foreach {
partitionBatchIndex => batchOrder.append((index, partitionBatchIndex))
}
}
Utils.tryWithSafeFinally {
val arrowBatchRdd = toArrowBatchRdd(plan)
sparkSession.sparkContext.runJob(
arrowBatchRdd,
(it: Iterator[Array[Byte]]) => it.toArray,
handlePartitionBatches)
} {
// After processing all partitions, end the batch stream
batchWriter.end()
// Write batch order indices
out.writeInt(batchOrder.length)
// Sort by (index of partition, batch index in that partition) tuple to get the
// overall_batch_index from 0 to N-1 batches, which can be used to put the
// transferred batches in the correct order
batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, overallBatchIndex) =>
out.writeInt(overallBatchIndex)
}
}
}
}
}
private[sql] def toPythonIterator(prefetchPartitions: Boolean = false): Array[Any] = {
withNewExecutionId {
PythonRDD.toLocalIteratorAndServe(javaToPython.rdd, prefetchPartitions)
}
}
////////////////////////////////////////////////////////////////////////////
// Private Helpers
////////////////////////////////////////////////////////////////////////////
/**
* Wrap a Dataset action to track all Spark jobs in the body so that we can connect them with
* an execution.
*/
private def withNewExecutionId[U](body: => U): U = {
SQLExecution.withNewExecutionId(queryExecution)(body)
}
/**
* Wrap an action of the Dataset's RDD to track all Spark jobs in the body so that we can connect
* them with an execution. Before performing the action, the metrics of the executed plan will be
* reset.
*/
private def withNewRDDExecutionId[U](name: String)(body: => U): U = {
SQLExecution.withNewExecutionId(rddQueryExecution, Some(name)) {
rddQueryExecution.executedPlan.resetMetrics()
body
}
}
/**
* Wrap a Dataset action to track the QueryExecution and time cost, then report to the
* user-registered callback functions, and also to convert asserts/NPE to
* the internal error exception.
*/
private def withAction[U](name: String, qe: QueryExecution)(action: SparkPlan => U) = {
SQLExecution.withNewExecutionId(qe, Some(name)) {
QueryExecution.withInternalError(s"""The "$name" action failed.""") {
qe.executedPlan.resetMetrics()
action(qe.executedPlan)
}
}
}
/**
* Collect all elements from a spark plan.
*/
private def collectFromPlan(plan: SparkPlan): Array[T] = {
val fromRow = resolvedEnc.createDeserializer()
plan.executeCollect().map(fromRow)
}
private def sortInternal(global: Boolean, sortExprs: Seq[Column]): Dataset[T] = {
val sortOrder: Seq[SortOrder] = sortExprs.map { col =>
col.expr match {
case expr: SortOrder =>
expr
case expr: Expression =>
SortOrder(expr, Ascending)
}
}
withTypedPlan {
Sort(sortOrder, global = global, logicalPlan)
}
}
/** A convenient function to wrap a logical plan and produce a DataFrame. */
@inline private def withPlan(logicalPlan: LogicalPlan): DataFrame = {
Dataset.ofRows(sparkSession, logicalPlan)
}
/** A convenient function to wrap a logical plan and produce a Dataset. */
@inline private def withTypedPlan[U : Encoder](logicalPlan: LogicalPlan): Dataset[U] = {
Dataset(sparkSession, logicalPlan)
}
/** A convenient function to wrap a set based logical plan and produce a Dataset. */
@inline private def withSetOperator[U : Encoder](logicalPlan: LogicalPlan): Dataset[U] = {
if (classTag.runtimeClass.isAssignableFrom(classOf[Row])) {
// Set operators widen types (change the schema), so we cannot reuse the row encoder.
Dataset.ofRows(sparkSession, logicalPlan).asInstanceOf[Dataset[U]]
} else {
Dataset(sparkSession, logicalPlan)
}
}
/** Convert to an RDD of serialized ArrowRecordBatches. */
private[sql] def toArrowBatchRdd(plan: SparkPlan): RDD[Array[Byte]] = {
val schemaCaptured = this.schema
val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch
val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
val errorOnDuplicatedFieldNames =
sparkSession.sessionState.conf.pandasStructHandlingMode == "legacy"
plan.execute().mapPartitionsInternal { iter =>
val context = TaskContext.get()
ArrowConverters.toBatchIterator(
iter, schemaCaptured, maxRecordsPerBatch, timeZoneId, errorOnDuplicatedFieldNames, context)
}
}
// This is only used in tests, for now.
private[sql] def toArrowBatchRdd: RDD[Array[Byte]] = {
toArrowBatchRdd(queryExecution.executedPlan)
}
}