All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.dimajix.spark.sql.SchemaUtils.scala Maven / Gradle / Ivy

There is a newer version: 1.2.0-synapse3.3-spark3.3-hadoop3.3
Show newest version
/*
 * Copyright (C) 2018 The Flowman Authors
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.dimajix.spark.sql

import java.util.Locale

import org.apache.spark.sql.Column
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.functions.rpad
import org.apache.spark.sql.functions.substring
import org.apache.spark.sql.types.ArrayType
import org.apache.spark.sql.types.CharType
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.MapType
import org.apache.spark.sql.types.Metadata
import org.apache.spark.sql.types.MetadataBuilder
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.VarcharType

import com.dimajix.spark.sql.catalyst.parser.CustomSqlParser


object SchemaUtils {
    // This key is compatible with Spark 3.x
    val CHAR_VARCHAR_TYPE_STRING_METADATA_KEY = "__CHAR_VARCHAR_TYPE_STRING"

    /**
      * Helper method for applying an optional schema to a given DataFrame. This will apply the types and order
      * of the target schema. Missing fields will be imputed by NULL values.
      *
      * @param df
      * @param schema
      * @return
      */
    def applySchema(df:DataFrame, schema:Option[StructType], insertNulls:Boolean=true) : DataFrame = {
        require(df != null)
        require(schema != null)

        def applyType(col:Column, field:StructField) : Column = {
            field.dataType match {
                case CharType(n) => rpad(col.cast(StringType), n, " ").as(field.name, field.metadata)
                case VarcharType(n) => substring(col.cast(StringType), 0, n).as(field.name, field.metadata)
                case _ => col.cast(field.dataType).as(field.name, field.metadata)
            }
        }

        def applySchema(df:DataFrame, schema:StructType, insertNulls:Boolean) : DataFrame = {
            val dfFieldsByName = df.schema.map(f => f.name.toLowerCase(Locale.ROOT) -> f).toMap
            val columns = schema.map { field =>
                val col = dfFieldsByName.get(field.name.toLowerCase(Locale.ROOT))
                    .map(_ => df(field.name))
                    .getOrElse {
                        if (!insertNulls)
                            throw new IllegalArgumentException(s"Missing column '${field.name}' in input DataFrame")
                        lit(null)
                    }
                applyType(col, field)
            }
            df.select(columns: _*)
        }

        schema match {
            case Some(s) => applySchema(df, s, insertNulls)
            case None => df
        }
    }

    /**
      * Finds a specific field in a schema
      * @param struct
      * @param name
      * @return
      */
    def find(struct:StructType, name:String) : Option[StructField] = {
        def findField(field:StructField, head:String, tail:Seq[String]) : Option[StructField] = {
            field.dataType match {
                case st:StructType => findStruct(st, head, tail)
                case at:ArrayType => findField(StructField("element", at.elementType), head, tail)
                case _ => throw new NoSuchElementException(s"Cannot descend field ${field.name} - it is neither struct not array")
            }
        }
        def findStruct(struct:StructType, head:String, tail:Seq[String]) : Option[StructField] = {
            struct.fields
                .find(_.name.toLowerCase(Locale.ROOT) == head)
                .flatMap { field =>
                    if (tail.isEmpty)
                        Some(field)
                    else
                        findField(field, tail.head, tail.tail)
                }
        }

        val segments = name.toLowerCase(Locale.ROOT).split('.')
        findStruct(struct, segments.head, segments.tail)
    }

    /**
      * Truncate comments to maximum length. Maybe required for Hive tables
      * @param schema
      * @param maxLength
      * @return
      */
    def truncateComments(schema:StructType, maxLength:Int) : StructType = {
        def processType(dataType:DataType) : DataType = {
            dataType match {
                case st:StructType => truncateComments(st, maxLength)
                case ar:ArrayType => ar.copy(elementType = processType(ar.elementType))
                case mt:MapType => mt.copy(keyType = processType(mt.keyType), valueType = processType(mt.valueType))
                case dt:DataType => dt
            }
        }
        def truncate(field:StructField) : StructField = {
            val metadata = field.getComment()
                .map(comment => new MetadataBuilder()
                    .withMetadata(field.metadata)
                    .putString("comment", comment.take(maxLength))
                    .build()
                ).getOrElse(field.metadata)
            val dataType = processType(field.dataType)
            field.copy(dataType = dataType, metadata = metadata)
        }
        val fields = schema.fields.map(truncate)
        StructType(fields)
    }


    def existsRecursively(dt:DataType, f: (DataType) => Boolean): Boolean = {
        dt match {
            case struct: StructType => f(dt) || struct.fields.exists(field => existsRecursively(field.dataType, f))
            case array: ArrayType => f(dt) || existsRecursively(array.elementType, f)
            case map: MapType => f(dt) || existsRecursively(map.keyType, f) || existsRecursively(map.valueType, f)
            case _ => f(dt)
        }
    }

    /**
     * Returns true if the given data type is CharType/VarcharType or has nested CharType/VarcharType.
     */
    def hasCharVarchar(dt: DataType): Boolean = {
        existsRecursively(dt, f => f.isInstanceOf[CharType] || f.isInstanceOf[VarcharType])
    }

    /**
     * This will normalize a given schema in the sense that all field names are converted to lowercase and all
     * metadata is stripped except the comments. The function will also replace all CHAR and VARCHAR columns to
     * STRING columns, but without providing meta data for reconstructing the original schema.
     *
     * This function can be used to compare schemas with regards to required migrations, when the storage system
     * does not support VARCHAR/CHAR types
     *
     * @param schema
     * @return
     */
    def normalize(schema:StructType) : StructType = {
        StructType(schema.fields.map(normalize))
    }
    def normalize(field:StructField) : StructField = {
        val metadata = field.getComment()
            .map(c =>  new MetadataBuilder().putString("comment", c).build())
            .getOrElse(Metadata.empty)
        StructField(field.name.toLowerCase(Locale.ROOT), normalize(field.dataType), field.nullable, metadata=metadata)
    }
    private def normalize(dtype:DataType) : DataType = {
        dtype match {
            case struct:StructType => normalize(struct)
            case array:ArrayType => ArrayType(normalize(array.elementType), array.containsNull)
            case map:MapType => MapType(normalize(map.keyType), normalize(map.valueType), map.valueContainsNull)
            case _:CharType => StringType
            case _:VarcharType => StringType
            case dt:DataType => dt
        }
    }

    /**
     * Replaces all occurances of VarChar and Char types by String types.
     * @param schema
     * @return
     */
    def replaceCharVarchar(schema:StructType) : StructType = {
        StructType(schema.fields.map(replaceCharVarchar))
    }
    def replaceCharVarchar(field:StructField) : StructField = {
        val metadata = if (hasCharVarchar(field.dataType)) {
            new MetadataBuilder()
                .withMetadata(field.metadata)
                .putString(CHAR_VARCHAR_TYPE_STRING_METADATA_KEY, field.dataType.catalogString)
                .build()
        } else {
            field.metadata
        }
        StructField(field.name, replaceCharVarchar(field.dataType), field.nullable, metadata=metadata)
    }
    private def replaceCharVarchar(dtype:DataType) : DataType = {
        dtype match {
            case struct:StructType => replaceCharVarchar(struct)
            case array:ArrayType => ArrayType(replaceCharVarchar(array.elementType),array.containsNull)
            case map:MapType => MapType(replaceCharVarchar(map.keyType), replaceCharVarchar(map.valueType), map.valueContainsNull)
            case _:CharType => StringType
            case _:VarcharType => StringType
            case dt:DataType => dt
        }
    }

    def hasExtendedTypeinfo(field:StructField) : Boolean = {
        field.metadata.contains(CHAR_VARCHAR_TYPE_STRING_METADATA_KEY)
    }
    def hasExtendedTypeinfo(metadata:Metadata) : Boolean = {
        metadata.contains(CHAR_VARCHAR_TYPE_STRING_METADATA_KEY)
    }

    def dropExtendedTypeInfo(struct:StructType) : StructType = {
        StructType(struct.fields.map(dropExtendedTypeInfo))
    }
    def dropExtendedTypeInfo(field:StructField) : StructField = {
        val meta = dropExtendedTypeInfo(field.metadata)
        val dt = dropExtendedTypeInfo(field.dataType)
        field.copy(dataType = dt, metadata = meta)
    }
    def dropExtendedTypeInfo(metadata: Metadata): Metadata = {
        if (metadata.contains(CHAR_VARCHAR_TYPE_STRING_METADATA_KEY)) {
            new MetadataBuilder()
                .withMetadata(metadata)
                .remove(CHAR_VARCHAR_TYPE_STRING_METADATA_KEY)
                .build()
        }
        else {
            metadata
        }
    }
    private def dropExtendedTypeInfo(dtype: DataType): DataType = {
        dtype match {
            case struct: StructType => dropExtendedTypeInfo(struct)
            case array: ArrayType => ArrayType(dropExtendedTypeInfo(array.elementType), array.containsNull)
            case map: MapType => MapType(dropExtendedTypeInfo(map.keyType), dropExtendedTypeInfo(map.valueType), map.valueContainsNull)
            case dt: DataType => dt
        }
    }

    /**
     * Recovers the original CHAR/VARCHAR types from a struct, which was previously cleaned via replaceCharVarchar
     * @param schema
     * @return
     */
    def recoverCharVarchar(schema:StructType) : StructType = {
        StructType(schema.map(recoverCharVarchar))
    }
    def recoverCharVarchar(field:StructField) : StructField = {
        val dt = field.dataType match {
            case struct: StructType => recoverCharVarchar(struct)
            case _:StringType|_:ArrayType|_:MapType if field.metadata.contains(CHAR_VARCHAR_TYPE_STRING_METADATA_KEY) =>
                val typeString = field.metadata.getString(CHAR_VARCHAR_TYPE_STRING_METADATA_KEY)
                try {
                    CustomSqlParser.parseDataType(typeString)
                } catch {
                    // Work around bad field names, which will fire a parse-exception
                    case _: ParseException =>
                        recoverCharVarchar(field.dataType)
                }
            case dt: DataType => dt
        }

        val meta = new MetadataBuilder().withMetadata(field.metadata)
            .remove(CHAR_VARCHAR_TYPE_STRING_METADATA_KEY)
            .build()
        field.copy(dataType = dt, metadata = meta)
    }
    private def recoverCharVarchar(dataType: DataType) : DataType = {
        dataType match {
            case struct: StructType => struct.copy(fields = struct.fields.map(recoverCharVarchar))
            case array: ArrayType => ArrayType(recoverCharVarchar(array.elementType), array.containsNull)
            case map: MapType => MapType(recoverCharVarchar(map.keyType), recoverCharVarchar(map.valueType), map.valueContainsNull)
            case dt: DataType => dt
        }
    }

    /**
      * Removes all meta data from a Spark schema. Useful for comparing results in unit tests
      * @param schema
      * @return
      */
    def dropMetadata(schema:StructType) : StructType = {
        def processType(dataType:DataType) : DataType = {
            dataType match {
                case st:StructType => dropMetadata(st)
                case ar:ArrayType => ar.copy(elementType = processType(ar.elementType))
                case mt:MapType => mt.copy(keyType = processType(mt.keyType), valueType = processType(mt.valueType))
                case dt:DataType => dt
            }
        }

        val fields = schema.fields.map { field =>
            field.copy(dataType = processType(field.dataType), metadata = Metadata.empty)
        }
        StructType(fields)
    }

    /**
      * Converts the given Spark schema to a lower case schema
      * @param schema
      * @return
      */
    def toLowerCase(schema:StructType) : StructType = {
        StructType(schema.fields.map(toLowerCase))
    }
    private def toLowerCase(field:StructField) : StructField = {
        StructField(field.name.toLowerCase(Locale.ROOT), toLowerCase(field.dataType), field.nullable, field.metadata)
    }
    private def toLowerCase(dtype:DataType) : DataType = {
        dtype match {
            case struct:StructType => toLowerCase(struct)
            case array:ArrayType => ArrayType(toLowerCase(array.elementType),array.containsNull)
            case map:MapType => MapType(toLowerCase(map.keyType), toLowerCase(map.valueType), map.valueContainsNull)
            case dt:DataType => dt
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy