All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.databricks.spark.xml.util.XSDToSchema.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.databricks.spark.xml.util

import java.io.{File, FileInputStream, InputStreamReader, StringReader}
import java.nio.charset.StandardCharsets
import java.nio.file.Path

import scala.jdk.CollectionConverters._
import org.apache.spark.sql.types._
import org.apache.ws.commons.schema._
import org.apache.ws.commons.schema.constants.Constants

import com.databricks.spark.xml.XmlOptions

/**
 * Utility to generate a Spark schema from an XSD. Not all XSD schemas are simple tabular schemas,
 * so not all elements or XSDs are supported.
 */
object XSDToSchema {

  /**
   * Reads a schema from an XSD file.
   * Note that if the schema consists of one complex parent type which you want to use as
   * the row tag schema, then you will need to extract the schema of the single resulting
   * struct in the resulting StructType, and use its StructType as your schema.
   *
   * @param xsdFile XSD file
   * @return Spark-compatible schema
   */
  def read(xsdFile: File): StructType = {
    val xmlSchemaCollection = new XmlSchemaCollection()
    xmlSchemaCollection.setBaseUri(xsdFile.getParent)
    val xmlSchema = xmlSchemaCollection.read(
      new InputStreamReader(new FileInputStream(xsdFile), StandardCharsets.UTF_8))

    getStructType(xmlSchema)
  }

  /**
   * Reads a schema from an XSD file.
   * Note that if the schema consists of one complex parent type which you want to use as
   * the row tag schema, then you will need to extract the schema of the single resulting
   * struct in the resulting StructType, and use its StructType as your schema.
   *
   * @param xsdFile XSD file
   * @return Spark-compatible schema
   */
  def read(xsdFile: Path): StructType = read(xsdFile.toFile)

  /**
   * Reads a schema from an XSD as a string.
   * Note that if the schema consists of one complex parent type which you want to use as
   * the row tag schema, then you will need to extract the schema of the single resulting
   * struct in the resulting StructType, and use its StructType as your schema.
   *
   * @param xsdString XSD as a string
   * @return Spark-compatible schema
   */
  def read(xsdString: String): StructType = {
    val xmlSchema = new XmlSchemaCollection().read(new StringReader(xsdString))
    getStructType(xmlSchema)
  }


  private def getStructField(xmlSchema: XmlSchema, schemaType: XmlSchemaType): StructField = {
    schemaType match {
      // xs:simpleType
      case simpleType: XmlSchemaSimpleType =>
        val schemaType = simpleType.getContent match {
          case restriction: XmlSchemaSimpleTypeRestriction =>
            val qName = simpleType.getQName match {
              case null => restriction.getBaseTypeName
              case n => n
            }

            // Hacky, is there a better way? see if the type is known as a custom
            // type and use that if so, assuming along the way it's a simple restriction
            val typeName = xmlSchema.getSchemaTypes.asScala.get(qName).map { s =>
                s.asInstanceOf[XmlSchemaSimpleType].
                  getContent.asInstanceOf[XmlSchemaSimpleTypeRestriction].getBaseTypeName
              }.getOrElse(qName)

            typeName match {
              case Constants.XSD_BOOLEAN => BooleanType
              case Constants.XSD_DECIMAL =>
                val facets = restriction.getFacets.asScala
                val fracDigits = facets.collectFirst {
                  case facet: XmlSchemaFractionDigitsFacet => facet.getValue.toString.toInt
                }.getOrElse(18)
                val totalDigits = facets.collectFirst {
                  case facet: XmlSchemaTotalDigitsFacet => facet.getValue.toString.toInt
                }.getOrElse(38)
                DecimalType(totalDigits, math.min(totalDigits, fracDigits))
              case Constants.XSD_UNSIGNEDLONG => DecimalType(38, 0)
              case Constants.XSD_DOUBLE => DoubleType
              case Constants.XSD_FLOAT => FloatType
              case Constants.XSD_BYTE => ByteType
              case Constants.XSD_SHORT |
                   Constants.XSD_UNSIGNEDBYTE => ShortType
              case Constants.XSD_INTEGER |
                   Constants.XSD_NEGATIVEINTEGER |
                   Constants.XSD_NONNEGATIVEINTEGER |
                   Constants.XSD_NONPOSITIVEINTEGER |
                   Constants.XSD_POSITIVEINTEGER |
                   Constants.XSD_UNSIGNEDSHORT => IntegerType
              case Constants.XSD_LONG |
                   Constants.XSD_UNSIGNEDINT => LongType
              case Constants.XSD_DATE => DateType
              case Constants.XSD_DATETIME => TimestampType
              case _ => StringType
            }
          case _ => StringType
        }
        StructField("baseName", schemaType)

      // xs:complexType
      case complexType: XmlSchemaComplexType =>
        complexType.getContentModel match {
          case content: XmlSchemaSimpleContent =>
            // xs:simpleContent
            content.getContent match {
              case extension: XmlSchemaSimpleContentExtension =>
                val baseStructField = getStructField(xmlSchema,
                  xmlSchema.getParent.getTypeByQName(extension.getBaseTypeName))
                val value = StructField("_VALUE", baseStructField.dataType)
                val attributes = extension.getAttributes.asScala.map {
                  case attribute: XmlSchemaAttribute =>
                    val baseStructField = getStructField(xmlSchema,
                      xmlSchema.getParent.getTypeByQName(attribute.getSchemaTypeName))
                    StructField(s"_${attribute.getName}", baseStructField.dataType,
                      attribute.getUse != XmlSchemaUse.REQUIRED)
                }.toSeq
                StructField(complexType.getName, StructType(value +: attributes))
              case unsupported =>
                throw new IllegalArgumentException(s"Unsupported content: $unsupported")
            }
          case content: XmlSchemaComplexContent =>
            val complexContent = content.getContent
            complexContent match {
              case extension: XmlSchemaComplexContentExtension =>
                val baseStructField = getStructField(xmlSchema,
                  xmlSchema.getParent.getTypeByQName(extension.getBaseTypeName))
                val baseFields = baseStructField.dataType match {
                  case structType: StructType => structType.fields
                  case others =>
                    throw new IllegalArgumentException(
                      s"Non-StructType in ComplexContentExtension: $others"
                    )
                }

                val extendedFields = getStructFieldsFromParticle(extension.getParticle, xmlSchema)
                StructField(
                  schemaType.getQName.getLocalPart,
                  StructType(baseFields ++ extendedFields)
                )
              case unsupported =>
                throw new IllegalArgumentException(s"Unsupported content: $unsupported")
            }
          case null =>
            val childFields = getStructFieldsFromParticle(complexType.getParticle, xmlSchema)
            val attributes = complexType.getAttributes.asScala.map {
              case attribute: XmlSchemaAttribute =>
                val attributeType = attribute.getSchemaTypeName match {
                  case null =>
                    StringType
                  case t =>
                    getStructField(xmlSchema, xmlSchema.getParent.getTypeByQName(t)).dataType
                }
                StructField(s"_${attribute.getName}", attributeType,
                  attribute.getUse != XmlSchemaUse.REQUIRED)
            }.toSeq
            StructField(complexType.getName, StructType(childFields ++ attributes))
          case unsupported =>
            throw new IllegalArgumentException(s"Unsupported content model: $unsupported")
        }
      case unsupported =>
        throw new IllegalArgumentException(s"Unsupported schema element type: $unsupported")
    }
  }

  private def getStructType(xmlSchema: XmlSchema): StructType = {
    StructType(xmlSchema.getElements.asScala.toSeq.map { case (_, schemaElement) =>
      val schemaType = schemaElement.getSchemaType
      // if (schemaType.isAnonymous) {
      //   schemaType.setName(qName.getLocalPart)
      // }
      val rootType = getStructField(xmlSchema, schemaType)
      StructField(schemaElement.getName, rootType.dataType, schemaElement.getMinOccurs == 0)
    })
  }

  private def getStructFieldsFromParticle(
      particle: XmlSchemaParticle,
      xmlSchema: XmlSchema
  ): Seq[StructField] = {
    particle match {
        // xs:all
        case all: XmlSchemaAll =>
          all.getItems.asScala.map {
            case element: XmlSchemaElement =>
              val baseStructField = getStructField(xmlSchema, element.getSchemaType)
              val nullable = element.getMinOccurs == 0
              if (element.getMaxOccurs == 1) {
                StructField(element.getName, baseStructField.dataType, nullable)
              } else {
                StructField(element.getName, ArrayType(baseStructField.dataType), nullable)
              }
          }.toSeq
        // xs:choice
        case choice: XmlSchemaChoice =>
          choice.getItems.asScala.map {
            case element: XmlSchemaElement =>
              val baseStructField = getStructField(xmlSchema, element.getSchemaType)
              if (element.getMaxOccurs == 1) {
                StructField(element.getName, baseStructField.dataType, true)
              } else {
                StructField(element.getName, ArrayType(baseStructField.dataType), true)
              }
            case any: XmlSchemaAny =>
              val dataType = if (any.getMaxOccurs > 1) ArrayType(StringType) else StringType
              StructField(XmlOptions.DEFAULT_WILDCARD_COL_NAME, dataType, true)
          }.toSeq
        // xs:sequence
        case sequence: XmlSchemaSequence =>
          // flatten xs:choice nodes
          sequence.getItems.asScala.flatMap {
            _ match {
              case choice: XmlSchemaChoice =>
                choice.getItems.asScala.map { e =>
                  val xme = e.asInstanceOf[XmlSchemaElement]
                  val baseType = getStructField(xmlSchema, xme.getSchemaType).dataType
                  val dataType = if (xme.getMaxOccurs > 1) ArrayType(baseType) else baseType
                  StructField(xme.getName, dataType, true)
                }
              case e: XmlSchemaElement =>
                val refQName = e.getRef.getTargetQName
                val baseType =
                  if (refQName != null) {
                    getStructField(
                      xmlSchema,
                      xmlSchema.getParent.getElementByQName(refQName).getSchemaType).dataType
                  }
                  else getStructField(xmlSchema, e.getSchemaType).dataType
                val dataType = if (e.getMaxOccurs > 1) ArrayType(baseType) else baseType
                val nullable = e.getMinOccurs == 0
                val structFieldName =
                  Option(refQName).map(_.getLocalPart).getOrElse(e.getName)
                Seq(StructField(structFieldName, dataType, nullable))
              case any: XmlSchemaAny =>
                val dataType =
                  if (any.getMaxOccurs > 1) ArrayType(StringType) else StringType
                val nullable = any.getMinOccurs == 0
                Seq(StructField(XmlOptions.DEFAULT_WILDCARD_COL_NAME, dataType, nullable))
              case unsupported =>
                throw new IllegalArgumentException(s"Unsupported item: $unsupported")
            }
          }.toSeq
        case null =>
          Seq.empty
        case unsupported =>
          throw new IllegalArgumentException(s"Unsupported particle: $unsupported")
      }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy