All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.databricks.spark.xml.XmlRelation.scala Maven / Gradle / Ivy

/*
 * Copyright 2014 Databricks
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.databricks.spark.xml

import java.io.IOException

import org.apache.hadoop.fs.Path
import org.slf4j.LoggerFactory

import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.sources.{PrunedScan, InsertableRelation, BaseRelation, TableScan}
import org.apache.spark.sql.types._
import com.databricks.spark.xml.util.{CompressionCodecs, InferSchema}
import com.databricks.spark.xml.parsers.StaxXmlParser

case class XmlRelation protected[spark] (
    baseRDD: () => RDD[String],
    location: Option[String],
    parameters: Map[String, String],
    userSchema: StructType = null)(@transient val sqlContext: SQLContext)
  extends BaseRelation
  with InsertableRelation
  with TableScan
  with PrunedScan {

  private val logger = LoggerFactory.getLogger(XmlRelation.getClass)

  private val options = XmlOptions(parameters)

  override val schema: StructType = {
    Option(userSchema).getOrElse {
      InferSchema.infer(
        baseRDD(),
        options)
    }
  }

  override def buildScan(): RDD[Row] = {
    StaxXmlParser.parse(
      baseRDD(),
      schema,
      options)
  }

  override def buildScan(requiredColumns: Array[String]): RDD[Row] = {
    val requiredFields = requiredColumns.map(schema(_))
    val schemaFields = schema.fields
    if (schemaFields.deep == requiredFields.deep) {
      buildScan()
    } else if (options.failFastFlag) {
      val safeRequestedSchema = StructType(requiredFields)
      StaxXmlParser.parse(
        baseRDD(),
        safeRequestedSchema,
        options)
    } else {
      // If `failFast` is disabled, then it needs to parse all the values
      // so that we can decide which row is malformed.
      val safeRequestedSchema = StructType(
        requiredFields ++ schema.fields.filterNot(requiredFields.contains(_)))
      val rows = StaxXmlParser.parse(
        baseRDD(),
        safeRequestedSchema,
        options)

      val rowSize = requiredFields.length
      rows.mapPartitions { iter =>
        iter.flatMap { xml =>
          Some(Row.fromSeq(xml.toSeq.take(rowSize)))
        }
      }
    }
  }

  // The function below was borrowed from JSONRelation
  override def insert(data: DataFrame, overwrite: Boolean): Unit = {
    val filesystemPath = location match {
      case Some(p) => new Path(p)
      case None =>
        throw new IOException(s"Cannot INSERT into table with no path defined")
    }

    val fs = filesystemPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration)

    if (overwrite) {
      try {
        fs.delete(filesystemPath, true)
      } catch {
        case e: IOException =>
          throw new IOException(
            s"Unable to clear output directory ${filesystemPath.toString} prior"
              + s" to INSERT OVERWRITE a XML table:\n${e.toString}")
      }
      // Write the data. We assume that schema isn't changed, and we won't update it.
      val codecClass = CompressionCodecs.getCodecClass(options.codec)
      data.saveAsXmlFile(filesystemPath.toString, parameters, codecClass)
    } else {
      sys.error("XML tables only support INSERT OVERWRITE for now.")
    }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy