All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.sql.delta.schema.SchemaUtils.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright (2020) The Delta Lake Project Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.delta.schema

import scala.collection.Set._
import scala.collection.mutable
import scala.util.control.NonFatal

import org.apache.spark.sql.delta.DeltaErrors

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.analysis.{ Resolver, UnresolvedAttribute }
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaConverter
import org.apache.spark.sql.functions.{ col, struct }
import org.apache.spark.sql.types._

object SchemaUtils {
  // We use case insensitive resolution while writing into Delta
  val DELTA_COL_RESOLVER: (String, String) => Boolean =
    org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution
  private val ARRAY_ELEMENT_INDEX                     = 0
  private val MAP_KEY_INDEX                           = 0
  private val MAP_VALUE_INDEX                         = 1

  /**
    * Finds `StructField`s that match a given check `f`. Returns the path to the column, and the
    * field.
    *
    * @param checkComplexTypes While `StructType` is also a complex type, since we're returning
    *                          StructFields, we definitely recurse into StructTypes. This flag
    *                          defines whether we should recurse into ArrayType and MapType.
    */
  def filterRecursively(schema: StructType, checkComplexTypes: Boolean)(
    f: StructField => Boolean
  ): Seq[(Seq[String], StructField)] = {
    def recurseIntoComplexTypes(complexType: DataType, columnStack: Seq[String]): Seq[(Seq[String], StructField)] =
      complexType match {
        case s: StructType                     =>
          s.fields.flatMap { sf =>
            val includeLevel = if (f(sf)) Seq((columnStack, sf)) else Nil
            includeLevel ++ recurseIntoComplexTypes(sf.dataType, columnStack :+ sf.name)
          }
        case a: ArrayType if checkComplexTypes => recurseIntoComplexTypes(a.elementType, columnStack)
        case m: MapType if checkComplexTypes   =>
          recurseIntoComplexTypes(m.keyType, columnStack :+ "key") ++
            recurseIntoComplexTypes(m.valueType, columnStack :+ "value")
        case _                                 => Nil
      }

    recurseIntoComplexTypes(schema, Nil)
  }

  /** Copied over from DataType for visibility reasons. */
  def typeExistsRecursively(dt: DataType)(f: DataType => Boolean): Boolean = dt match {
    case s: StructType =>
      f(s) || s.fields.exists(field => typeExistsRecursively(field.dataType)(f))
    case a: ArrayType  =>
      f(a) || typeExistsRecursively(a.elementType)(f)
    case m: MapType    =>
      f(m) || typeExistsRecursively(m.keyType)(f) || typeExistsRecursively(m.valueType)(f)
    case other         =>
      f(other)
  }

  /** Turns the data types to nullable in a recursive manner for nested columns. */
  def typeAsNullable(dt: DataType): DataType = dt match {
    case s: StructType                                  => s.asNullable
    case a @ ArrayType(s: StructType, _)                => a.copy(s.asNullable, containsNull = true)
    case a: ArrayType                                   => a.copy(containsNull = true)
    case m @ MapType(s1: StructType, s2: StructType, _) =>
      m.copy(s1.asNullable, s2.asNullable, valueContainsNull = true)
    case m @ MapType(s1: StructType, _, _)              =>
      m.copy(keyType = s1.asNullable, valueContainsNull = true)
    case m @ MapType(_, s2: StructType, _)              =>
      m.copy(valueType = s2.asNullable, valueContainsNull = true)
    case other                                          => other
  }

  /**
    * Drops null types from the DataFrame if they exist. We don't have easy ways of generating types
    * such as MapType and ArrayType, therefore if these types contain NullType in their elements,
    * we will throw an AnalysisException.
    */
  def dropNullTypeColumns(df: DataFrame): DataFrame = {
    val schema                                                              = df.schema
    if (!typeExistsRecursively(schema)(_.isInstanceOf[NullType])) return df
    def generateSelectExpr(sf: StructField, nameStack: Seq[String]): Column = sf.dataType match {
      case st: StructType                                                     =>
        val nested = st.fields.flatMap { f =>
          if (f.dataType.isInstanceOf[NullType]) {
            None
          } else {
            Some(generateSelectExpr(f, nameStack :+ sf.name))
          }
        }
        struct(nested: _*).alias(sf.name)
      case a: ArrayType if typeExistsRecursively(a)(_.isInstanceOf[NullType]) =>
        val colName = UnresolvedAttribute.apply(nameStack :+ sf.name).name
        throw new AnalysisException(
          s"Found nested NullType in column $colName which is of ArrayType. Delta doesn't " +
            "support writing NullType in complex types."
        )
      case m: MapType if typeExistsRecursively(m)(_.isInstanceOf[NullType])   =>
        val colName = UnresolvedAttribute.apply(nameStack :+ sf.name).name
        throw new AnalysisException(
          s"Found nested NullType in column $colName which is of MapType. Delta doesn't " +
            "support writing NullType in complex types."
        )
      case _                                                                  =>
        val colName = UnresolvedAttribute.apply(nameStack :+ sf.name).name
        col(colName).alias(sf.name)
    }

    val selectExprs = schema.flatMap { f =>
      if (f.dataType.isInstanceOf[NullType]) None else Some(generateSelectExpr(f, Nil))
    }
    df.select(selectExprs: _*)
  }

  /**
    * Returns all column names in this schema as a flat list. For example, a schema like:
    *   | - a
    *   | | - 1
    *   | | - 2
    *   | - b
    *   | - c
    *   | | - nest
    *   |   | - 3
    *   will get flattened to: "a", "a.1", "a.2", "b", "c", "c.nest", "c.nest.3"
    */
  def explodeNestedFieldNames(schema: StructType): Seq[String] = {
    def explode(schema: StructType): Seq[Seq[String]] = {
      def recurseIntoComplexTypes(complexType: DataType): Seq[Seq[String]] = {
        complexType match {
          case s: StructType => explode(s)
          case a: ArrayType  => recurseIntoComplexTypes(a.elementType)
          case m: MapType    =>
            recurseIntoComplexTypes(m.keyType).map(Seq("key") ++ _) ++
              recurseIntoComplexTypes(m.valueType).map(Seq("value") ++ _)
          case _             => Nil
        }
      }

      schema.flatMap {
        case StructField(name, s: StructType, _, _) =>
          Seq(Seq(name)) ++ explode(s).map(nested => Seq(name) ++ nested)
        case StructField(name, a: ArrayType, _, _)  =>
          Seq(Seq(name)) ++ recurseIntoComplexTypes(a).map(nested => Seq(name) ++ nested)
        case StructField(name, m: MapType, _, _)    =>
          Seq(Seq(name)) ++ recurseIntoComplexTypes(m).map(nested => Seq(name) ++ nested)
        case f                                      => Seq(f.name) :: Nil
      }
    }

    explode(schema).map(UnresolvedAttribute.apply(_).name)
  }

  /**
    * Checks if input column names have duplicate identifiers. This throws an exception if
    * the duplication exists.
    *
    * @param schema the schema to check for duplicates
    * @param colType column type name, used in an exception message
    */
  def checkColumnNameDuplication(schema: StructType, colType: String): Unit = {
    val columnNames = explodeNestedFieldNames(schema)
    // scalastyle:off caselocale
    val names       = columnNames.map(_.toLowerCase)
    // scalastyle:on caselocale
    if (names.distinct.length != names.length) {
      val duplicateColumns = names.groupBy(identity).collect {
        case (x, ys) if ys.length > 1 => s"$x"
      }
      throw new AnalysisException(s"Found duplicate column(s) $colType: ${duplicateColumns.mkString(", ")}")
    }
  }

  /**
    * Rewrite the query field names according to the table schema. This method assumes that all
    * schema validation checks have been made and this is the last operation before writing into
    * Delta.
    */
  def normalizeColumnNames(baseSchema: StructType, data: Dataset[_]): DataFrame = {
    val dataSchema  = data.schema
    val dataFields  = explodeNestedFieldNames(dataSchema).toSet
    val tableFields = explodeNestedFieldNames(baseSchema).toSet
    if (dataFields.subsetOf(tableFields)) {
      data.toDF()
    } else {
      // Check that nested columns don't need renaming. We can't handle that right now
      val topLevelDataFields = dataFields.map(UnresolvedAttribute.parseAttributeName(_).head)
      if (topLevelDataFields.subsetOf(tableFields)) {
        val columnsThatNeedRenaming = dataFields -- tableFields
        throw new AnalysisException(
          "Nested fields need renaming to avoid data loss. " +
            s"Fields:\n${columnsThatNeedRenaming.mkString("[", ", ", "]")}.\n" +
            s"Original schema:\n${baseSchema.treeString}"
        )
      }

      val baseFields       = toFieldMap(baseSchema)
      val aliasExpressions = dataSchema.map { field =>
        val originalCase = baseFields.getOrElse(
          field.name,
          throw new AnalysisException(s"Can't resolve column ${field.name} in ${baseSchema.treeString}")
        )
        if (originalCase.name != field.name) {
          functions.col(field.name).as(originalCase.name)
        } else {
          functions.col(field.name)
        }
      }
      data.select(aliasExpressions: _*)
    }
  }

  /**
    * As the Delta snapshots update, the schema may change as well. This method defines whether the
    * new schema of a Delta table can be used with a previously analyzed LogicalPlan. Our
    * rules are to return false if:
    *   - Dropping any column that was present in the DataFrame schema
    *   - Converting nullable=false to nullable=true for any column
    *   - Any change of datatype
    */
  def isReadCompatible(existingSchema: StructType, readSchema: StructType): Boolean = {

    def isDatatypeReadCompatible(existing: DataType, newtype: DataType): Boolean = {
      (existing, newtype) match {
        case (e: StructType, n: StructType) =>
          isReadCompatible(e, n)
        case (e: ArrayType, n: ArrayType)   =>
          // if existing elements are non-nullable, so should be the new element
          (e.containsNull || !n.containsNull) &&
            isDatatypeReadCompatible(e.elementType, n.elementType)
        case (e: MapType, n: MapType)       =>
          // if existing value is non-nullable, so should be the new value
          (e.valueContainsNull || !n.valueContainsNull) &&
            isDatatypeReadCompatible(e.keyType, n.keyType) &&
            isDatatypeReadCompatible(e.valueType, n.valueType)
        case (a, b)                         => a == b
      }
    }

    def isStructReadCompatible(existing: StructType, newtype: StructType): Boolean = {
      val existing           = toFieldMap(existingSchema)
      // scalastyle:off caselocale
      val existingFieldNames = existingSchema.fieldNames.map(_.toLowerCase).toSet
      assert(
        existingFieldNames.size == existingSchema.length,
        "Delta tables don't allow field names that only differ by case"
      )
      val newFields          = readSchema.fieldNames.map(_.toLowerCase).toSet
      assert(newFields.size == readSchema.length, "Delta tables don't allow field names that only differ by case")
      // scalastyle:on caselocale

      if (!existingFieldNames.subsetOf(newFields)) {
        // Dropped a column that was present in the DataFrame schema
        return false
      }
      readSchema.forall { newField =>
        // new fields are fine, they just won't be returned
        existing.get(newField.name).forall { existingField =>
          // we know the name matches modulo case - now verify exact match
          (existingField.name == newField.name
          // if existing value is non-nullable, so should be the new value
          && (existingField.nullable || !newField.nullable)
          // and the type of the field must be compatible, too
          && isDatatypeReadCompatible(existingField.dataType, newField.dataType))
        }
      }
    }

    isStructReadCompatible(existingSchema, readSchema)
  }

  /**
    * Compare an existing schema to a specified new schema and
    * return a message describing the first difference found, if any:
    *   - different field name or datatype
    *   - different metadata
    */
  def reportDifferences(existingSchema: StructType, specifiedSchema: StructType): Seq[String] = {

    def canOrNot(can: Boolean) = if (can) "can" else "can not"
    def isOrNon(b: Boolean)    = if (b) "" else "non-"

    def missingFieldsMessage(fields: Set[String]): String = {
      s"Specified schema is missing field(s): ${fields.mkString(", ")}"
    }
    def additionalFieldsMessage(fields: Set[String]): String = {
      s"Specified schema has additional field(s): ${fields.mkString(", ")}"
    }
    def fieldNullabilityMessage(field: String, specified: Boolean, existing: Boolean): String = {
      s"Field $field is ${isOrNon(specified)}nullable in specified " +
        s"schema but ${isOrNon(existing)}nullable in existing schema."
    }
    def arrayNullabilityMessage(field: String, specified: Boolean, existing: Boolean): String = {
      s"Array field $field ${canOrNot(specified)} contain null in specified schema " +
        s"but ${canOrNot(existing)} in existing schema"
    }
    def valueNullabilityMessage(field: String, specified: Boolean, existing: Boolean): String = {
      s"Map field $field ${canOrNot(specified)} contain null values in specified schema " +
        s"but ${canOrNot(existing)} in existing schema"
    }
    def metadataDifferentMessage(field: String, specified: Metadata, existing: Metadata): String = {
      s"""Specified metadata for field $field is different from existing schema:
         |Specified: $specified
         |Existing:  $existing""".stripMargin
    }
    def typeDifferenceMessage(field: String, specified: DataType, existing: DataType): String = {
      s"""Specified type for $field is different from existing schema:
         |Specified: ${specified.typeName}
         |Existing:  ${existing.typeName}""".stripMargin
    }

    // prefix represents the nested field(s) containing this schema
    def structDifference(existing: StructType, specified: StructType, prefix: String): Seq[String] = {

      // 1. ensure set of fields is the same
      val existingFieldNames  = existing.fieldNames.toSet
      val specifiedFieldNames = specified.fieldNames.toSet

      val missingFields      = existingFieldNames diff specifiedFieldNames
      val missingFieldsDiffs =
        if (missingFields.isEmpty) Nil
        else Seq(missingFieldsMessage(missingFields.map(prefix + _)))

      val extraFields      = specifiedFieldNames diff existingFieldNames
      val extraFieldsDiffs =
        if (extraFields.isEmpty) Nil
        else Seq(additionalFieldsMessage(extraFields.map(prefix + _)))

      // 2. for each common field, ensure it has the same type and metadata
      val existingFields  = toFieldMap(existing)
      val specifiedFields = toFieldMap(specified)
      val fieldsDiffs     = (existingFieldNames intersect specifiedFieldNames).flatMap((name: String) =>
        fieldDifference(existingFields(name), specifiedFields(name), prefix)
      )

      missingFieldsDiffs ++ extraFieldsDiffs ++ fieldsDiffs
    }

    def fieldDifference(existing: StructField, specified: StructField, prefix: String): Seq[String] = {

      val name             = s"$prefix${existing.name}"
      val nullabilityDiffs =
        if (existing.nullable == specified.nullable) Nil
        else Seq(fieldNullabilityMessage(s"$name", specified.nullable, existing.nullable))
      val metadataDiffs    =
        if (existing.metadata == specified.metadata) Nil
        else Seq(metadataDifferentMessage(s"$name", specified.metadata, existing.metadata))
      val typeDiffs        =
        typeDifference(existing.dataType, specified.dataType, name)

      nullabilityDiffs ++ metadataDiffs ++ typeDiffs
    }

    def typeDifference(existing: DataType, specified: DataType, field: String): Seq[String] = {

      (existing, specified) match {
        case (e: StructType, s: StructType) => structDifference(e, s, s"$field.")
        case (e: ArrayType, s: ArrayType)   => arrayDifference(e, s, s"$field[]")
        case (e: MapType, s: MapType)       => mapDifference(e, s, s"$field")
        case (e, s) if e != s               => Seq(typeDifferenceMessage(field, s, e))
        case _                              => Nil
      }
    }

    def arrayDifference(existing: ArrayType, specified: ArrayType, field: String): Seq[String] = {

      val elementDiffs     =
        typeDifference(existing.elementType, specified.elementType, field)
      val nullabilityDiffs =
        if (existing.containsNull == specified.containsNull) Nil
        else Seq(arrayNullabilityMessage(field, specified.containsNull, existing.containsNull))

      elementDiffs ++ nullabilityDiffs
    }

    def mapDifference(existing: MapType, specified: MapType, field: String): Seq[String] = {

      val keyDiffs         =
        typeDifference(existing.keyType, specified.keyType, s"$field[key]")
      val valueDiffs       =
        typeDifference(existing.valueType, specified.valueType, s"$field[value]")
      val nullabilityDiffs =
        if (existing.valueContainsNull == specified.valueContainsNull) Nil
        else Seq(valueNullabilityMessage(field, specified.valueContainsNull, existing.valueContainsNull))

      keyDiffs ++ valueDiffs ++ nullabilityDiffs
    }

    structDifference(existingSchema, specifiedSchema, "")
  }

  /**
    * Returns the given column's ordinal within the given `schema` and the size of the last schema
    * size. The length of the returned position will be as long as how nested the column is.
    *
    * For ArrayType: accessing the array's element adds a position 0 to the position list.
    * e.g. accessing a.element.y would have the result -> Seq(..., positionOfA, 0, positionOfY)
    *
    * For MapType: accessing the map's key adds a position 0 to the position list.
    * e.g. accessing m.key.y would have the result -> Seq(..., positionOfM, 0, positionOfY)
    *
    * For MapType: accessing the map's value adds a position 1 to the position list.
    * e.g. accessing m.key.y would have the result -> Seq(..., positionOfM, 1, positionOfY)
    *
    * @param column The column to search for in the given struct. If the length of `column` is
    *               greater than 1, we expect to enter a nested field.
    * @param schema The current struct we are looking at.
    * @param resolver The resolver to find the column.
    */
  def findColumnPosition(
    column: Seq[String],
    schema: StructType,
    resolver: Resolver = DELTA_COL_RESOLVER
  ): (Seq[Int], Int) = {
    def find(column: Seq[String], schema: StructType, stack: Seq[String]): (Seq[Int], Int) = {
      if (column.isEmpty) return (Nil, schema.size)
      val thisCol              = column.head
      lazy val columnPath      = UnresolvedAttribute(stack :+ thisCol).name
      val pos                  = schema.indexWhere(f => resolver(f.name, thisCol))
      if (pos == -1) {
        throw new IndexOutOfBoundsException(columnPath)
      }
      val colTail              = column.tail
      val (children, lastSize) = (colTail, schema(pos).dataType) match {
        case (_, s: StructType)                                          =>
          find(colTail, s, stack :+ thisCol)
        case (Seq("element", _ @_*), ArrayType(s: StructType, _))        =>
          val (child, size) = find(colTail.tail, s, stack :+ thisCol)
          (ARRAY_ELEMENT_INDEX +: child, size)
        case (Seq(), ArrayType(s: StructType, _))                        =>
          find(colTail, s, stack :+ thisCol)
        case (Seq(), ArrayType(_, _))                                    =>
          (Seq(0), 0)
        case (_, ArrayType(_, _))                                        =>
          throw new AnalysisException(
            s"""An ArrayType was found. In order to access elements of an ArrayType, specify
               |${prettyFieldName(stack ++ Seq(thisCol, "element"))}
               |Instead of ${prettyFieldName(stack ++ Seq(thisCol))}
               """.stripMargin
          )
        case (Seq(), MapType(_, _, _))                                   =>
          (Nil, 2)
        case (Seq("key", _ @_*), MapType(keyType: StructType, _, _))     =>
          val (child, size) = find(colTail.tail, keyType, stack :+ thisCol)
          (MAP_KEY_INDEX +: child, size)
        case (Seq("key"), MapType(_, _, _))                              =>
          (Seq(MAP_KEY_INDEX), 0)
        case (Seq("value", _ @_*), MapType(_, valueType: StructType, _)) =>
          val (child, size) = find(colTail.tail, valueType, stack :+ thisCol)
          (MAP_VALUE_INDEX +: child, size)
        case (Seq("value"), MapType(_, _, _))                            =>
          (Seq(MAP_VALUE_INDEX), 0)
        case (_, MapType(_, _, _))                                       =>
          throw new AnalysisException(
            s"""A MapType was found. In order to access the key or value of a MapType, specify one
               |of:
               |${prettyFieldName(stack ++ Seq(thisCol, "key"))} or
               |${prettyFieldName(stack ++ Seq(thisCol, "value"))}
               |followed by the name of the column (only if that column is a struct type).
               |e.g. mymap.key.mykey
               |If the column is a basic type, mymap.key or mymap.value is sufficient.
              """.stripMargin
          )
        case (_, o)                                                      =>
          if (column.length > 1) {
            throw new AnalysisException(
              s"""Expected $columnPath to be a nested data type, but found $o. Was looking for the
                 |index of ${prettyFieldName(column)} in a nested field
              """.stripMargin
            )
          }
          (Nil, 0)
      }
      (Seq(pos) ++ children, lastSize)
    }

    try {
      find(column, schema, Nil)
    } catch {
      case i: IndexOutOfBoundsException =>
        throw DeltaErrors.columnNotInSchemaException(i.getMessage, schema)
      case e: AnalysisException         =>
        throw new AnalysisException(e.getMessage + s":\n${schema.treeString}")
    }
  }

  /**
    * Pretty print the column path passed in.
    */
  def prettyFieldName(columnPath: Seq[String]): String = {
    UnresolvedAttribute(columnPath).name
  }

  /**
    * Add `column` to the specified `position` in `schema`.
    * @param position A Seq of ordinals on where this column should go. It is a Seq to denote
    *                 positions in nested columns (0-based). For example:
    *
    *                 tableSchema: , b,c:STRUCT>
    *                 column: c2
    *                 position: Seq(2, 1)
    *                 will return
    *                 result: , b,c:STRUCT>
    */
  def addColumn(schema: StructType, column: StructField, position: Seq[Int]): StructType = {
    require(position.nonEmpty, s"Don't know where to add the column $column")
    val slicePosition = position.head
    if (slicePosition < 0) {
      throw new AnalysisException(s"Index $slicePosition to add column $column is lower than 0")
    }
    val length        = schema.length
    if (slicePosition > length) {
      throw new AnalysisException(s"Index $slicePosition to add column $column is larger than struct length: $length")
    }
    if (slicePosition == length) {
      if (position.length > 1) {
        throw new AnalysisException(s"Struct not found at position $slicePosition")
      }
      return StructType(schema :+ column)
    }
    val pre           = schema.take(slicePosition)
    if (position.length > 1) {
      val posTail = position.tail
      val mid     = schema(slicePosition) match {
        case StructField(name, f: StructType, nullable, metadata)                          =>
          if (!column.nullable && nullable) {
            throw new AnalysisException(
              "A non-nullable nested field can't be added to a nullable parent. Please set the " +
                "nullability of the parent column accordingly."
            )
          }
          StructField(name, addColumn(f, column, posTail), nullable, metadata)
        case StructField(name, ArrayType(f: StructType, containsNull), nullable, metadata) =>
          if (!column.nullable && nullable) {
            throw new AnalysisException(
              "A non-nullable nested field can't be added to a nullable parent. Please set the " +
                "nullability of the parent column accordingly."
            )
          }

          if (posTail.head != ARRAY_ELEMENT_INDEX) {
            throw new AnalysisException(
              s"""Incorrectly accessing an ArrayType. Use arrayname.element.elementname position to
                 |add to an array.
               """.stripMargin
            )
          }

          StructField(name, ArrayType(addColumn(f, column, posTail.tail), containsNull), nullable, metadata)
        case StructField(name, map @ MapType(_, _, _), nullable, metadata)                 =>
          if (!column.nullable && nullable) {
            throw new AnalysisException(
              "A non-nullable nested field can't be added to a nullable parent. Please set the " +
                "nullability of the parent column accordingly."
            )
          }

          val addedMap = (posTail.head, map) match {
            case (MAP_KEY_INDEX, MapType(key: StructType, v, nullability))     =>
              MapType(addColumn(key, column, posTail.tail), v, nullability)
            case (MAP_VALUE_INDEX, MapType(k, value: StructType, nullability)) =>
              MapType(k, addColumn(value, column, posTail.tail), nullability)
            case _                                                             =>
              throw new AnalysisException(s"""
                                             |Cannot add ${column.name} because its parent is not a StructType.
                """.stripMargin)
          }
          StructField(name, addedMap, nullable, metadata)
        case o                                                                             =>
          throw new AnalysisException(
            s"Cannot add ${column.name} because its parent is not a " +
              s"StructType. Found ${o.dataType}"
          )
      }
      StructType(pre ++ Seq(mid) ++ schema.slice(slicePosition + 1, length))
    } else {
      StructType(pre ++ Seq(column) ++ schema.slice(slicePosition, length))
    }
  }

  // TODO @pranavanand: This method is no longer being used by AlterTable. If transformColumnsStruct
  // works sufficiently, remove this method
  /**
    * Drop from the specified `position` in `schema` and return with the original column.
    * @param position A Seq of ordinals on where this column should go. It is a Seq to denote
    *                 positions in nested columns (0-based). For example:
    *
    *                 tableSchema: , b,c:STRUCT>
    *                 position: Seq(2, 1)
    *                 will return
    *                 result: , b,c:STRUCT>
    */
  def dropColumn(schema: StructType, position: Seq[Int]): (StructType, StructField) = {
    require(position.nonEmpty, "Don't know where to drop the column")
    val slicePosition = position.head
    if (slicePosition < 0) {
      throw new AnalysisException(s"Index $slicePosition to drop column is lower than 0")
    }
    val length        = schema.length
    if (slicePosition >= length) {
      throw new AnalysisException(
        s"Index $slicePosition to drop column equals to or is larger than struct length: $length"
      )
    }
    val pre           = schema.take(slicePosition)
    if (position.length > 1) {
      val (mid, original) = schema(slicePosition) match {
        case StructField(name, f: StructType, nullable, metadata) =>
          val (dropped, original) = dropColumn(f, position.tail)
          (StructField(name, dropped, nullable, metadata), original)
        case o                                                    =>
          throw new AnalysisException(s"Can only drop nested columns from StructType. Found: $o")
      }
      (StructType(pre ++ Seq(mid) ++ schema.slice(slicePosition + 1, length)), original)
    } else {
      (StructType(pre ++ schema.slice(slicePosition + 1, length)), schema(slicePosition))
    }
  }

  /**
    * Check if the two data types can be changed.
    *
    * @return None if the data types can be changed, otherwise Some(err) containing the reason.
    */
  def canChangeDataType(
    from: DataType,
    to: DataType,
    resolver: Resolver,
    columnPath: Seq[String] = Seq.empty
  ): Option[String] = {
    def verify(cond: Boolean, err: => String): Unit = {
      if (!cond) {
        throw new AnalysisException(err)
      }
    }

    def verifyNullability(fn: Boolean, tn: Boolean, columnPath: Seq[String]): Unit = {
      verify(tn || !fn, s"tightening nullability of ${UnresolvedAttribute(columnPath).name}")
    }

    def check(fromDt: DataType, toDt: DataType, columnPath: Seq[String]): Unit = {
      (fromDt, toDt) match {
        case (ArrayType(fromElement, fn), ArrayType(toElement, tn)) =>
          verifyNullability(fn, tn, columnPath)
          check(fromElement, toElement, columnPath)

        case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) =>
          verifyNullability(fn, tn, columnPath)
          check(fromKey, toKey, columnPath :+ "key")
          check(fromValue, toValue, columnPath :+ "value")

        case (StructType(fromFields), StructType(toFields)) =>
          val remainingFields = fromFields.to[mutable.Set]
          toFields.foreach { toField =>
            fromFields.find(field => resolver(field.name, toField.name)) match {
              case Some(fromField) =>
                remainingFields -= fromField

                val newPath = columnPath :+ fromField.name
                verifyNullability(fromField.nullable, toField.nullable, newPath)
                check(fromField.dataType, toField.dataType, newPath)
              case None            =>
                verify(
                  toField.nullable,
                  "adding non-nullable column " +
                    UnresolvedAttribute(columnPath :+ toField.name).name
                )
            }
          }
          verify(
            remainingFields.isEmpty,
            s"dropping column(s) [${remainingFields.map(_.name).mkString(", ")}]" +
              (if (columnPath.nonEmpty) s" from ${UnresolvedAttribute(columnPath).name}" else "")
          )

        case (fromDataType, toDataType) =>
          verify(
            fromDataType == toDataType,
            s"changing data type of ${UnresolvedAttribute(columnPath).name} " +
              s"from $fromDataType to $toDataType"
          )
      }
    }

    try {
      check(from, to, columnPath)
      None
    } catch {
      case e: AnalysisException =>
        Some(e.message)
    }
  }

  /**
    * Copy the nested data type between two data types.
    */
  def changeDataType(from: DataType, to: DataType, resolver: Resolver): DataType = {
    (from, to) match {
      case (ArrayType(fromElement, fn), ArrayType(toElement, _)) =>
        ArrayType(changeDataType(fromElement, toElement, resolver), fn)

      case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, _)) =>
        MapType(changeDataType(fromKey, toKey, resolver), changeDataType(fromValue, toValue, resolver), fn)

      case (StructType(fromFields), StructType(toFields)) =>
        StructType(
          toFields.map { toField =>
            fromFields
              .find(field => resolver(field.name, toField.name))
              .map { fromField =>
                toField
                  .getComment()
                  .map(fromField.withComment)
                  .getOrElse(fromField)
                  .copy(
                    dataType = changeDataType(fromField.dataType, toField.dataType, resolver),
                    nullable = toField.nullable
                  )
              }
              .getOrElse(toField)
          }
        )

      case (_, toDataType) => toDataType
    }
  }

  /**
    * Check whether we can write to the Delta table, which has `tableSchema`, using a query that has
    * `dataSchema`. Our rules are that:
    *   - `dataSchema` may be missing columns or have additional columns
    *   - We don't trust the nullability in `dataSchema`. Assume fields are nullable.
    *   - We only allow nested StructType expansions. For all other complex types, we check for
    *     strict equality
    *   - `dataSchema` can't have duplicate column names. Columns that only differ by case are also
    *     not allowed.
    * The following merging strategy is
    * applied:
    *  - The name of the current field is used.
    *  - The data types are merged by calling this function.
    *  - We respect the current field's nullability.
    *  - The metadata is current field's metadata.
    *
    * Schema merging occurs in a case insensitive manner. Hence, column names that only differ
    * by case are not accepted in the `dataSchema`.
    */
  def mergeSchemas(tableSchema: StructType, dataSchema: StructType): StructType = {
    checkColumnNameDuplication(dataSchema, "in the data to save")
    def merge(current: DataType, update: DataType): DataType = {
      (current, update) match {
        case (StructType(currentFields), StructType(updateFields))                                        =>
          // Merge existing fields.
          val updateFieldMap       = toFieldMap(updateFields)
          val updatedCurrentFields = currentFields.map { currentField =>
            updateFieldMap.get(currentField.name) match {
              case Some(updateField) =>
                try {
                  StructField(
                    currentField.name,
                    merge(currentField.dataType, updateField.dataType),
                    currentField.nullable,
                    currentField.metadata
                  )
                } catch {
                  case NonFatal(e) =>
                    throw new AnalysisException(
                      s"Failed to merge fields '${currentField.name}' " +
                        s"and '${updateField.name}'. " + e.getMessage
                    )
                }
              case None              =>
                // Retain the old field.
                currentField
            }
          }

          // Identify the newly added fields.
          val nameToFieldMap = toFieldMap(currentFields)
          val newFields      = updateFields.filterNot(f => nameToFieldMap.contains(f.name))

          // Create the merged struct, the new fields are appended at the end of the struct.
          StructType(updatedCurrentFields ++ newFields)
        case (ArrayType(currentElementType, currentContainsNull), ArrayType(updateElementType, _))        =>
          ArrayType(merge(currentElementType, updateElementType), currentContainsNull)
        case (
              MapType(currentKeyType, currentElementType, currentContainsNull),
              MapType(updateKeyType, updateElementType, _)
            ) =>
          MapType(
            merge(currentKeyType, updateKeyType),
            merge(currentElementType, updateElementType),
            currentContainsNull
          )
        case (DecimalType.Fixed(leftPrecision, leftScale), DecimalType.Fixed(rightPrecision, rightScale)) =>
          if ((leftPrecision == rightPrecision) && (leftScale == rightScale)) {
            current
          } else if ((leftPrecision != rightPrecision) && (leftScale != rightScale)) {
            throw new AnalysisException(
              "Failed to merge decimal types with incompatible " +
                s"precision $leftPrecision and $rightPrecision & scale $leftScale and $rightScale"
            )
          } else if (leftPrecision != rightPrecision) {
            throw new AnalysisException(
              "Failed to merge decimal types with incompatible " +
                s"precision $leftPrecision and $rightPrecision"
            )
          } else {
            throw new AnalysisException(
              "Failed to merge decimal types with incompatible " +
                s"scale $leftScale and $rightScale"
            )
          }
        case _ if current == update                                                                       =>
          current

        // Parquet physically stores ByteType, ShortType and IntType as IntType, so when a parquet
        // column is of one of these three types, you can read this column as any of these three
        // types. Since Parquet doesn't complain, we should also allow upcasting among these
        // three types when merging schemas.
        case (ByteType, ShortType)                                                                        => ShortType
        case (ByteType, IntegerType)                                                                      => IntegerType

        case (ShortType, ByteType)    => ShortType
        case (ShortType, IntegerType) => IntegerType

        case (IntegerType, ShortType) => IntegerType
        case (IntegerType, ByteType)  => IntegerType

        case (NullType, _) =>
          update
        case (_, NullType) =>
          current
        case _             =>
          throw new AnalysisException(s"Failed to merge incompatible data types $current and $update")
      }
    }
    merge(tableSchema, dataSchema).asInstanceOf[StructType]
  }

  private def toFieldMap(fields: Seq[StructField]): Map[String, StructField] = {
    CaseInsensitiveMap(fields.map(field => field.name -> field).toMap)
  }

  /**
    * Transform (nested) columns in a schema.
    *
    * @param schema to transform.
    * @param tf function to apply.
    * @return the transformed schema.
    */
  def transformColumns(schema: StructType)(tf: (Seq[String], StructField, Resolver) => StructField): StructType = {
    def transform[E <: DataType](path: Seq[String], dt: E): E = {
      val newDt = dt match {
        case StructType(fields)                             =>
          StructType(fields.map { field =>
            val newField = tf(path, field, DELTA_COL_RESOLVER)
            newField.copy(dataType = transform(path :+ newField.name, newField.dataType))
          })
        case ArrayType(elementType, containsNull)           =>
          ArrayType(transform(path, elementType), containsNull)
        case MapType(keyType, valueType, valueContainsNull) =>
          MapType(transform(path :+ "key", keyType), transform(path :+ "value", valueType), valueContainsNull)
        case other                                          => other
      }
      newDt.asInstanceOf[E]
    }
    transform(Seq.empty, schema)
  }

  /**
    * Transform (nested) columns in a schema. Runs the transform function on all nested StructTypes
    *
    * @param schema to transform.
    * @param tf function to apply on the StructType.
    * @return the transformed schema.
    */
  def transformColumnsStructs(schema: StructType, colName: String)(
    tf: (Seq[String], StructType, Resolver) => Seq[StructField]
  ): StructType = {
    def transform[E <: DataType](path: Seq[String], dt: E): E = {
      val newDt = dt match {
        case struct @ StructType(fields)                    =>
          val newFields = if (fields.exists(_.name == colName)) {
            tf(path, struct, DELTA_COL_RESOLVER)
          } else {
            fields.toSeq
          }

          StructType(newFields.map { field =>
            field.copy(dataType = transform(path :+ field.name, field.dataType))
          })
        case ArrayType(elementType, containsNull)           =>
          ArrayType(transform(path :+ "element", elementType), containsNull)
        case MapType(keyType, valueType, valueContainsNull) =>
          MapType(transform(path :+ "key", keyType), transform(path :+ "value", valueType), valueContainsNull)
        case other                                          => other
      }
      newDt.asInstanceOf[E]
    }
    transform(Seq.empty, schema)
  }

  /**
    * Transform (nested) columns in a schema using the given path and parameter pairs. The transform
    * function is only invoked when a field's path matches one of the input paths.
    *
    * @param schema to transform
    * @param input paths and parameter pairs. The paths point to fields we want to transform. The
    *              parameters will be passed to the transform function for a matching field.
    * @param tf function to apply per matched field. This function takes the field path, the field
    *           itself and the input names and payload pairs that matched the field name. It should
    *           return a new field.
    * @tparam E the type of the payload used for transforming fields.
    * @return the transformed schema.
    */
  def transformColumns[E](schema: StructType, input: Seq[(Seq[String], E)])(
    tf: (Seq[String], StructField, Seq[(Seq[String], E)]) => StructField
  ): StructType = {
    // scalastyle:off caselocale
    val inputLookup = input.groupBy(_._1.map(_.toLowerCase))
    SchemaUtils.transformColumns(schema) { (path, field, resolver) =>
      // Find the parameters that match this field name.
      val fullPath           = path :+ field.name
      val normalizedFullPath = fullPath.map(_.toLowerCase)
      val matches            = inputLookup.get(normalizedFullPath).toSeq.flatMap {
        // Keep only the input name(s) that actually match the field name(s). Note
        // that the Map guarantees that the zipped sequences have the same size.
        _.filter(_._1.zip(fullPath).forall(resolver.tupled))
      }
      if (matches.nonEmpty) {
        tf(path, field, matches)
      } else {
        field
      }
    }
    // scalastyle:on caselocale
  }

  /**
    * Verifies that the column names are acceptable by Parquet and henceforth Delta. Parquet doesn't
    * accept the characters ' ,;{}()\n\t'. We ensure that neither the data columns nor the partition
    * columns have these characters.
    */
  def checkFieldNames(names: Seq[String]): Unit = {
    ParquetSchemaConverter.checkFieldNames(names)
    // The method checkFieldNames doesn't have a valid regex to search for '\n'. That should be
    // fixed in Apache Spark, and we can remove this additional check here.
    names.find(_.contains("\n")).foreach(col => throw DeltaErrors.invalidColumnName(col))
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy