All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.sql.HoodieInternalRowUtils.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql

import org.apache.hudi.AvroConversionUtils.convertAvroSchemaToStructType
import org.apache.hudi.avro.HoodieAvroUtils.{createFullName, toJavaDate}
import org.apache.hudi.exception.HoodieException

import org.apache.avro.Schema
import org.apache.hbase.thirdparty.com.google.common.base.Supplier
import org.apache.spark.sql.HoodieCatalystExpressionUtils.generateUnsafeProjection
import org.apache.spark.sql.HoodieUnsafeRowUtils.{NestedFieldPath, composeNestedFieldPath}
import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, UnsafeArrayData, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.types.Decimal.ROUND_HALF_UP
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

import java.util.concurrent.ConcurrentHashMap
import java.util.function.{Function => JFunction}
import java.util.{ArrayDeque => JArrayDeque, Collections => JCollections, Deque => JDeque, Map => JMap}

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

object HoodieInternalRowUtils {

  private type RenamedColumnMap = JMap[String, String]
  private type UnsafeRowWriter = InternalRow => UnsafeRow

  // NOTE: [[UnsafeProjection]] objects cache have to stay [[ThreadLocal]] since these are not thread-safe
  private val unsafeWriterThreadLocal: ThreadLocal[mutable.HashMap[(StructType, StructType, RenamedColumnMap), UnsafeRowWriter]] =
    ThreadLocal.withInitial(new Supplier[mutable.HashMap[(StructType, StructType, RenamedColumnMap), UnsafeRowWriter]] {
      override def get(): mutable.HashMap[(StructType, StructType, RenamedColumnMap), UnsafeRowWriter] =
        new mutable.HashMap[(StructType, StructType, RenamedColumnMap), UnsafeRowWriter]
    })

  // NOTE: [[UnsafeRowWriter]] objects cache have to stay [[ThreadLocal]] since these are not thread-safe
  private val unsafeProjectionThreadLocal: ThreadLocal[mutable.HashMap[(StructType, StructType), UnsafeProjection]] =
    ThreadLocal.withInitial(new Supplier[mutable.HashMap[(StructType, StructType), UnsafeProjection]] {
      override def get(): mutable.HashMap[(StructType, StructType), UnsafeProjection] =
        new mutable.HashMap[(StructType, StructType), UnsafeProjection]
    })

  private val schemaMap = new ConcurrentHashMap[Schema, StructType]
  private val orderPosListMap = new ConcurrentHashMap[(StructType, String), Option[NestedFieldPath]]

  /**
   * Provides cached instance of [[UnsafeProjection]] transforming provided [[InternalRow]]s from
   * one [[StructType]] and into another [[StructType]]
   *
   * For more details regarding its semantic, please check corresponding scala-doc for
   * [[HoodieCatalystExpressionUtils.generateUnsafeProjection]]
   */
  def getCachedUnsafeProjection(from: StructType, to: StructType): UnsafeProjection = {
    unsafeProjectionThreadLocal.get()
      .getOrElseUpdate((from, to), generateUnsafeProjection(from, to))
  }

  /**
   * Provides cached instance of [[UnsafeRowWriter]] transforming provided [[InternalRow]]s from
   * one [[StructType]] and into another [[StructType]]
   *
   * Unlike [[UnsafeProjection]] requiring that [[from]] has to be a proper subset of [[to]] schema,
   * [[UnsafeRowWriter]] is able to perform whole spectrum of schema-evolution transformations including:
   *
   * 
    *
  • Transforming nested structs/maps/arrays
  • *
  • Handling type promotions (int -> long, etc)
  • *
  • Handling (field) renames
  • *
*/ def getCachedUnsafeRowWriter(from: StructType, to: StructType, renamedColumnsMap: JMap[String, String] = JCollections.emptyMap()): UnsafeRowWriter = { unsafeWriterThreadLocal.get() .getOrElseUpdate((from, to, renamedColumnsMap), genUnsafeRowWriter(from, to, renamedColumnsMap)) } def getCachedPosList(structType: StructType, field: String): Option[NestedFieldPath] = { val nestedFieldPathOpt = orderPosListMap.get((structType, field)) // NOTE: This specifically designed to do 2 lookups (in case of cache-miss) to avoid // allocating the closure when using [[computeIfAbsent]] on more frequent cache-hit path if (nestedFieldPathOpt != null) { nestedFieldPathOpt } else { orderPosListMap.computeIfAbsent((structType, field), new JFunction[(StructType, String), Option[NestedFieldPath]] { override def apply(t: (StructType, String)): Option[NestedFieldPath] = composeNestedFieldPath(structType, field) }) } } def getCachedSchema(schema: Schema): StructType = { val structType = schemaMap.get(schema) // NOTE: This specifically designed to do 2 lookups (in case of cache-miss) to avoid // allocating the closure when using [[computeIfAbsent]] on more frequent cache-hit path if (structType != null) { structType } else { schemaMap.computeIfAbsent(schema, new JFunction[Schema, StructType] { override def apply(t: Schema): StructType = convertAvroSchemaToStructType(schema) }) } } private[sql] def genUnsafeRowWriter(prevSchema: StructType, newSchema: StructType, renamedColumnsMap: JMap[String, String]): UnsafeRowWriter = { val writer = newWriterRenaming(prevSchema, newSchema, renamedColumnsMap, new JArrayDeque[String]()) val unsafeProjection = generateUnsafeProjection(newSchema, newSchema) val phonyUpdater = new CatalystDataUpdater { var value: InternalRow = _ override def set(ordinal: Int, value: Any): Unit = this.value = value.asInstanceOf[InternalRow] } oldRow => { writer(phonyUpdater, 0, oldRow) unsafeProjection(phonyUpdater.value) } } private type RowFieldUpdater = (CatalystDataUpdater, Int, Any) => Unit private def genUnsafeStructWriter(prevStructType: StructType, newStructType: StructType, renamedColumnsMap: JMap[String, String], fieldNamesStack: JDeque[String]): (CatalystDataUpdater, Any) => Unit = { // TODO need to canonicalize schemas (casing) val fieldWriters = ArrayBuffer.empty[RowFieldUpdater] val positionMap = ArrayBuffer.empty[Int] for (newField <- newStructType.fields) { fieldNamesStack.push(newField.name) val (fieldWriter, prevFieldPos): (RowFieldUpdater, Int) = prevStructType.getFieldIndex(newField.name) match { case Some(prevFieldPos) => val prevField = prevStructType(prevFieldPos) (newWriterRenaming(prevField.dataType, newField.dataType, renamedColumnsMap, fieldNamesStack), prevFieldPos) case None => val newFieldQualifiedName = createFullName(fieldNamesStack) val prevFieldName: String = lookupRenamedField(newFieldQualifiedName, renamedColumnsMap) // Handle rename prevStructType.getFieldIndex(prevFieldName) match { case Some(prevFieldPos) => val prevField = prevStructType.fields(prevFieldPos) (newWriterRenaming(prevField.dataType, newField.dataType, renamedColumnsMap, fieldNamesStack), prevFieldPos) case None => val updater: RowFieldUpdater = (fieldUpdater, ordinal, _) => fieldUpdater.setNullAt(ordinal) (updater, -1) } } fieldWriters += fieldWriter positionMap += prevFieldPos fieldNamesStack.pop() } (fieldUpdater, row) => { var pos = 0 while (pos < fieldWriters.length) { val prevPos = positionMap(pos) val prevValue = if (prevPos >= 0) { row.asInstanceOf[InternalRow].get(prevPos, prevStructType.fields(prevPos).dataType) } else { null } if(prevValue == null) fieldUpdater.setNullAt(pos) else fieldWriters(pos)(fieldUpdater, pos, prevValue) pos += 1 } } } private def newWriterRenaming(prevDataType: DataType, newDataType: DataType, renamedColumnsMap: JMap[String, String], fieldNameStack: JDeque[String]): RowFieldUpdater = { (newDataType, prevDataType) match { case (newType, prevType) if prevType.sql == newType.sql => (fieldUpdater, ordinal, value) => fieldUpdater.set(ordinal, value) case (newStructType: StructType, prevStructType: StructType) => val writer = genUnsafeStructWriter(prevStructType, newStructType, renamedColumnsMap, fieldNameStack) val newRow = new SpecificInternalRow(newStructType.fields.map(_.dataType)) val rowUpdater = new RowUpdater(newRow) (fieldUpdater, ordinal, value) => { // Here new row is built in 2 stages: // - First, we pass mutable row (used as buffer/scratchpad) created above wrapped into [[RowUpdater]] // into generated row-writer // - Upon returning from row-writer, we call back into parent row's [[fieldUpdater]] to set returned // row as a value in it writer(rowUpdater, value) fieldUpdater.set(ordinal, newRow) } case (ArrayType(newElementType, _), ArrayType(prevElementType, containsNull)) => fieldNameStack.push("element") val elementWriter = newWriterRenaming(prevElementType, newElementType, renamedColumnsMap, fieldNameStack) fieldNameStack.pop() (fieldUpdater, ordinal, value) => { val prevArrayData = value.asInstanceOf[ArrayData] val prevArray = prevArrayData.toObjectArray(prevElementType) val newArrayData = createArrayData(newElementType, prevArrayData.numElements()) val elementUpdater = new ArrayDataUpdater(newArrayData) var i = 0 while (i < prevArray.length) { val element = prevArray(i) if (element == null) { if (!containsNull) { throw new HoodieException( s"Array value at path '${fieldNameStack.asScala.mkString(".")}' is not allowed to be null") } else { elementUpdater.setNullAt(i) } } else { elementWriter(elementUpdater, i, element) } i += 1 } fieldUpdater.set(ordinal, newArrayData) } case (MapType(_, newValueType, _), MapType(_, prevValueType, valueContainsNull)) => fieldNameStack.push("value") val valueWriter = newWriterRenaming(prevValueType, newValueType, renamedColumnsMap, fieldNameStack) fieldNameStack.pop() (updater, ordinal, value) => val mapData = value.asInstanceOf[MapData] val prevKeyArrayData = mapData.keyArray val prevValueArrayData = mapData.valueArray val prevValueArray = prevValueArrayData.toObjectArray(prevValueType) val newValueArray = createArrayData(newValueType, mapData.numElements()) val valueUpdater = new ArrayDataUpdater(newValueArray) var i = 0 while (i < prevValueArray.length) { val value = prevValueArray(i) if (value == null) { if (!valueContainsNull) { throw new HoodieException(s"Map value at path ${fieldNameStack.asScala.mkString(".")} is not allowed to be null") } else { valueUpdater.setNullAt(i) } } else { valueWriter(valueUpdater, i, value) } i += 1 } // NOTE: Key's couldn't be transformed and have to always be of [[StringType]] updater.set(ordinal, new ArrayBasedMapData(prevKeyArrayData, newValueArray)) case (newDecimal: DecimalType, _) => prevDataType match { case IntegerType | LongType | FloatType | DoubleType | StringType => (fieldUpdater, ordinal, value) => val scale = newDecimal.scale // TODO this has to be revisited to avoid loss of precision (for fps) fieldUpdater.setDecimal(ordinal, Decimal.fromDecimal(BigDecimal(value.toString).setScale(scale, ROUND_HALF_UP))) case _: DecimalType => (fieldUpdater, ordinal, value) => fieldUpdater.setDecimal(ordinal, Decimal.fromDecimal(value.asInstanceOf[Decimal].toBigDecimal.setScale(newDecimal.scale))) case _ => throw new IllegalArgumentException(s"$prevDataType and $newDataType are incompatible") } case (_: ShortType, _) => prevDataType match { case _: ByteType => (fieldUpdater, ordinal, value) => fieldUpdater.setShort(ordinal, value.asInstanceOf[Byte].toShort) case _ => throw new IllegalArgumentException(s"$prevDataType and $newDataType are incompatible") } case (_: IntegerType, _) => prevDataType match { case _: ShortType => (fieldUpdater, ordinal, value) => fieldUpdater.setInt(ordinal, value.asInstanceOf[Short].toInt) case _: ByteType => (fieldUpdater, ordinal, value) => fieldUpdater.setInt(ordinal, value.asInstanceOf[Byte].toInt) case _ => throw new IllegalArgumentException(s"$prevDataType and $newDataType are incompatible") } case (_: LongType, _) => prevDataType match { case _: IntegerType => (fieldUpdater, ordinal, value) => fieldUpdater.setLong(ordinal, value.asInstanceOf[Int].toLong) case _: ShortType => (fieldUpdater, ordinal, value) => fieldUpdater.setLong(ordinal, value.asInstanceOf[Short].toLong) case _: ByteType => (fieldUpdater, ordinal, value) => fieldUpdater.setLong(ordinal, value.asInstanceOf[Byte].toLong) case _ => throw new IllegalArgumentException(s"$prevDataType and $newDataType are incompatible") } case (_: FloatType, _) => prevDataType match { case _: LongType => (fieldUpdater, ordinal, value) => fieldUpdater.setFloat(ordinal, value.asInstanceOf[Long].toFloat) case _: IntegerType => (fieldUpdater, ordinal, value) => fieldUpdater.setFloat(ordinal, value.asInstanceOf[Int].toFloat) case _: ShortType => (fieldUpdater, ordinal, value) => fieldUpdater.setFloat(ordinal, value.asInstanceOf[Short].toFloat) case _: ByteType => (fieldUpdater, ordinal, value) => fieldUpdater.setFloat(ordinal, value.asInstanceOf[Byte].toFloat) case _ => throw new IllegalArgumentException(s"$prevDataType and $newDataType are incompatible") } case (_: DoubleType, _) => prevDataType match { case _: FloatType => (fieldUpdater, ordinal, value) => fieldUpdater.setDouble(ordinal, value.asInstanceOf[Float].toDouble) case _: LongType => (fieldUpdater, ordinal, value) => fieldUpdater.setDouble(ordinal, value.asInstanceOf[Long].toDouble) case _: IntegerType => (fieldUpdater, ordinal, value) => fieldUpdater.setDouble(ordinal, value.asInstanceOf[Int].toDouble) case _: ShortType => (fieldUpdater, ordinal, value) => fieldUpdater.setDouble(ordinal, value.asInstanceOf[Short].toDouble) case _: ByteType => (fieldUpdater, ordinal, value) => fieldUpdater.setDouble(ordinal, value.asInstanceOf[Byte].toDouble) case _ => throw new IllegalArgumentException(s"$prevDataType and $newDataType are incompatible") } case (_: BinaryType, _: StringType) => (fieldUpdater, ordinal, value) => fieldUpdater.set(ordinal, value.asInstanceOf[UTF8String].getBytes) // TODO revisit this (we need to align permitted casting w/ Spark) // NOTE: This is supported to stay compatible w/ [[HoodieAvroUtils.rewriteRecordWithNewSchema]] case (_: StringType, _) => prevDataType match { case BinaryType => (fieldUpdater, ordinal, value) => fieldUpdater.set(ordinal, UTF8String.fromBytes(value.asInstanceOf[Array[Byte]])) case DateType => (fieldUpdater, ordinal, value) => fieldUpdater.set(ordinal, UTF8String.fromString(toJavaDate(value.asInstanceOf[Integer]).toString)) case IntegerType | LongType | FloatType | DoubleType | _: DecimalType => (fieldUpdater, ordinal, value) => fieldUpdater.set(ordinal, UTF8String.fromString(value.toString)) case _ => throw new IllegalArgumentException(s"$prevDataType and $newDataType are incompatible") } case (DateType, StringType) => (fieldUpdater, ordinal, value) => fieldUpdater.set(ordinal, CatalystTypeConverters.convertToCatalyst(java.sql.Date.valueOf(value.toString))) case (_, _) => throw new IllegalArgumentException(s"$prevDataType and $newDataType are incompatible") } } private def lookupRenamedField(newFieldQualifiedName: String, renamedColumnsMap: JMap[String, String]) = { val prevFieldQualifiedName = renamedColumnsMap.getOrDefault(newFieldQualifiedName, "") val prevFieldQualifiedNameParts = prevFieldQualifiedName.split("\\.") val prevFieldName = prevFieldQualifiedNameParts(prevFieldQualifiedNameParts.length - 1) prevFieldName } private def createArrayData(elementType: DataType, length: Int): ArrayData = elementType match { case BooleanType => UnsafeArrayData.fromPrimitiveArray(new Array[Boolean](length)) case ByteType => UnsafeArrayData.fromPrimitiveArray(new Array[Byte](length)) case ShortType => UnsafeArrayData.fromPrimitiveArray(new Array[Short](length)) case IntegerType => UnsafeArrayData.fromPrimitiveArray(new Array[Int](length)) case LongType => UnsafeArrayData.fromPrimitiveArray(new Array[Long](length)) case FloatType => UnsafeArrayData.fromPrimitiveArray(new Array[Float](length)) case DoubleType => UnsafeArrayData.fromPrimitiveArray(new Array[Double](length)) case _ => new GenericArrayData(new Array[Any](length)) } sealed trait CatalystDataUpdater { def set(ordinal: Int, value: Any): Unit def setNullAt(ordinal: Int): Unit = set(ordinal, null) def setBoolean(ordinal: Int, value: Boolean): Unit = set(ordinal, value) def setByte(ordinal: Int, value: Byte): Unit = set(ordinal, value) def setShort(ordinal: Int, value: Short): Unit = set(ordinal, value) def setInt(ordinal: Int, value: Int): Unit = set(ordinal, value) def setLong(ordinal: Int, value: Long): Unit = set(ordinal, value) def setDouble(ordinal: Int, value: Double): Unit = set(ordinal, value) def setFloat(ordinal: Int, value: Float): Unit = set(ordinal, value) def setDecimal(ordinal: Int, value: Decimal): Unit = set(ordinal, value) } final class RowUpdater(row: InternalRow) extends CatalystDataUpdater { override def set(ordinal: Int, value: Any): Unit = row.update(ordinal, value) override def setNullAt(ordinal: Int): Unit = row.setNullAt(ordinal) override def setBoolean(ordinal: Int, value: Boolean): Unit = row.setBoolean(ordinal, value) override def setByte(ordinal: Int, value: Byte): Unit = row.setByte(ordinal, value) override def setShort(ordinal: Int, value: Short): Unit = row.setShort(ordinal, value) override def setInt(ordinal: Int, value: Int): Unit = row.setInt(ordinal, value) override def setLong(ordinal: Int, value: Long): Unit = row.setLong(ordinal, value) override def setDouble(ordinal: Int, value: Double): Unit = row.setDouble(ordinal, value) override def setFloat(ordinal: Int, value: Float): Unit = row.setFloat(ordinal, value) override def setDecimal(ordinal: Int, value: Decimal): Unit = row.setDecimal(ordinal, value, value.precision) } final class ArrayDataUpdater(array: ArrayData) extends CatalystDataUpdater { override def set(ordinal: Int, value: Any): Unit = array.update(ordinal, value) override def setNullAt(ordinal: Int): Unit = array.setNullAt(ordinal) override def setBoolean(ordinal: Int, value: Boolean): Unit = array.setBoolean(ordinal, value) override def setByte(ordinal: Int, value: Byte): Unit = array.setByte(ordinal, value) override def setShort(ordinal: Int, value: Short): Unit = array.setShort(ordinal, value) override def setInt(ordinal: Int, value: Int): Unit = array.setInt(ordinal, value) override def setLong(ordinal: Int, value: Long): Unit = array.setLong(ordinal, value) override def setDouble(ordinal: Int, value: Double): Unit = array.setDouble(ordinal, value) override def setFloat(ordinal: Int, value: Float): Unit = array.setFloat(ordinal, value) override def setDecimal(ordinal: Int, value: Decimal): Unit = array.update(ordinal, value) } }




© 2015 - 2025 Weber Informatics LLC | Privacy Policy