za.co.absa.cobrix.spark.cobol.utils.SparkUtils.scala Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of spark-cobol_2.11 Show documentation
Show all versions of spark-cobol_2.11 Show documentation
COBOL Reading and Import Extensions for Apache Spark
The newest version!
/*
* Copyright 2018 ABSA Group Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package za.co.absa.cobrix.spark.cobol.utils
import com.fasterxml.jackson.databind.ObjectMapper
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.SparkContext
import org.apache.spark.sql.functions.{array, col, expr, max, struct}
import za.co.absa.cobrix.spark.cobol.utils.impl.HofsWrapper.transform
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Column, DataFrame, SparkSession}
import za.co.absa.cobrix.cobol.internal.Logging
import za.co.absa.cobrix.spark.cobol.parameters.MetadataFields.MAX_ELEMENTS
import scala.annotation.tailrec
import scala.collection.mutable
import scala.collection.mutable.ListBuffer
import scala.util.Try
/**
* This object contains common Spark tools used for easier processing of dataframes originated from mainframes.
*/
object SparkUtils extends Logging {
/**
* Retrieves all executors available for the current job.
*/
def currentActiveExecutors(sc: SparkContext): Seq[String] = {
val allExecutors = sc.getExecutorMemoryStatus.map(_._1.split(":").head)
val driverHost: String = sc.getConf.get("spark.driver.host", "localhost")
logger.info(s"Going to filter driver from available executors: Driver host: $driverHost, Available executors: $allExecutors")
allExecutors.filter(!_.equals(driverHost)).toList.distinct
}
/**
* Returns true if Spark Data type is a primitive data type.
*
* @param dataType Spark data type
* @return true if the data type is primitive.
*/
def isPrimitive(dataType: DataType): Boolean = {
dataType match {
case _: ArrayType => false
case _: StructType => false
case _: MapType => false
case _ => true
}
}
/**
* Given an instance of DataFrame returns a dataframe with flattened schema.
* All nested structures are flattened and arrays are projected as columns.
*
* Note. The method checks the maximum size for each array and that could perform slowly,
* especially on a vary big dataframes.
*
* @param df A dataframe
* @param useShortFieldNames When flattening a schema each field name will contain full path. You can override this
* behavior and use a short field names instead
* @return A new dataframe with flat schema.
*/
def flattenSchema(df: DataFrame, useShortFieldNames: Boolean = false): DataFrame = {
val fields = new mutable.ListBuffer[Column]()
val stringFields = new mutable.ListBuffer[String]()
val usedNames = new mutable.HashSet[String]()
def getNewFieldName(desiredName: String): String = {
var name = desiredName
var i = 1
while (usedNames.contains(name)) {
name = s"$desiredName$i"
i += 1
}
usedNames.add(name)
name
}
/**
* Aggregating arrays of primitives by projecting it's columns
*
* @param path path to an StructArray
* @param fieldNamePrefix Prefix for the field name
* @param structField StructField
* @param arrayType ArrayType
*/
def flattenStructArray(path: String, fieldNamePrefix: String, structField: StructField, arrayType: ArrayType): Unit = {
val maxInd = getMaxArraySize(s"$path${structField.name}")
var i = 0
while (i < maxInd) {
arrayType.elementType match {
case st: StructType =>
val newFieldNamePrefix = s"${fieldNamePrefix}${i}_"
flattenGroup(s"$path`${structField.name}`[$i].", newFieldNamePrefix, st)
case ar: ArrayType =>
val newFieldNamePrefix = s"${fieldNamePrefix}${i}_"
flattenArray(s"$path`${structField.name}`[$i].", newFieldNamePrefix, structField, ar)
// AtomicType is protected on package 'sql' level so have to enumerate all subtypes :(
case _ =>
val newFieldNamePrefix = s"${fieldNamePrefix}${i}"
val newFieldName = getNewFieldName(s"$newFieldNamePrefix")
fields += expr(s"$path`${structField.name}`[$i]").as(newFieldName, structField.metadata)
stringFields += s"""expr("$path`${structField.name}`[$i] AS `$newFieldName`")"""
}
i += 1
}
}
def flattenNestedArrays(path: String, fieldNamePrefix: String, arrayType: ArrayType, metadata: Metadata): Unit = {
val maxInd = getMaxArraySize(path)
var i = 0
while (i < maxInd) {
arrayType.elementType match {
case st: StructType =>
val newFieldNamePrefix = s"${fieldNamePrefix}${i}_"
flattenGroup(s"$path[$i]", newFieldNamePrefix, st)
case ar: ArrayType =>
val newFieldNamePrefix = s"${fieldNamePrefix}${i}_"
flattenNestedArrays(s"$path[$i]", newFieldNamePrefix, ar, metadata)
// AtomicType is protected on package 'sql' level so have to enumerate all subtypes :(
case _ =>
val newFieldNamePrefix = s"${fieldNamePrefix}${i}"
val newFieldName = getNewFieldName(s"$newFieldNamePrefix")
fields += expr(s"$path[$i]").as(newFieldName, metadata)
stringFields += s"""expr("$path`[$i] AS `$newFieldName`")"""
}
i += 1
}
}
def getMaxArraySize(path: String): Int = {
getField(path, df.schema) match {
case Some(field) if field.metadata.contains(MAX_ELEMENTS) =>
field.metadata.getLong(MAX_ELEMENTS).toInt
case _ =>
val collected = df.agg(max(expr(s"size($path)"))).collect()(0)(0)
if (collected != null) {
// can be null for empty dataframe
collected.toString.toInt
} else {
1
}
}
}
def flattenArray(path: String, fieldNamePrefix: String, structField: StructField, arrayType: ArrayType): Unit = {
arrayType.elementType match {
case _: ArrayType =>
flattenNestedArrays(s"$path${structField.name}", fieldNamePrefix, arrayType, structField.metadata)
case _ =>
flattenStructArray(path, fieldNamePrefix, structField, arrayType)
}
}
def flattenGroup(path: String, fieldNamePrefix: String, structField: StructType): Unit = {
structField.foreach(field => {
val newFieldNamePrefix = if (useShortFieldNames) {
s"${field.name}_"
} else {
s"$fieldNamePrefix${field.name}_"
}
field.dataType match {
case st: StructType =>
flattenGroup(s"$path`${field.name}`.", newFieldNamePrefix, st)
case arr: ArrayType =>
flattenArray(path, newFieldNamePrefix, field, arr)
case _ =>
val newFieldName = getNewFieldName(s"$fieldNamePrefix${field.name}")
fields += expr(s"$path`${field.name}`").as(newFieldName, field.metadata)
if (path.contains('['))
stringFields += s"""expr("$path`${field.name}` AS `$newFieldName`")"""
else
stringFields += s"""col("$path`${field.name}`").as("$newFieldName")"""
}
})
}
flattenGroup("", "", df.schema)
logger.info(stringFields.mkString("Flattening code: \n.select(\n", ",\n", "\n)"))
df.select(fields.toSeq: _*)
}
/**
* Removes all struct nesting when possible for a given schema.
*/
def unstructSchema(schema: StructType, useShortFieldNames: Boolean = false): StructType = {
def mapFieldShort(field: StructField): Array[StructField] = {
field.dataType match {
case st: StructType =>
st.fields flatMap mapFieldShort
case _ =>
Array(field)
}
}
def mapFieldLong(field: StructField, path: String): Array[StructField] = {
field.dataType match {
case st: StructType =>
st.fields.flatMap(f => mapFieldLong(f, s"$path${field.name}_"))
case _ =>
Array(field.copy(name = s"$path${field.name}"))
}
}
val fields = if (useShortFieldNames)
schema.fields flatMap mapFieldShort
else
schema.fields.flatMap(f => mapFieldLong(f, ""))
StructType(fields)
}
/**
* Removes all struct nesting when possible for a given dataframe.
*
* Similar to `flattenSchema()`, but does not flatten arrays.
*/
def unstructDataFrame(df: DataFrame, useShortFieldNames: Boolean = false): DataFrame = {
def mapFieldShort(column: Column, field: StructField): Array[Column] = {
field.dataType match {
case st: StructType =>
st.fields.flatMap(f => mapFieldShort(column.getField(f.name), f))
case _ =>
Array(column.as(field.name, field.metadata))
}
}
def mapFieldLong(column: Column, field: StructField, path: String): Array[Column] = {
field.dataType match {
case st: StructType =>
st.fields.flatMap(f => mapFieldLong(column.getField(f.name), f, s"$path${field.name}_"))
case _ =>
Array(column.as(s"$path${field.name}", field.metadata))
}
}
val columns = if (useShortFieldNames)
df.schema.fields.flatMap(f => mapFieldShort(col(f.name), f))
else
df.schema.fields.flatMap(f => mapFieldLong(col(f.name), f, ""))
df.select(columns: _*)
}
/**
* Copies metadata from one schema to another as long as names and data types are the same.
*
* @param schemaFrom Schema to copy metadata from.
* @param schemaTo Schema to copy metadata to.
* @param overwrite If true, the metadata of schemaTo is not retained
* @param sourcePreferred If true, schemaFrom metadata is used on conflicts, schemaTo otherwise.
* @param copyDataType If true, data type is copied as well. This is limited to primitive data types.
* @return Same schema as schemaTo with metadata from schemaFrom.
*/
def copyMetadata(schemaFrom: StructType,
schemaTo: StructType,
overwrite: Boolean = false,
sourcePreferred: Boolean = false,
copyDataType: Boolean = false): StructType = {
def joinMetadata(from: Metadata, to: Metadata): Metadata = {
val newMetadataMerged = new MetadataBuilder
if (sourcePreferred) {
newMetadataMerged.withMetadata(to)
newMetadataMerged.withMetadata(from)
} else {
newMetadataMerged.withMetadata(from)
newMetadataMerged.withMetadata(to)
}
newMetadataMerged.build()
}
@tailrec
def processArray(ar: ArrayType, fieldFrom: StructField, fieldTo: StructField): ArrayType = {
ar.elementType match {
case st: StructType if fieldFrom.dataType.isInstanceOf[ArrayType] && fieldFrom.dataType.asInstanceOf[ArrayType].elementType.isInstanceOf[StructType] =>
val innerStructFrom = fieldFrom.dataType.asInstanceOf[ArrayType].elementType.asInstanceOf[StructType]
val newDataType = StructType(copyMetadata(innerStructFrom, st, overwrite, sourcePreferred, copyDataType).fields)
ArrayType(newDataType, ar.containsNull)
case at: ArrayType =>
processArray(at, fieldFrom, fieldTo)
case p =>
if (copyDataType && fieldFrom.dataType.isInstanceOf[ArrayType] && isPrimitive(fieldFrom.dataType.asInstanceOf[ArrayType].elementType)) {
ArrayType(fieldFrom.dataType.asInstanceOf[ArrayType].elementType, ar.containsNull)
} else {
ArrayType(p, ar.containsNull)
}
}
}
val fieldsMap = schemaFrom.fields.map(f => (f.name, f)).toMap
val newFields: Array[StructField] = schemaTo.fields.map { fieldTo =>
fieldsMap.get(fieldTo.name) match {
case Some(fieldFrom) =>
val newMetadata = if (overwrite) {
fieldFrom.metadata
} else {
joinMetadata(fieldFrom.metadata, fieldTo.metadata)
}
fieldTo.dataType match {
case st: StructType if fieldFrom.dataType.isInstanceOf[StructType] =>
val newDataType = StructType(copyMetadata(fieldFrom.dataType.asInstanceOf[StructType], st, overwrite, sourcePreferred, copyDataType).fields)
fieldTo.copy(dataType = newDataType, metadata = newMetadata)
case at: ArrayType =>
val newType = processArray(at, fieldFrom, fieldTo)
fieldTo.copy(dataType = newType, metadata = newMetadata)
case _ =>
if (copyDataType && isPrimitive(fieldFrom.dataType)) {
fieldTo.copy(dataType = fieldFrom.dataType, metadata = newMetadata)
} else {
fieldTo.copy(metadata = newMetadata)
}
}
case None =>
fieldTo
}
}
StructType(newFields)
}
/**
* Allows mapping every primitive field in a dataframe with a Spark expression.
*
* The metadata of the original schema is retained.
*
* @param df The dataframe to map.
* @param f The function to apply to each primitive field.
* @return The new dataframe with the mapping applied.
*/
def mapPrimitives(df: DataFrame)(f: (StructField, Column) => Column): DataFrame = {
def mapField(column: Column, field: StructField): Column = {
field.dataType match {
case st: StructType =>
val columns = st.fields.map(f => mapField(column.getField(f.name), f))
struct(columns: _*).as(field.name, field.metadata)
case ar: ArrayType =>
mapArray(ar, column, field.name).as(field.name, field.metadata)
case _ =>
f(field, column).as(field.name, field.metadata)
}
}
def mapArray(arr: ArrayType, column: Column, columnName: String): Column = {
arr.elementType match {
case st: StructType =>
transform(column, c => {
val columns = st.fields.map(f => mapField(c.getField(f.name), f))
struct(columns: _*)
})
case ar: ArrayType =>
array(mapArray(ar, column, columnName))
case p =>
array(f(StructField(columnName, p), column))
}
}
val columns = df.schema.fields.map(f => mapField(col(f.name), f))
val newDf = df.select(columns: _*)
val newSchema = copyMetadata(df.schema, newDf.schema)
df.sparkSession.createDataFrame(newDf.rdd, newSchema)
}
def covertIntegralToDecimal(df: DataFrame): DataFrame = {
mapPrimitives(df) { (field, c) =>
val metadata = field.metadata
if (metadata.contains("precision") && (field.dataType == LongType || field.dataType == IntegerType || field.dataType == ShortType)) {
val precision = metadata.getLong("precision").toInt
c.cast(DecimalType(precision, 0)).as(field.name)
} else {
c
}
}
}
/**
* Given an instance of DataFrame returns a dataframe where all primitive fields are converted to String
*
* @param df A dataframe
* @return A new dataframe with all primitive fields as Strings.
*/
def convertDataframeFieldsToStrings(df: DataFrame): DataFrame = {
val fields = new mutable.ListBuffer[Column]()
def convertArrayToStrings(path: String, structField: StructField, arrayType: ArrayType): Unit = {
arrayType.elementType match {
case st: StructType =>
// ToDo convert array's inner struct fields to Strings.
// Possibly Spark 2.4 array transform API could be used for that.
fields += expr(s"$path`${structField.name}`")
case fld =>
fields += expr(s"$path`${structField.name}`").cast(ArrayType(StringType))
}
}
def convertToStrings(path: String, structField: StructType): Unit = {
structField.foreach(field => {
field.dataType match {
case st: StructType =>
convertToStrings(s"$path`${field.name}`.", st)
case arr: ArrayType =>
convertArrayToStrings(path, field, arr)
case fld =>
fields += expr(s"$path`${field.name}`").cast(StringType)
}
})
}
convertToStrings("", df.schema)
df.select(fields.toSeq: _*)
}
def convertDataFrameToPrettyJSON(df: DataFrame, takeN: Int = 0): String = {
val collected = if (takeN <= 0) {
df.toJSON.collect().mkString("\n")
} else {
df.toJSON.take(takeN).mkString("\n")
}
val json = "[" + "}\n".r.replaceAllIn(collected, "},\n") + "]"
prettyJSON(json)
}
def prettyJSON(jsonIn: String): String = {
val mapper = new ObjectMapper()
val jsonUnindented = mapper.readValue(jsonIn, classOf[Any])
val indented = mapper.writerWithDefaultPrettyPrinter.writeValueAsString(jsonUnindented)
indented.replace("\r\n", "\n")
}
/**
* Get a Spark field from a text path and a given schema
* (originally implemented here: https://github.com/AbsaOSS/enceladus/blob/665b34fa1c04fe255729e4b6706cf9ea33227b3e/utils/src/main/scala/za/co/absa/enceladus/utils/schema/SchemaUtils.scala#L45)
*
* @param path The dot-separated path to the field
* @param schema The schema which should contain the specified path
* @return Some(the requested field) or None if the field does not exist
*/
def getField(path: String, schema: StructType): Option[StructField] = {
@tailrec
def goThroughArrayDataType(dataType: DataType): DataType = {
dataType match {
case ArrayType(dt, _) => goThroughArrayDataType(dt)
case result => result
}
}
@tailrec
def examineStructField(names: List[String], structField: StructField): Option[StructField] = {
if (names.isEmpty) {
Option(structField)
} else {
structField.dataType match {
case struct: StructType => examineStructField(names.tail, struct(names.head))
case ArrayType(el: DataType, _) =>
goThroughArrayDataType(el) match {
case struct: StructType => examineStructField(names.tail, struct(names.head))
case _ => None
}
case _ => None
}
}
}
val pathTokens = splitFieldPath(path)
Try {
examineStructField(pathTokens.tail, schema(pathTokens.head))
}.getOrElse(None)
}
def getDefaultHdfsBlockSize(spark: SparkSession, pathOpt: Option[String]): Option[Int] = {
val conf = spark.sparkContext.hadoopConfiguration
val fileSystem =pathOpt match {
case Some(pathStr) => new Path(pathStr).getFileSystem(conf)
case None => FileSystem.get(conf)
}
val hdfsBlockSize = HDFSUtils.getHDFSDefaultBlockSizeMB(fileSystem)
hdfsBlockSize match {
case None => logger.info(s"Unable to get default block size for '${fileSystem.getScheme}://.")
case Some(size) => logger.info(s"Default block size for '${fileSystem.getScheme}://' is $size MB.")
}
hdfsBlockSize
}
private def splitFieldPath(path: String): List[String] = {
var state = 0
var currentField = new StringBuilder()
val fields = new ListBuffer[String]()
var i = 0
while (i < path.length) {
val c = path(i)
state match {
case 0 =>
// The character might be part of the path
if (c == '.') {
fields.append(currentField.toString())
currentField = new StringBuilder()
} else if (c == '`') {
state = 1
} else if (c == '[') {
state = 2
} else {
currentField.append(c)
}
case 1 =>
// The character is part of the backquoted field name
if (c == '`') {
state = 0
} else {
currentField.append(c)
}
case 2 =>
// The character is an index (that should be ignored)
if (c == ']') {
state = 0
}
}
i += 1
}
if (currentField.nonEmpty) {
fields.append(currentField.toString())
}
fields.toList
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy