com.crealytics.spark.excel.InferSchema.scala Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of spark-excel-2.12.17-3.1.1_2.12 Show documentation
Show all versions of spark-excel-2.12.17-3.1.1_2.12 Show documentation
A Spark plugin for reading and writing Excel files
The newest version!
/*
* Copyright 2022 Martin Mauch (@nightscape)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.crealytics.spark.excel
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types._
private[excel] object InferSchema {
type CellType = Int
/** Similar to the JSON schema inference. [[org.apache.spark.sql.execution.datasources.json.InferSchema]]
* 1. Infer type of each row 2. Merge row types to find common type 3. Replace any null types with string type
*/
def apply(rowsRDD: RDD[Seq[DataType]]): Array[DataType] = {
val startType: Array[DataType] = Array.empty
val rootTypes: Array[DataType] = rowsRDD.aggregate(startType)(inferRowType, mergeRowTypes)
rootTypes.map {
case _: NullType => StringType
case other => other
}
}
private def inferRowType(rowSoFar: Array[DataType], next: Seq[DataType]): Array[DataType] = {
val maxLength = math.max(rowSoFar.length, next.size)
val defaultDataType: Int => DataType = (_ => NullType)
val filledRowSoFar = Array.tabulate(maxLength)(n => rowSoFar.applyOrElse[Int, DataType](n, defaultDataType))
val filledNext = Array.tabulate(maxLength)(n => next.applyOrElse[Int, DataType](n, defaultDataType))
filledRowSoFar.zip(filledNext).map { case (r, n) => inferField(r, n) }
}
private[excel] def mergeRowTypes(first: Array[DataType], second: Array[DataType]): Array[DataType] = {
first.zipAll(second, NullType, NullType).map { case ((a, b)) =>
findTightestCommonType(a, b).getOrElse(NullType)
}
}
/** Infer type of string field. Given known type Double, and a string "1", there is no point checking if it is an Int,
* as the final type must be Double or higher.
*/
private[excel] def inferField(typeSoFar: DataType, field: DataType): DataType = {
// Defining a function to return the StringType constant is necessary in order to work around
// a Scala compiler issue which leads to runtime incompatibilities with certain Spark versions;
// see issue #128 for more details.
def stringType(): DataType = {
StringType
}
if (field == NullType) {
typeSoFar
} else {
(typeSoFar, field) match {
case (NullType, ct) => ct
case (DoubleType, DoubleType) => DoubleType
case (BooleanType, BooleanType) => BooleanType
case (TimestampType, TimestampType) => TimestampType
case (StringType, _) => stringType()
case (_, _) => stringType()
}
}
}
/** Copied from internal Spark api [[org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion]]
*/
private val numericPrecedence: IndexedSeq[DataType] =
IndexedSeq[DataType](ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, TimestampType)
/** Copied from internal Spark api [[org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion]]
*/
val findTightestCommonType: (DataType, DataType) => Option[DataType] = {
case (t1, t2) if t1 == t2 => Some(t1)
case (NullType, t1) => Some(t1)
case (t1, NullType) => Some(t1)
case (StringType, _) => Some(StringType)
case (_, StringType) => Some(StringType)
// Promote numeric types to the highest of the two and all numeric types to unlimited decimal
case (t1, t2) if Seq(t1, t2).forall(numericPrecedence.contains) =>
val index = numericPrecedence.lastIndexWhere(t => t == t1 || t == t2)
Some(numericPrecedence(index))
case _ => None
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy