All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.sql.delta.util.PartitionUtils.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/*
 * This file contains code from the Apache Spark project (original license above).
 * It contains modifications, which are licensed as follows:
 */

/*
 * Copyright (2020) The Delta Lake Project Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.delta.util

import java.lang.{ Double => JDouble, Long => JLong }
import java.math.{ BigDecimal => JBigDecimal }
import java.util.{ Locale, TimeZone }

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.util.Try

import org.apache.hadoop.fs.Path

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{ Attribute, Cast, Literal }
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.types._

/**
  * This file is forked from [[org.apache.spark.sql.execution.datasources.PartitioningUtils]].
  */

// In open-source Apache Spark, PartitionPath is defined as
//
//  case class PartitionPath(values: InternalRow, path: Path)
//
// but in Databricks we use a different representation where the Path is stored as a String
// and converted back to a Path only when read. This significantly cuts memory consumption because
// Hadoop Path objects are heavyweight. See SC-7591 for details.
object PartitionPath {
  // Used only in tests:
  def apply(values: InternalRow, path: String): PartitionPath = {
    // Roundtrip through `new Path` to ensure any normalization done there is applied:
    apply(values, new Path(path))
  }

  def apply(values: InternalRow, path: Path): PartitionPath = {
    new PartitionPath(values, path.toString)
  }
}

/**
  * Holds a directory in a partitioned collection of files as well as the partition values
  * in the form of a Row.  Before scanning, the files at `path` need to be enumerated.
  */
class PartitionPath private (val values: InternalRow, val pathStr: String) {
  // Note: this isn't a case class because we don't want to have a public apply() method which
  // accepts a string. The goal is to force every value stored in `pathStr` to have gone through
  // a `new Path(...).toString` to ensure that canonicalization / normalization has taken place.
  def path: Path                           = new Path(pathStr)
  def withNewValues(newValues: InternalRow): PartitionPath = {
    new PartitionPath(newValues, pathStr)
  }
  override def equals(other: Any): Boolean = other match {
    case that: PartitionPath => values == that.values && pathStr == that.pathStr
    case _                   => false
  }
  override def hashCode(): Int = {
    (values, pathStr).hashCode()
  }
  override def toString: String = {
    s"PartitionPath($values, $pathStr)"
  }
}

case class PartitionSpec(partitionColumns: StructType, partitions: Seq[PartitionPath])

object PartitionSpec {
  val emptySpec = PartitionSpec(StructType(Seq.empty[StructField]), Seq.empty[PartitionPath])
}

private[delta] object PartitionUtils {

  val timestampPartitionPattern = "yyyy-MM-dd HH:mm:ss[.S]"

  case class PartitionValues(columnNames: Seq[String], literals: Seq[Literal]) {
    require(columnNames.size == literals.size)
  }

  import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.{
    escapePathName,
    unescapePathName,
    DEFAULT_PARTITION_NAME
  }

  /**
    * Given a group of qualified paths, tries to parse them and returns a partition specification.
    * For example, given:
    * {{{
    *   hdfs://:/path/to/partition/a=1/b=hello/c=3.14
    *   hdfs://:/path/to/partition/a=2/b=world/c=6.28
    * }}}
    * it returns:
    * {{{
    *   PartitionSpec(
    *     partitionColumns = StructType(
    *       StructField(name = "a", dataType = IntegerType, nullable = true),
    *       StructField(name = "b", dataType = StringType, nullable = true),
    *       StructField(name = "c", dataType = DoubleType, nullable = true)),
    *     partitions = Seq(
    *       Partition(
    *         values = Row(1, "hello", 3.14),
    *         path = "hdfs://:/path/to/partition/a=1/b=hello/c=3.14"),
    *       Partition(
    *         values = Row(2, "world", 6.28),
    *         path = "hdfs://:/path/to/partition/a=2/b=world/c=6.28")))
    * }}}
    */
  def parsePartitions(
    paths: Seq[Path],
    typeInference: Boolean,
    basePaths: Set[Path],
    userSpecifiedSchema: Option[StructType],
    caseSensitive: Boolean,
    validatePartitionColumns: Boolean,
    timeZoneId: String
  ): PartitionSpec = {
    parsePartitions(
      paths,
      typeInference,
      basePaths,
      userSpecifiedSchema,
      caseSensitive,
      validatePartitionColumns,
      DateTimeUtils.getTimeZone(timeZoneId)
    )
  }

  def parsePartitions(
    paths: Seq[Path],
    typeInference: Boolean,
    basePaths: Set[Path],
    userSpecifiedSchema: Option[StructType],
    caseSensitive: Boolean,
    validatePartitionColumns: Boolean,
    timeZone: TimeZone
  ): PartitionSpec = {
    val userSpecifiedDataTypes = if (userSpecifiedSchema.isDefined) {
      val nameToDataType = userSpecifiedSchema.get.fields.map(f => f.name -> f.dataType).toMap
      if (!caseSensitive) {
        CaseInsensitiveMap(nameToDataType)
      } else {
        nameToDataType
      }
    } else {
      Map.empty[String, DataType]
    }

    // SPARK-26990: use user specified field names if case insensitive.
    val userSpecifiedNames = if (userSpecifiedSchema.isDefined && !caseSensitive) {
      CaseInsensitiveMap(userSpecifiedSchema.get.fields.map(f => f.name -> f.name).toMap)
    } else {
      Map.empty[String, String]
    }

    val dateFormatter                             = DateFormatter()
    val timestampFormatter                        = TimestampFormatter(timestampPartitionPattern, timeZone)
    // First, we need to parse every partition's path and see if we can find partition values.
    val (partitionValues, optDiscoveredBasePaths) = paths.map { path =>
      parsePartition(
        path,
        typeInference,
        basePaths,
        userSpecifiedDataTypes,
        validatePartitionColumns,
        timeZone,
        dateFormatter,
        timestampFormatter
      )
    }.unzip

    // We create pairs of (path -> path's partition value) here
    // If the corresponding partition value is None, the pair will be skipped
    val pathsWithPartitionValues = paths.zip(partitionValues).flatMap(x => x._2.map(x._1 -> _))

    if (pathsWithPartitionValues.isEmpty) {
      // This dataset is not partitioned.
      PartitionSpec.emptySpec
    } else {
      // This dataset is partitioned. We need to check whether all partitions have the same
      // partition columns and resolve potential type conflicts.

      // Check if there is conflicting directory structure.
      // For the paths such as:
      // var paths = Seq(
      //   "hdfs://host:9000/invalidPath",
      //   "hdfs://host:9000/path/a=10/b=20",
      //   "hdfs://host:9000/path/a=10.5/b=hello")
      // It will be recognised as conflicting directory structure:
      //   "hdfs://host:9000/invalidPath"
      //   "hdfs://host:9000/path"
      // TODO: Selective case sensitivity.
      val discoveredBasePaths = optDiscoveredBasePaths.flatten.map(_.toString.toLowerCase())
      assert(
        discoveredBasePaths.distinct.size == 1,
        "Conflicting directory structures detected. Suspicious paths:\b" +
          discoveredBasePaths.distinct.mkString("\n\t", "\n\t", "\n\n") +
          "If provided paths are partition directories, please set " +
          "\"basePath\" in the options of the data source to specify the " +
          "root directory of the table. If there are multiple root directories, " +
          "please load them separately and then union them."
      )

      val resolvedPartitionValues =
        resolvePartitions(pathsWithPartitionValues, caseSensitive, timeZone)

      // Creates the StructType which represents the partition columns.
      val fields = {
        val PartitionValues(columnNames, literals) = resolvedPartitionValues.head
        columnNames.zip(literals).map { case (name, Literal(_, dataType)) =>
          // We always assume partition columns are nullable since we've no idea whether null values
          // will be appended in the future.
          val resultName     = userSpecifiedNames.getOrElse(name, name)
          val resultDataType = userSpecifiedDataTypes.getOrElse(name, dataType)
          StructField(resultName, resultDataType, nullable = true)
        }
      }

      // Finally, we create `Partition`s based on paths and resolved partition values.
      val partitions =
        resolvedPartitionValues.zip(pathsWithPartitionValues).map { case (PartitionValues(_, literals), (path, _)) =>
          PartitionPath(InternalRow.fromSeq(literals.map(_.value)), path)
        }

      PartitionSpec(StructType(fields), partitions)
    }
  }

  /**
    * Parses a single partition, returns column names and values of each partition column, also
    * the path when we stop partition discovery.  For example, given:
    * {{{
    *   path = hdfs://:/path/to/partition/a=42/b=hello/c=3.14
    * }}}
    * it returns the partition:
    * {{{
    *   PartitionValues(
    *     Seq("a", "b", "c"),
    *     Seq(
    *       Literal.create(42, IntegerType),
    *       Literal.create("hello", StringType),
    *       Literal.create(3.14, DoubleType)))
    * }}}
    * and the path when we stop the discovery is:
    * {{{
    *   hdfs://:/path/to/partition
    * }}}
    */
  def parsePartition(
    path: Path,
    typeInference: Boolean,
    basePaths: Set[Path],
    userSpecifiedDataTypes: Map[String, DataType],
    validatePartitionColumns: Boolean,
    timeZone: TimeZone,
    dateFormatter: DateFormatter,
    timestampFormatter: TimestampFormatter
  ): (Option[PartitionValues], Option[Path]) = {
    val columns           = ArrayBuffer.empty[(String, Literal)]
    // Old Hadoop versions don't have `Path.isRoot`
    var finished          = path.getParent == null
    // currentPath is the current path that we will use to parse partition column value.
    var currentPath: Path = path

    while (!finished) {
      // Sometimes (e.g., when speculative task is enabled), temporary directories may be left
      // uncleaned. Here we simply ignore them.
      if (currentPath.getName.toLowerCase(Locale.ROOT) == "_temporary") {
        return (None, None)
      }

      if (basePaths.contains(currentPath)) {
        // If the currentPath is one of base paths. We should stop.
        finished = true
      } else {
        // Let's say currentPath is a path of "/table/a=1/", currentPath.getName will give us a=1.
        // Once we get the string, we try to parse it and find the partition column and value.
        val maybeColumn =
          parsePartitionColumn(
            currentPath.getName,
            typeInference,
            userSpecifiedDataTypes,
            validatePartitionColumns,
            timeZone,
            dateFormatter,
            timestampFormatter
          )
        maybeColumn.foreach(columns += _)

        // Now, we determine if we should stop.
        // When we hit any of the following cases, we will stop:
        //  - In this iteration, we could not parse the value of partition column and value,
        //    i.e. maybeColumn is None, and columns is not empty. At here we check if columns is
        //    empty to handle cases like /table/a=1/_temporary/something (we need to find a=1 in
        //    this case).
        //  - After we get the new currentPath, this new currentPath represent the top level dir
        //    i.e. currentPath.getParent == null. For the example of "/table/a=1/",
        //    the top level dir is "/table".
        finished = (maybeColumn.isEmpty && columns.nonEmpty) || currentPath.getParent == null

        if (!finished) {
          // For the above example, currentPath will be "/table/".
          currentPath = currentPath.getParent
        }
      }
    }

    if (columns.isEmpty) {
      (None, Some(path))
    } else {
      val (columnNames, values) = columns.reverse.unzip
      (Some(PartitionValues(columnNames, values)), Some(currentPath))
    }
  }

  private def parsePartitionColumn(
    columnSpec: String,
    typeInference: Boolean,
    userSpecifiedDataTypes: Map[String, DataType],
    validatePartitionColumns: Boolean,
    timeZone: TimeZone,
    dateFormatter: DateFormatter,
    timestampFormatter: TimestampFormatter
  ): Option[(String, Literal)] = {
    val equalSignIndex = columnSpec.indexOf('=')
    if (equalSignIndex == -1) {
      None
    } else {
      val columnName = unescapePathName(columnSpec.take(equalSignIndex))
      assert(columnName.nonEmpty, s"Empty partition column name in '$columnSpec'")

      val rawColumnValue = columnSpec.drop(equalSignIndex + 1)
      assert(rawColumnValue.nonEmpty, s"Empty partition column value in '$columnSpec'")

      val literal = if (userSpecifiedDataTypes.contains(columnName)) {
        // SPARK-26188: if user provides corresponding column schema, get the column value without
        //              inference, and then cast it as user specified data type.
        val dataType           = userSpecifiedDataTypes(columnName)
        val columnValueLiteral =
          inferPartitionColumnValue(rawColumnValue, false, timeZone, dateFormatter, timestampFormatter)
        val columnValue        = columnValueLiteral.eval()
        val castedValue        = Cast(columnValueLiteral, dataType, Option(timeZone.getID)).eval()
        if (validatePartitionColumns && columnValue != null && castedValue == null) {
          throw new RuntimeException(
            s"Failed to cast value `$columnValue` to `$dataType` " +
              s"for partition column `$columnName`"
          )
        }
        Literal.create(castedValue, dataType)
      } else {
        inferPartitionColumnValue(rawColumnValue, typeInference, timeZone, dateFormatter, timestampFormatter)
      }
      Some(columnName -> literal)
    }
  }

  /**
    * Given a partition path fragment, e.g. `fieldOne=1/fieldTwo=2`, returns a parsed spec
    * for that fragment as a `TablePartitionSpec`, e.g. `Map(("fieldOne", "1"), ("fieldTwo", "2"))`.
    */
  def parsePathFragment(pathFragment: String): TablePartitionSpec = {
    parsePathFragmentAsSeq(pathFragment).toMap
  }

  /**
    * Given a partition path fragment, e.g. `fieldOne=1/fieldTwo=2`, returns a parsed spec
    * for that fragment as a `Seq[(String, String)]`, e.g.
    * `Seq(("fieldOne", "1"), ("fieldTwo", "2"))`.
    */
  def parsePathFragmentAsSeq(pathFragment: String): Seq[(String, String)] = {
    pathFragment.split("/").map { kv =>
      val pair = kv.split("=", 2)
      (unescapePathName(pair(0)), unescapePathName(pair(1)))
    }
  }

  /**
    * This is the inverse of parsePathFragment().
    */
  def getPathFragment(spec: TablePartitionSpec, partitionSchema: StructType): String = {
    partitionSchema.map { field =>
      escapePathName(field.name) + "=" + escapePathName(spec(field.name))
    }.mkString("/")
  }

  def getPathFragment(spec: TablePartitionSpec, partitionColumns: Seq[Attribute]): String = {
    getPathFragment(spec, StructType.fromAttributes(partitionColumns))
  }

  /**
    * Normalize the column names in partition specification, w.r.t. the real partition column names
    * and case sensitivity. e.g., if the partition spec has a column named `monTh`, and there is a
    * partition column named `month`, and it's case insensitive, we will normalize `monTh` to
    * `month`.
    */
  def normalizePartitionSpec[T](
    partitionSpec: Map[String, T],
    partColNames: Seq[String],
    tblName: String,
    resolver: Resolver
  ): Map[String, T] = {
    val normalizedPartSpec = partitionSpec.toSeq.map { case (key, value) =>
      val normalizedKey = partColNames.find(resolver(_, key)).getOrElse {
        throw new AnalysisException(s"$key is not a valid partition column in table $tblName.")
      }
      normalizedKey -> value
    }

    checkColumnNameDuplication(normalizedPartSpec.map(_._1), "in the partition schema", resolver)

    normalizedPartSpec.toMap
  }

  /**
    * Resolves possible type conflicts between partitions by up-casting "lower" types using
    * [[findWiderTypeForPartitionColumn]].
    */
  def resolvePartitions(
    pathsWithPartitionValues: Seq[(Path, PartitionValues)],
    caseSensitive: Boolean,
    timeZone: TimeZone
  ): Seq[PartitionValues] = {
    if (pathsWithPartitionValues.isEmpty) {
      Seq.empty
    } else {
      val partColNames = if (caseSensitive) {
        pathsWithPartitionValues.map(_._2.columnNames)
      } else {
        pathsWithPartitionValues.map(_._2.columnNames.map(_.toLowerCase()))
      }
      assert(partColNames.distinct.size == 1, listConflictingPartitionColumns(pathsWithPartitionValues))

      // Resolves possible type conflicts for each column
      val values         = pathsWithPartitionValues.map(_._2)
      val columnCount    = values.head.columnNames.size
      val resolvedValues = (0 until columnCount).map { i =>
        resolveTypeConflicts(values.map(_.literals(i)), timeZone)
      }

      // Fills resolved literals back to each partition
      values.zipWithIndex.map { case (d, index) =>
        d.copy(literals = resolvedValues.map(_(index)))
      }
    }
  }

  def listConflictingPartitionColumns(pathWithPartitionValues: Seq[(Path, PartitionValues)]): String = {
    val distinctPartColNames = pathWithPartitionValues.map(_._2.columnNames).distinct

    def groupByKey[K, V](seq: Seq[(K, V)]): Map[K, Iterable[V]] =
      seq.groupBy { case (key, _) => key }.mapValues(_.map { case (_, value) => value })

    val partColNamesToPaths                                     = groupByKey(pathWithPartitionValues.map { case (path, partValues) =>
      partValues.columnNames -> path
    })

    val distinctPartColLists = distinctPartColNames.map(_.mkString(", ")).zipWithIndex.map { case (names, index) =>
      s"Partition column name list #$index: $names"
    }

    // Lists out those non-leaf partition directories that also contain files
    val suspiciousPaths = distinctPartColNames.sortBy(_.length).flatMap(partColNamesToPaths)

    s"Conflicting partition column names detected:\n" +
      distinctPartColLists.mkString("\n\t", "\n\t", "\n\n") +
      "For partitioned table directories, data files should only live in leaf directories.\n" +
      "And directories at the same level should have the same partition column name.\n" +
      "Please check the following directories for unexpected files or " +
      "inconsistent partition column names:\n" +
      suspiciousPaths.map("\t" + _).mkString("\n", "\n", "")
  }

  // scalastyle:off line.size.limit
  /**
    * Converts a string to a [[Literal]] with automatic type inference. Currently only supports
    * [[NullType]], [[IntegerType]], [[LongType]], [[DoubleType]], [[DecimalType]], [[DateType]]
    * [[TimestampType]], and [[StringType]].
    *
    * When resolving conflicts, it follows the table below:
    *
    * +--------------------+-------------------+-------------------+-------------------+--------------------+------------+---------------+---------------+------------+
    * | InputA \ InputB    | NullType          | IntegerType       | LongType          | DecimalType(38,0)* | DoubleType | DateType      | TimestampType | StringType |
    * +--------------------+-------------------+-------------------+-------------------+--------------------+------------+---------------+---------------+------------+
    * | NullType           | NullType          | IntegerType       | LongType          | DecimalType(38,0)  | DoubleType | DateType      | TimestampType | StringType |
    * | IntegerType        | IntegerType       | IntegerType       | LongType          | DecimalType(38,0)  | DoubleType | StringType    | StringType    | StringType |
    * | LongType           | LongType          | LongType          | LongType          | DecimalType(38,0)  | StringType | StringType    | StringType    | StringType |
    * | DecimalType(38,0)* | DecimalType(38,0) | DecimalType(38,0) | DecimalType(38,0) | DecimalType(38,0)  | StringType | StringType    | StringType    | StringType |
    * | DoubleType         | DoubleType        | DoubleType        | StringType        | StringType         | DoubleType | StringType    | StringType    | StringType |
    * | DateType           | DateType          | StringType        | StringType        | StringType         | StringType | DateType      | TimestampType | StringType |
    * | TimestampType      | TimestampType     | StringType        | StringType        | StringType         | StringType | TimestampType | TimestampType | StringType |
    * | StringType         | StringType        | StringType        | StringType        | StringType         | StringType | StringType    | StringType    | StringType |
    * +--------------------+-------------------+-------------------+-------------------+--------------------+------------+---------------+---------------+------------+
    * Note that, for DecimalType(38,0)*, the table above intentionally does not cover all other
    * combinations of scales and precisions because currently we only infer decimal type like
    * `BigInteger`/`BigInt`. For example, 1.1 is inferred as double type.
    */
  // scalastyle:on line.size.limit
  def inferPartitionColumnValue(
    raw: String,
    typeInference: Boolean,
    timeZone: TimeZone,
    dateFormatter: DateFormatter,
    timestampFormatter: TimestampFormatter
  ): Literal = {
    val decimalTry = Try {
      // `BigDecimal` conversion can fail when the `field` is not a form of number.
      val bigDecimal = new JBigDecimal(raw)
      // It reduces the cases for decimals by disallowing values having scale (eg. `1.1`).
      require(bigDecimal.scale <= 0)
      // `DecimalType` conversion can fail when
      //   1. The precision is bigger than 38.
      //   2. scale is bigger than precision.
      Literal(bigDecimal)
    }

    val dateTry = Try {
      // try and parse the date, if no exception occurs this is a candidate to be resolved as
      // DateType
      dateFormatter.parse(raw)
      // SPARK-23436: Casting the string to date may still return null if a bad Date is provided.
      // This can happen since DateFormat.parse  may not use the entire text of the given string:
      // so if there are extra-characters after the date, it returns correctly.
      // We need to check that we can cast the raw string since we later can use Cast to get
      // the partition values with the right DataType (see
      // org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex.inferPartitioning)
      val dateValue = Cast(Literal(raw), DateType).eval()
      // Disallow DateType if the cast returned null
      require(dateValue != null)
      Literal.create(dateValue, DateType)
    }

    val timestampTry = Try {
      val unescapedRaw   = unescapePathName(raw)
      // try and parse the date, if no exception occurs this is a candidate to be resolved as
      // TimestampType
      timestampFormatter.parse(unescapedRaw)
      // SPARK-23436: see comment for date
      val timestampValue = Cast(Literal(unescapedRaw), TimestampType, Some(timeZone.getID)).eval()
      // Disallow TimestampType if the cast returned null
      require(timestampValue != null)
      Literal.create(timestampValue, TimestampType)
    }

    if (typeInference) {
      // First tries integral types
      Try(Literal.create(Integer.parseInt(raw), IntegerType))
        .orElse(Try(Literal.create(JLong.parseLong(raw), LongType)))
        .orElse(decimalTry)
        // Then falls back to fractional types
        .orElse(Try(Literal.create(JDouble.parseDouble(raw), DoubleType)))
        // Then falls back to date/timestamp types
        .orElse(timestampTry)
        .orElse(dateTry)
        // Then falls back to string
        .getOrElse {
          if (raw == DEFAULT_PARTITION_NAME) {
            Literal.create(null, NullType)
          } else {
            Literal.create(unescapePathName(raw), StringType)
          }
        }
    } else {
      if (raw == DEFAULT_PARTITION_NAME) {
        Literal.create(null, NullType)
      } else {
        Literal.create(unescapePathName(raw), StringType)
      }
    }
  }

  def validatePartitionColumn(schema: StructType, partitionColumns: Seq[String], caseSensitive: Boolean): Unit = {
    checkColumnNameDuplication(partitionColumns, "in the partition columns", caseSensitive)

    partitionColumnsSchema(schema, partitionColumns, caseSensitive).foreach { field =>
      field.dataType match {
        case _: AtomicType => // OK
        case _             => throw new AnalysisException(s"Cannot use ${field.dataType} for partition column")
      }
    }

    if (partitionColumns.nonEmpty && partitionColumns.size == schema.fields.length) {
      throw new AnalysisException(s"Cannot use all columns for partition columns")
    }
  }

  def partitionColumnsSchema(schema: StructType, partitionColumns: Seq[String], caseSensitive: Boolean): StructType = {
    val equality = columnNameEquality(caseSensitive)
    StructType(partitionColumns.map { col =>
      schema.find(f => equality(f.name, col)).getOrElse {
        val schemaCatalog = schema.catalogString
        throw new AnalysisException(s"Partition column `$col` not found in schema $schemaCatalog")
      }
    }).asNullable
  }

  def mergeDataAndPartitionSchema(
    dataSchema: StructType,
    partitionSchema: StructType,
    caseSensitive: Boolean
  ): (StructType, Map[String, StructField]) = {
    val overlappedPartCols = mutable.Map.empty[String, StructField]
    partitionSchema.foreach { partitionField =>
      val partitionFieldName = getColName(partitionField, caseSensitive)
      if (dataSchema.exists(getColName(_, caseSensitive) == partitionFieldName)) {
        overlappedPartCols += partitionFieldName -> partitionField
      }
    }

    // When data and partition schemas have overlapping columns, the output
    // schema respects the order of the data schema for the overlapping columns, and it
    // respects the data types of the partition schema.
    // `HadoopFsRelation` will be mapped to `FileSourceScanExec`, which always output
    // all the partition columns physically. Here we need to make sure the final schema
    // contains all the partition columns.
    val fullSchema =
      StructType(
        dataSchema.map(f => overlappedPartCols.getOrElse(getColName(f, caseSensitive), f)) ++
          partitionSchema.filterNot(f => overlappedPartCols.contains(getColName(f, caseSensitive)))
      )
    (fullSchema, overlappedPartCols.toMap)
  }

  def getColName(f: StructField, caseSensitive: Boolean): String = {
    if (caseSensitive) {
      f.name
    } else {
      f.name.toLowerCase(Locale.ROOT)
    }
  }

  private def columnNameEquality(caseSensitive: Boolean): (String, String) => Boolean = {
    if (caseSensitive) {
      org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
    } else {
      org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution
    }
  }

  /**
    * Given a collection of [[Literal]]s, resolves possible type conflicts by
    * [[findWiderTypeForPartitionColumn]].
    */
  private def resolveTypeConflicts(literals: Seq[Literal], timeZone: TimeZone): Seq[Literal] = {
    val litTypes    = literals.map(_.dataType)
    val desiredType = litTypes.reduce(findWiderTypeForPartitionColumn)

    literals.map { case l @ Literal(_, dataType) =>
      Literal.create(Cast(l, desiredType, Some(timeZone.getID)).eval(), desiredType)
    }
  }

  /**
    * Type widening rule for partition column types. It is similar to
    * [[TypeCoercion.findWiderTypeForTwo]] but the main difference is that here we disallow
    * precision loss when widening double/long and decimal, and fall back to string.
    */
  private val findWiderTypeForPartitionColumn: (DataType, DataType) => DataType = {
    case (DoubleType, _: DecimalType) | (_: DecimalType, DoubleType) => StringType
    case (DoubleType, LongType) | (LongType, DoubleType)             => StringType
    case (t1, t2)                                                    => TypeCoercion.findWiderTypeForTwo(t1, t2).getOrElse(StringType)
  }

  /** The methods below are forked from [[org.apache.spark.sql.util.SchemaUtils]] */

  /**
    * Checks if input column names have duplicate identifiers. This throws an exception if
    * the duplication exists.
    *
    * @param columnNames column names to check
    * @param colType column type name, used in an exception message
    * @param resolver resolver used to determine if two identifiers are equal
    */
  def checkColumnNameDuplication(columnNames: Seq[String], colType: String, resolver: Resolver): Unit = {
    checkColumnNameDuplication(columnNames, colType, isCaseSensitiveAnalysis(resolver))
  }

  /**
    * Checks if input column names have duplicate identifiers. This throws an exception if
    * the duplication exists.
    *
    * @param columnNames column names to check
    * @param colType column type name, used in an exception message
    * @param caseSensitiveAnalysis whether duplication checks should be case sensitive or not
    */
  def checkColumnNameDuplication(columnNames: Seq[String], colType: String, caseSensitiveAnalysis: Boolean): Unit = {
    // scalastyle:off caselocale
    val names = if (caseSensitiveAnalysis) columnNames else columnNames.map(_.toLowerCase)
    // scalastyle:on caselocale
    if (names.distinct.length != names.length) {
      val duplicateColumns = names.groupBy(identity).collect {
        case (x, ys) if ys.length > 1 => s"`$x`"
      }
      throw new AnalysisException(s"Found duplicate column(s) $colType: ${duplicateColumns.mkString(", ")}")
    }
  }

  // Returns true if a given resolver is case-sensitive
  private def isCaseSensitiveAnalysis(resolver: Resolver): Boolean = {
    if (resolver == caseSensitiveResolution) {
      true
    } else if (resolver == caseInsensitiveResolution) {
      false
    } else {
      sys.error(
        "A resolver to check if two identifiers are equal must be " +
          "`caseSensitiveResolution` or `caseInsensitiveResolution` in o.a.s.sql.catalyst."
      )
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy