All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.sql.internal.DataFrameWriterV2Impl.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.internal

import java.util

import scala.collection.mutable
import scala.jdk.CollectionConverters.MapHasAsScala

import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.{Column, DataFrame, DataFrameWriterV2, Dataset}
import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, UnresolvedFunction, UnresolvedIdentifier, UnresolvedRelation}
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Literal}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.connector.catalog.TableWritePrivilege._
import org.apache.spark.sql.connector.expressions._
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.types.IntegerType

/**
 * Interface used to write a [[org.apache.spark.sql.Dataset]] to external storage using the v2 API.
 *
 * @since 3.0.0
 */
@Experimental
final class DataFrameWriterV2Impl[T] private[sql](table: String, ds: Dataset[T])
    extends DataFrameWriterV2[T] {

  private val df: DataFrame = ds.toDF()

  private val sparkSession = ds.sparkSession
  import sparkSession.expression

  private val tableName = sparkSession.sessionState.sqlParser.parseMultipartIdentifier(table)

  private val logicalPlan = df.queryExecution.logical

  private var provider: Option[String] = None

  private val options = new mutable.HashMap[String, String]()

  private val properties = new mutable.HashMap[String, String]()

  private var partitioning: Option[Seq[Transform]] = None

  private var clustering: Option[ClusterByTransform] = None

  /** @inheritdoc */
  override def using(provider: String): this.type = {
    this.provider = Some(provider)
    this
  }

  /** @inheritdoc */
  override def option(key: String, value: String): this.type = {
    this.options.put(key, value)
    this
  }

  /** @inheritdoc */
  override def options(options: scala.collection.Map[String, String]): this.type = {
    options.foreach {
      case (key, value) =>
        this.options.put(key, value)
    }
    this
  }

  /** @inheritdoc */
  override def options(options: util.Map[String, String]): this.type = {
    this.options(options.asScala)
    this
  }

  /** @inheritdoc */
  override def tableProperty(property: String, value: String): this.type = {
    this.properties.put(property, value)
    this
  }


  /** @inheritdoc */
  @scala.annotation.varargs
  override def partitionedBy(column: Column, columns: Column*): this.type = {
    def ref(name: String): NamedReference = LogicalExpressions.parseReference(name)

    val asTransforms = (column +: columns).map(expression).map {
      case PartitionTransform.YEARS(Seq(attr: Attribute)) =>
        LogicalExpressions.years(ref(attr.name))
      case PartitionTransform.MONTHS(Seq(attr: Attribute)) =>
        LogicalExpressions.months(ref(attr.name))
      case PartitionTransform.DAYS(Seq(attr: Attribute)) =>
        LogicalExpressions.days(ref(attr.name))
      case PartitionTransform.HOURS(Seq(attr: Attribute)) =>
        LogicalExpressions.hours(ref(attr.name))
      case PartitionTransform.BUCKET(Seq(Literal(numBuckets: Int, IntegerType), attr: Attribute)) =>
        LogicalExpressions.bucket(numBuckets, Array(ref(attr.name)))
      case PartitionTransform.BUCKET(Seq(numBuckets, e)) =>
        throw QueryCompilationErrors.invalidBucketsNumberError(numBuckets.toString, e.toString)
      case attr: Attribute =>
        LogicalExpressions.identity(ref(attr.name))
      case expr =>
        throw QueryCompilationErrors.invalidPartitionTransformationError(expr)
    }

    this.partitioning = Some(asTransforms)
    validatePartitioning()
    this
  }

  /** @inheritdoc */
  @scala.annotation.varargs
  override def clusterBy(colName: String, colNames: String*): this.type = {
    this.clustering =
      Some(ClusterByTransform((colName +: colNames).map(col => FieldReference(col))))
    validatePartitioning()
    this
  }

  /**
   * Validate that clusterBy is not used with partitionBy.
   */
  private def validatePartitioning(): Unit = {
    if (partitioning.nonEmpty && clustering.nonEmpty) {
      throw QueryCompilationErrors.clusterByWithPartitionedBy()
    }
  }

  /** @inheritdoc */
  override def create(): Unit = {
    val tableSpec = UnresolvedTableSpec(
      properties = properties.toMap,
      provider = provider,
      optionExpression = OptionList(Seq.empty),
      location = None,
      comment = None,
      serde = None,
      external = false)
    runCommand(
      CreateTableAsSelect(
        UnresolvedIdentifier(tableName),
        partitioning.getOrElse(Seq.empty) ++ clustering,
        logicalPlan,
        tableSpec,
        options.toMap,
        false))
  }

  /** @inheritdoc */
  override def replace(): Unit = {
    internalReplace(orCreate = false)
  }

  /** @inheritdoc */
  override def createOrReplace(): Unit = {
    internalReplace(orCreate = true)
  }

  /** @inheritdoc */
  @throws(classOf[NoSuchTableException])
  def append(): Unit = {
    val append = AppendData.byName(
      UnresolvedRelation(tableName).requireWritePrivileges(Seq(INSERT)),
      logicalPlan, options.toMap)
    runCommand(append)
  }

  /** @inheritdoc */
  @throws(classOf[NoSuchTableException])
  def overwrite(condition: Column): Unit = {
    val overwrite = OverwriteByExpression.byName(
      UnresolvedRelation(tableName).requireWritePrivileges(Seq(INSERT, DELETE)),
      logicalPlan, expression(condition), options.toMap)
    runCommand(overwrite)
  }

  /** @inheritdoc */
  @throws(classOf[NoSuchTableException])
  def overwritePartitions(): Unit = {
    val dynamicOverwrite = OverwritePartitionsDynamic.byName(
      UnresolvedRelation(tableName).requireWritePrivileges(Seq(INSERT, DELETE)),
      logicalPlan, options.toMap)
    runCommand(dynamicOverwrite)
  }

  /**
   * Wrap an action to track the QueryExecution and time cost, then report to the user-registered
   * callback functions.
   */
  private def runCommand(command: LogicalPlan): Unit = {
    val qe = new QueryExecution(sparkSession, command, df.queryExecution.tracker)
    qe.assertCommandExecuted()
  }

  private def internalReplace(orCreate: Boolean): Unit = {
    val tableSpec = UnresolvedTableSpec(
      properties = properties.toMap,
      provider = provider,
      optionExpression = OptionList(Seq.empty),
      location = None,
      comment = None,
      serde = None,
      external = false)
    runCommand(ReplaceTableAsSelect(
      UnresolvedIdentifier(tableName),
      partitioning.getOrElse(Seq.empty) ++ clustering,
      logicalPlan,
      tableSpec,
      writeOptions = options.toMap,
      orCreate = orCreate))
  }
}

private object PartitionTransform {
  class ExtractTransform(name: String) {
    private val NAMES = Seq(name)

    def unapply(e: Expression): Option[Seq[Expression]] = e match {
      case UnresolvedFunction(NAMES, children, false, None, false, Nil, true) => Option(children)
      case _ => None
    }
  }

  val HOURS = new ExtractTransform("hours")
  val DAYS = new ExtractTransform("days")
  val MONTHS = new ExtractTransform("months")
  val YEARS = new ExtractTransform("years")
  val BUCKET = new ExtractTransform("bucket")
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy