All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.crealytics.spark.excel.ExcelRelation.scala Maven / Gradle / Ivy

The newest version!
package com.crealytics.spark.excel

import java.sql.Timestamp
import java.text.SimpleDateFormat

import org.apache.poi.ss.usermodel.{Cell, CellType, DataFormatter, DateUtil, Row => _}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._

import scala.util.Try

case class ExcelRelation(
  dataLocator: DataLocator,
  header: Boolean,
  treatEmptyValuesAsNulls: Boolean,
  usePlainNumberFormat: Boolean,
  inferSheetSchema: Boolean,
  setErrorCellsToFallbackValues: Boolean,
  addColorColumns: Boolean = true,
  userSchema: Option[StructType] = None,
  timestampFormat: Option[String] = None,
  excerptSize: Int = 10,
  workbookReader: WorkbookReader
)(@transient val sqlContext: SQLContext)
    extends BaseRelation
    with TableScan
    with PrunedScan {
  type SheetRow = Seq[Cell]

  lazy val excerpt: List[SheetRow] = workbookReader.withWorkbook(dataLocator.readFrom(_).take(excerptSize).to[List])

  lazy val headerColumnForName = headerColumns.map(c => c.name -> c).toMap

  override val schema: StructType = inferSchema

  val dataFormatter = new DataFormatter()

  override def buildScan: RDD[Row] = buildScan(schema.map(_.name).toArray)

  private val timestampParser: String => Timestamp =
    timestampFormat
      .map { fmt =>
        val parser = new SimpleDateFormat(fmt)
        (stringValue: String) => new Timestamp(parser.parse(stringValue).getTime)
      }
      .getOrElse((stringValue: String) => Timestamp.valueOf(stringValue))

  val columnNameRegex = s"(?s)^(.*?)(_color)?$$".r.unanchored
  private def columnExtractor(column: String): SheetRow => Any = {
    val columnNameRegex(columnName, isColor) = column
    val headerColumn = headerColumnForName(columnName)

    val cellExtractor: PartialFunction[Seq[Cell], Any] = if (isColor == null) {
      headerColumn
    } else new ColorDataColumn(headerColumn.name, headerColumn.columnIndex)

    cellExtractor.applyOrElse(_, (_: Seq[Cell]) => null)
  }

  override def buildScan(requiredColumns: Array[String]): RDD[Row] = {
    val lookups = requiredColumns.map(columnExtractor).toSeq
    workbookReader.withWorkbook { workbook =>
      val allDataIterator = dataLocator.readFrom(workbook)
      val iter = if (header) allDataIterator.drop(1) else allDataIterator
      val rows: Iterator[Seq[Any]] = iter
        .flatMap(row =>
          Try {
            val values = lookups.map(l => l(row))
            Some(values)
          }.recover { case e =>
            // e.printStackTrace()
            None
          }.get
        )
      val result = rows.to[Vector]
      parallelize(result.map(Row.fromSeq))
    }
  }

  private def getSparkType(cell: Cell): DataType = {
    cell.getCellType match {
      case CellType.FORMULA =>
        cell.getCachedFormulaResultType match {
          case CellType.STRING => StringType
          case CellType.NUMERIC => DoubleType
          case _ => NullType
        }
      case CellType.STRING if cell.getStringCellValue == "" => NullType
      case CellType.STRING => StringType
      case CellType.BOOLEAN => BooleanType
      case CellType.NUMERIC => if (DateUtil.isCellDateFormatted(cell)) TimestampType else DoubleType
      case CellType.BLANK => NullType
      case CellType.ERROR => NullType
    }
  }

  private def parallelize[T : scala.reflect.ClassTag](seq: Seq[T]): RDD[T] = sqlContext.sparkContext.parallelize(seq)

  /** Generates a header from the given row which is null-safe and duplicate-safe.
    */
  lazy val headerColumns: Seq[HeaderDataColumn] = {
    val firstRow = excerpt.head
    val nonHeaderRows = if (header) excerpt.tail else excerpt

    val fields = userSchema.getOrElse {
      val dataTypes = if (this.inferSheetSchema) {
        val headerIndices = firstRow.map(_.getColumnIndex)
        val cellTypes: Seq[Seq[DataType]] = nonHeaderRows
          .map { r =>
            headerIndices.map(i => r.find(_.getColumnIndex == i).map(getSparkType).getOrElse(DataTypes.NullType))
          }
        InferSchema(parallelize(cellTypes))
      } else {
        // By default fields are assumed to be StringType
        excerpt.map(_.size).reduceOption(math.max) match {
          case None => Array()
          case Some(maxCellsPerRow) => {
            (0 until maxCellsPerRow).map(_ => StringType: DataType).toArray
          }
        }
      }

      def colName(cell: Cell) = cell.getStringCellValue

      val colNames = if (header) {
        val headerNames = firstRow.map(colName)
        val duplicates = {
          val nonNullHeaderNames = headerNames.filter(_ != null)
          nonNullHeaderNames.groupBy(identity).filter(_._2.size > 1).keySet
        }

        firstRow.zipWithIndex.map { case (cell, index) =>
          val value = colName(cell)
          if (value == null || value.isEmpty) {
            // When there are empty strings or the, put the index as the suffix.
            s"_c$index"
          } else if (duplicates.contains(value)) {
            // When there are duplicates, put the index as the suffix.
            s"$value$index"
          } else {
            value
          }
        }
      } else {
        firstRow.zipWithIndex.map { case (_, index) =>
          // Uses default column names, "_c#" where # is its position of fields
          // when header option is disabled.
          s"_c$index"
        }
      }
      colNames.zip(dataTypes).map { case (colName, dataType) =>
        StructField(name = colName, dataType = dataType, nullable = true)
      }
    }

    firstRow.zip(fields).map { case (cell, field) =>
      new HeaderDataColumn(
        field,
        cell.getColumnIndex,
        treatEmptyValuesAsNulls,
        usePlainNumberFormat,
        timestampParser,
        setErrorCellsToFallbackValues
      )
    }
  }

  private def inferSchema(): StructType =
    this.userSchema.getOrElse {
      val baseSchema = StructType(headerColumns.map(_.field))
      if (addColorColumns) {
        headerColumns.foldLeft(baseSchema) { (schema, header) =>
          schema.add(s"${header.name}_color", StringType, nullable = true)
        }
      } else {
        baseSchema
      }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy