
com.databricks.spark.xml.parsers.StaxXmlParser.scala Maven / Gradle / Ivy
The newest version!
/*
* Copyright 2014 Databricks
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.databricks.spark.xml.parsers
import java.io.StringReader
import javax.xml.stream.XMLEventReader
import javax.xml.stream.events.{Attribute, Characters, EndElement, StartElement, XMLEvent}
import javax.xml.transform.stream.StreamSource
import javax.xml.validation.Schema
import scala.collection.mutable.ArrayBuffer
import scala.collection.JavaConverters._
import scala.util.control.NonFatal
import scala.util.Try
import org.slf4j.LoggerFactory
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import com.databricks.spark.xml.util.TypeCast._
import com.databricks.spark.xml.XmlOptions
import com.databricks.spark.xml.util._
/**
* Wraps parser to iteration process.
*/
private[xml] object StaxXmlParser extends Serializable {
private val logger = LoggerFactory.getLogger(StaxXmlParser.getClass)
def parse(
xml: RDD[String],
schema: StructType,
options: XmlOptions): RDD[Row] = {
xml.mapPartitions { iter =>
val xsdSchema = Option(options.rowValidationXSDPath).map(ValidatorUtil.getSchema)
iter.flatMap { xml =>
doParseColumn(xml, schema, options, options.parseMode, xsdSchema)
}
}
}
def parseColumn(xml: String, schema: StructType, options: XmlOptions): Row = {
// The user=specified schema from from_xml, etc will typically not include a
// "corrupted record" column. In PERMISSIVE mode, which puts bad records in
// such a column, this would cause an error. In this mode, if such a column
// is not manually specified, then fall back to DROPMALFORMED, which will return
// null column values where parsing fails.
val parseMode =
if (options.parseMode == PermissiveMode &&
!schema.fields.exists(_.name == options.columnNameOfCorruptRecord)) {
DropMalformedMode
} else {
options.parseMode
}
val xsdSchema = Option(options.rowValidationXSDPath).map(ValidatorUtil.getSchema)
doParseColumn(xml, schema, options, parseMode, xsdSchema).orNull
}
private def doParseColumn(xml: String,
schema: StructType,
options: XmlOptions,
parseMode: ParseMode,
xsdSchema: Option[Schema]): Option[Row] = {
try {
xsdSchema.foreach { schema =>
schema.newValidator().validate(new StreamSource(new StringReader(xml)))
}
val parser = StaxXmlParserUtils.filteredReader(xml)
val rootAttributes = StaxXmlParserUtils.gatherRootAttributes(parser)
Some(convertObject(parser, schema, options, rootAttributes))
} catch {
case e: PartialResultException =>
failedRecord(xml, options, parseMode, schema,
e.cause, Some(e.partialResult))
case NonFatal(e) =>
failedRecord(xml, options, parseMode, schema, e)
}
}
private def failedRecord(record: String,
options: XmlOptions,
parseMode: ParseMode,
schema: StructType,
cause: Throwable = null,
partialResult: Option[Row] = None): Option[Row] = {
// create a row even if no corrupt record column is present
parseMode match {
case FailFastMode =>
val abbreviatedRecord =
if (record.length() > 1000) record.substring(0, 1000) + "..." else record
throw new IllegalArgumentException(
s"Malformed line in FAILFAST mode: ${abbreviatedRecord.replaceAll("\n", "")}", cause)
case DropMalformedMode =>
val reason = if (cause != null) s"Reason: ${cause.getMessage}" else ""
val abbreviatedRecord =
if (record.length() > 1000) record.substring(0, 1000) + "..." else record
logger.warn(s"Dropping malformed line: ${abbreviatedRecord.replaceAll("\n", "")}. $reason")
logger.debug("Malformed line cause:", cause)
None
case PermissiveMode =>
logger.debug("Malformed line cause:", cause)
// The logic below is borrowed from Apache Spark's FailureSafeParser.
val resultRow = new Array[Any](schema.length)
schema.filterNot(_.name == options.columnNameOfCorruptRecord).foreach { from =>
val sourceIndex = schema.fieldIndex(from.name)
resultRow(sourceIndex) = partialResult.map(_.get(sourceIndex)).orNull
}
val corruptFieldIndex = Try(schema.fieldIndex(options.columnNameOfCorruptRecord)).toOption
corruptFieldIndex.foreach(resultRow(_) = record)
Some(Row.fromSeq(resultRow.toIndexedSeq))
}
}
/**
* Parse the current token (and related children) according to a desired schema
*/
private[xml] def convertField(
parser: XMLEventReader,
dataType: DataType,
options: XmlOptions,
attributes: Array[Attribute] = Array.empty): Any = {
def convertComplicatedType(dt: DataType, attributes: Array[Attribute]): Any = dt match {
case st: StructType => convertObject(parser, st, options)
case MapType(StringType, vt, _) => convertMap(parser, vt, options, attributes)
case ArrayType(st, _) => convertField(parser, st, options)
case _: StringType => StaxXmlParserUtils.currentStructureAsString(parser)
}
(parser.peek, dataType) match {
case (_: StartElement, dt: DataType) => convertComplicatedType(dt, attributes)
case (_: EndElement, _: StringType) =>
// Empty. It's null if these are explicitly treated as null, or "" is the null value
if (options.treatEmptyValuesAsNulls || options.nullValue == ""){
null
} else {
""
}
case (_: EndElement, _: DataType) => null
case (c: Characters, ArrayType(st, _)) =>
// For `ArrayType`, it needs to return the type of element. The values are merged later.
convertTo(c.getData, st, options)
case (c: Characters, st: StructType) =>
// If a value tag is present, this can be an attribute-only element whose values is in that
// value tag field. Or, it can be a mixed-type element with both some character elements
// and other complex structure. Character elements are ignored.
val attributesOnly = st.fields.forall { f =>
f.name == options.valueTag || f.name.startsWith(options.attributePrefix)
}
if (attributesOnly) {
// If everything else is an attribute column, there's no complex structure.
// Just return the value of the character element, or null if we don't have a value tag
st.find(_.name == options.valueTag).map(
valueTag => convertTo(c.getData, valueTag.dataType, options)).orNull
} else {
// Otherwise, ignore this character element, and continue parsing the following complex
// structure
parser.next
parser.peek match {
case _: EndElement => null // no struct here at all; done
case _ => convertObject(parser, st, options)
}
}
case (c: Characters, _: DataType) if c.isWhiteSpace =>
// When `Characters` is found, we need to look further to decide
// if this is really data or space between other elements.
val data = c.getData
parser.next
parser.peek match {
case _: StartElement => convertComplicatedType(dataType, attributes)
case _: EndElement if data.isEmpty => null
case _: EndElement if options.treatEmptyValuesAsNulls => null
case _: EndElement => convertTo(data, dataType, options)
case _ => convertField(parser, dataType, options, attributes)
}
case (c: Characters, dt: DataType) =>
convertTo(c.getData, dt, options)
case (e: XMLEvent, dt: DataType) =>
throw new IllegalArgumentException(
s"Failed to parse a value for data type $dt with event ${e.toString}")
}
}
/**
* Parse an object as map.
*/
private def convertMap(
parser: XMLEventReader,
valueType: DataType,
options: XmlOptions,
attributes: Array[Attribute]): Map[String, Any] = {
val kvPairs = ArrayBuffer.empty[(String, Any)]
attributes.foreach { attr =>
kvPairs += (options.attributePrefix + attr.getName.getLocalPart -> attr.getValue)
}
var shouldStop = false
while (!shouldStop) {
parser.nextEvent match {
case e: StartElement =>
kvPairs +=
(StaxXmlParserUtils.getName(e.asStartElement.getName, options) ->
convertField(parser, valueType, options))
case _: EndElement =>
shouldStop = StaxXmlParserUtils.checkEndElement(parser)
case _ => // do nothing
}
}
kvPairs.toMap
}
/**
* Convert XML attributes to a map with the given schema types.
*/
private def convertAttributes(
attributes: Array[Attribute],
schema: StructType,
options: XmlOptions): Map[String, Any] = {
val convertedValuesMap = collection.mutable.Map.empty[String, Any]
val valuesMap = StaxXmlParserUtils.convertAttributesToValuesMap(attributes, options)
valuesMap.foreach { case (f, v) =>
val nameToIndex = schema.map(_.name).zipWithIndex.toMap
nameToIndex.get(f).foreach { i =>
convertedValuesMap(f) = convertTo(v, schema(i).dataType, options)
}
}
convertedValuesMap.toMap
}
/**
* [[convertObject()]] calls this in order to convert the nested object to a row.
* [[convertObject()]] contains some logic to find out which events are the start
* and end of a nested row and this function converts the events to a row.
*/
private def convertObjectWithAttributes(
parser: XMLEventReader,
schema: StructType,
options: XmlOptions,
attributes: Array[Attribute] = Array.empty): Row = {
// TODO: This method might have to be removed. Some logics duplicate `convertObject()`
val row = new Array[Any](schema.length)
// Read attributes first.
val attributesMap = convertAttributes(attributes, schema, options)
// Then, we read elements here.
val fieldsMap = convertField(parser, schema, options) match {
case row: Row =>
Map(schema.map(_.name).zip(row.toSeq): _*)
case v if schema.fieldNames.contains(options.valueTag) =>
// If this is the element having no children, then it wraps attributes
// with a row So, we first need to find the field name that has the real
// value and then push the value.
val valuesMap = schema.fieldNames.map((_, null)).toMap
valuesMap + (options.valueTag -> v)
case _ => Map.empty
}
// Here we merge both to a row.
val valuesMap = fieldsMap ++ attributesMap
valuesMap.foreach { case (f, v) =>
val nameToIndex = schema.map(_.name).zipWithIndex.toMap
nameToIndex.get(f).foreach { row(_) = v }
}
if (valuesMap.isEmpty) {
// Return an empty row with all nested elements by the schema set to null.
Row.fromSeq(Seq.fill(schema.fieldNames.length)(null))
} else {
Row.fromSeq(row.toIndexedSeq)
}
}
/**
* Parse an object from the event stream into a new Row representing the schema.
* Fields in the xml that are not defined in the requested schema will be dropped.
*/
private def convertObject(
parser: XMLEventReader,
schema: StructType,
options: XmlOptions,
rootAttributes: Array[Attribute] = Array.empty): Row = {
val row = new Array[Any](schema.length)
val nameToIndex = schema.map(_.name).zipWithIndex.toMap
// If there are attributes, then we process them first.
convertAttributes(rootAttributes, schema, options).toSeq.foreach { case (f, v) =>
nameToIndex.get(f).foreach { row(_) = v }
}
val wildcardColName = options.wildcardColName
val hasWildcard = schema.exists(_.name == wildcardColName)
var badRecordException: Option[Throwable] = None
var shouldStop = false
while (!shouldStop) {
parser.nextEvent match {
case e: StartElement => try {
val attributes = e.getAttributes.asScala.map(_.asInstanceOf[Attribute]).toArray
val field = StaxXmlParserUtils.getName(e.asStartElement.getName, options)
nameToIndex.get(field) match {
case Some(index) => schema(index).dataType match {
case st: StructType =>
row(index) = convertObjectWithAttributes(parser, st, options, attributes)
case ArrayType(dt: DataType, _) =>
val values = Option(row(index))
.map(_.asInstanceOf[ArrayBuffer[Any]])
.getOrElse(ArrayBuffer.empty[Any])
val newValue = dt match {
case st: StructType =>
convertObjectWithAttributes(parser, st, options, attributes)
case dt: DataType =>
convertField(parser, dt, options)
}
row(index) = values :+ newValue
case dt: DataType =>
row(index) = convertField(parser, dt, options, attributes)
}
case None =>
if (hasWildcard) {
// Special case: there's an 'any' wildcard element that matches anything else
// as a string (or array of strings, to parse multiple ones)
val newValue = convertField(parser, StringType, options)
val anyIndex = schema.fieldIndex(wildcardColName)
schema(wildcardColName).dataType match {
case StringType =>
row(anyIndex) = newValue
case ArrayType(StringType, _) =>
val values = Option(row(anyIndex))
.map(_.asInstanceOf[ArrayBuffer[String]])
.getOrElse(ArrayBuffer.empty[String])
row(anyIndex) = values :+ newValue
}
} else {
StaxXmlParserUtils.skipChildren(parser)
}
}
} catch {
case NonFatal(exception) if options.parseMode == PermissiveMode =>
badRecordException = badRecordException.orElse(Some(exception))
}
case _: EndElement =>
shouldStop = StaxXmlParserUtils.checkEndElement(parser)
case _ => // do nothing
}
}
if (badRecordException.isEmpty) {
Row.fromSeq(row.toIndexedSeq)
} else {
throw PartialResultException(Row.fromSeq(row.toIndexedSeq), badRecordException.get)
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy