All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.sql.avro.SchemaConverters.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.avro

import java.util.Locale

import scala.collection.JavaConverters._
import scala.collection.mutable

import org.apache.avro.{LogicalTypes, Schema, SchemaBuilder}
import org.apache.avro.LogicalTypes.{Date, Decimal, LocalTimestampMicros, LocalTimestampMillis, TimestampMicros, TimestampMillis}
import org.apache.avro.Schema.Type._

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.Decimal.minBytesForPrecision

/**
 * This object contains method that are used to convert sparkSQL schemas to avro schemas and vice
 * versa.
 */
@DeveloperApi
object SchemaConverters {
  private lazy val nullSchema = Schema.create(Schema.Type.NULL)

  /**
   * Internal wrapper for SQL data type and nullability.
   *
   * @since 2.4.0
   */
  case class SchemaType(dataType: DataType, nullable: Boolean)

  /**
   * Converts an Avro schema to a corresponding Spark SQL schema.
   *
   * @since 4.0.0
   */
  def toSqlType(avroSchema: Schema, useStableIdForUnionType: Boolean): SchemaType = {
    toSqlTypeHelper(avroSchema, Set.empty, useStableIdForUnionType)
  }
  /**
   * Converts an Avro schema to a corresponding Spark SQL schema.
   *
   * @since 2.4.0
   */
  def toSqlType(avroSchema: Schema): SchemaType = {
    toSqlType(avroSchema, false)
  }
  def toSqlType(avroSchema: Schema, options: Map[String, String]): SchemaType = {
    toSqlTypeHelper(avroSchema, Set.empty, AvroOptions(options).useStableIdForUnionType)
  }

  // The property specifies Catalyst type of the given field
  private val CATALYST_TYPE_PROP_NAME = "spark.sql.catalyst.type"

  private def toSqlTypeHelper(
      avroSchema: Schema,
      existingRecordNames: Set[String],
      useStableIdForUnionType: Boolean): SchemaType = {
    avroSchema.getType match {
      case INT => avroSchema.getLogicalType match {
        case _: Date => SchemaType(DateType, nullable = false)
        case _ =>
          val catalystTypeAttrValue = avroSchema.getProp(CATALYST_TYPE_PROP_NAME)
          val catalystType = if (catalystTypeAttrValue == null) {
            IntegerType
          } else {
            CatalystSqlParser.parseDataType(catalystTypeAttrValue)
          }
          SchemaType(catalystType, nullable = false)
      }
      case STRING => SchemaType(StringType, nullable = false)
      case BOOLEAN => SchemaType(BooleanType, nullable = false)
      case BYTES | FIXED => avroSchema.getLogicalType match {
        // For FIXED type, if the precision requires more bytes than fixed size, the logical
        // type will be null, which is handled by Avro library.
        case d: Decimal => SchemaType(DecimalType(d.getPrecision, d.getScale), nullable = false)
        case _ => SchemaType(BinaryType, nullable = false)
      }

      case DOUBLE => SchemaType(DoubleType, nullable = false)
      case FLOAT => SchemaType(FloatType, nullable = false)
      case LONG => avroSchema.getLogicalType match {
        case d: CustomDecimal =>
          SchemaType(DecimalType(d.precision, d.scale), nullable = false)
        case _: TimestampMillis | _: TimestampMicros => SchemaType(TimestampType, nullable = false)
        case _: LocalTimestampMillis | _: LocalTimestampMicros =>
          SchemaType(TimestampNTZType, nullable = false)
        case _ =>
          val catalystTypeAttrValue = avroSchema.getProp(CATALYST_TYPE_PROP_NAME)
          val catalystType = if (catalystTypeAttrValue == null) {
            LongType
          } else {
            CatalystSqlParser.parseDataType(catalystTypeAttrValue)
          }
          SchemaType(catalystType, nullable = false)
      }

      case ENUM => SchemaType(StringType, nullable = false)

      case NULL => SchemaType(NullType, nullable = true)

      case RECORD =>
        if (existingRecordNames.contains(avroSchema.getFullName)) {
          throw new IncompatibleSchemaException(s"""
            |Found recursive reference in Avro schema, which can not be processed by Spark:
            |${avroSchema.toString(true)}
          """.stripMargin)
        }
        val newRecordNames = existingRecordNames + avroSchema.getFullName
        val fields = avroSchema.getFields.asScala.map { f =>
          val schemaType = toSqlTypeHelper(f.schema(), newRecordNames, useStableIdForUnionType)
          StructField(f.name, schemaType.dataType, schemaType.nullable)
        }

        SchemaType(StructType(fields.toArray), nullable = false)

      case ARRAY =>
        val schemaType = toSqlTypeHelper(
          avroSchema.getElementType,
          existingRecordNames,
          useStableIdForUnionType)
        SchemaType(
          ArrayType(schemaType.dataType, containsNull = schemaType.nullable),
          nullable = false)

      case MAP =>
        val schemaType = toSqlTypeHelper(avroSchema.getValueType,
          existingRecordNames, useStableIdForUnionType)
        SchemaType(
          MapType(StringType, schemaType.dataType, valueContainsNull = schemaType.nullable),
          nullable = false)

      case UNION =>
        if (avroSchema.getTypes.asScala.exists(_.getType == NULL)) {
          // In case of a union with null, eliminate it and make a recursive call
          val remainingUnionTypes = AvroUtils.nonNullUnionBranches(avroSchema)
          if (remainingUnionTypes.size == 1) {
            toSqlTypeHelper(remainingUnionTypes.head, existingRecordNames, useStableIdForUnionType)
              .copy(nullable = true)
          } else {
            toSqlTypeHelper(
              Schema.createUnion(remainingUnionTypes.asJava),
              existingRecordNames,
              useStableIdForUnionType).copy(nullable = true)
          }
        } else avroSchema.getTypes.asScala.map(_.getType).toSeq match {
          case Seq(t1) =>
            toSqlTypeHelper(avroSchema.getTypes.get(0),
              existingRecordNames, useStableIdForUnionType)
          case Seq(t1, t2) if Set(t1, t2) == Set(INT, LONG) =>
            SchemaType(LongType, nullable = false)
          case Seq(t1, t2) if Set(t1, t2) == Set(FLOAT, DOUBLE) =>
            SchemaType(DoubleType, nullable = false)
          case _ =>
            // When avroOptions.useStableIdForUnionType is false, convert complex unions to struct
            // types where field names are member0, member1, etc. This is consistent with the
            // behavior when converting between Avro and Parquet.
            // If avroOptions.useStableIdForUnionType is true, include type name in field names
            // so that users can drop or add fields and keep field name stable.
            val fieldNameSet : mutable.Set[String] = mutable.Set()
            val fields = avroSchema.getTypes.asScala.zipWithIndex.map {
              case (s, i) =>
                val schemaType = toSqlTypeHelper(s, existingRecordNames, useStableIdForUnionType)

                val fieldName = if (useStableIdForUnionType) {
                  // Avro's field name may be case sensitive, so field names for two named type
                  // could be "a" and "A" and we need to distinguish them. In this case, we throw
                  // an exception.
                  // Stable id prefix can be empty so the name of the field can be just the type.
                  val tempFieldName = s"member_${s.getName}"
                  if (!fieldNameSet.add(tempFieldName.toLowerCase(Locale.ROOT))) {
                    throw new IncompatibleSchemaException(
                      "Cannot generate stable identifier for Avro union type due to name " +
                      s"conflict of type name ${s.getName}")
                  }
                  tempFieldName
                } else {
                  s"member$i"
                }

                // All fields are nullable because only one of them is set at a time
                StructField(fieldName, schemaType.dataType, nullable = true)
            }

            SchemaType(StructType(fields.toArray), nullable = false)
        }

      case other => throw new IncompatibleSchemaException(s"Unsupported type $other")
    }
  }

  /**
   * Converts a Spark SQL schema to a corresponding Avro schema.
   *
   * @since 2.4.0
   */
  def toAvroType(
      catalystType: DataType,
      nullable: Boolean = false,
      recordName: String = "topLevelRecord",
      nameSpace: String = "")
    : Schema = {
    val builder = SchemaBuilder.builder()

    val schema = catalystType match {
      case BooleanType => builder.booleanType()
      case ByteType | ShortType | IntegerType => builder.intType()
      case LongType => builder.longType()
      case DateType =>
        LogicalTypes.date().addToSchema(builder.intType())
      case TimestampType =>
        LogicalTypes.timestampMicros().addToSchema(builder.longType())
      case TimestampNTZType =>
        LogicalTypes.localTimestampMicros().addToSchema(builder.longType())

      case FloatType => builder.floatType()
      case DoubleType => builder.doubleType()
      case StringType => builder.stringType()
      case NullType => builder.nullType()
      case d: DecimalType =>
        val avroType = LogicalTypes.decimal(d.precision, d.scale)
        val fixedSize = minBytesForPrecision(d.precision)
        // Need to avoid naming conflict for the fixed fields
        val name = nameSpace match {
          case "" => s"$recordName.fixed"
          case _ => s"$nameSpace.$recordName.fixed"
        }
        avroType.addToSchema(SchemaBuilder.fixed(name).size(fixedSize))

      case BinaryType => builder.bytesType()
      case ArrayType(et, containsNull) =>
        builder.array()
          .items(toAvroType(et, containsNull, recordName, nameSpace))
      case MapType(StringType, vt, valueContainsNull) =>
        builder.map()
          .values(toAvroType(vt, valueContainsNull, recordName, nameSpace))
      case st: StructType =>
        val childNameSpace = if (nameSpace != "") s"$nameSpace.$recordName" else recordName
        val fieldsAssembler = builder.record(recordName).namespace(nameSpace).fields()
        st.foreach { f =>
          val fieldAvroType =
            toAvroType(f.dataType, f.nullable, f.name, childNameSpace)
          fieldsAssembler.name(f.name).`type`(fieldAvroType).noDefault()
        }
        fieldsAssembler.endRecord()

      case ym: YearMonthIntervalType =>
        val ymIntervalType = builder.intType()
        ymIntervalType.addProp(CATALYST_TYPE_PROP_NAME, ym.typeName)
        ymIntervalType
      case dt: DayTimeIntervalType =>
        val dtIntervalType = builder.longType()
        dtIntervalType.addProp(CATALYST_TYPE_PROP_NAME, dt.typeName)
        dtIntervalType

      // This should never happen.
      case other => throw new IncompatibleSchemaException(s"Unexpected type $other.")
    }
    if (nullable && catalystType != NullType) {
      Schema.createUnion(schema, nullSchema)
    } else {
      schema
    }
  }
}

private[avro] class IncompatibleSchemaException(
  msg: String, ex: Throwable = null) extends Exception(msg, ex)

private[avro] class UnsupportedAvroTypeException(msg: String) extends Exception(msg)




© 2015 - 2024 Weber Informatics LLC | Privacy Policy