com.azure.cosmos.spark.CosmosRowConverterBase.scala Maven / Gradle / Ivy
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package com.azure.cosmos.spark
import com.azure.cosmos.implementation.{Constants, Utils}
import com.azure.cosmos.spark.CosmosTableSchemaInferrer._
import com.azure.cosmos.spark.SchemaConversionModes.SchemaConversionMode
import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait
import com.fasterxml.jackson.databind.node._
import com.fasterxml.jackson.databind.{JsonNode, ObjectMapper}
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{GenericRowWithSchema, UnsafeMapData}
import org.apache.spark.sql.catalyst.util.ArrayData
import java.io.IOException
import java.sql.{Date, Timestamp}
import java.time.format.DateTimeFormatter
import java.time.{Instant, LocalDate, OffsetDateTime, ZoneOffset}
import java.util.concurrent.TimeUnit
// scalastyle:off underscore.import
import org.apache.spark.sql.types._
import scala.collection.JavaConverters._
// scalastyle:on underscore.import
import org.apache.spark.unsafe.types.UTF8String
import scala.util.{Try, Success, Failure}
private[cosmos] class CosmosRowConverterBase(
private val objectMapper: ObjectMapper,
private val serializationConfig: CosmosSerializationConfig)
extends BasicLoggingTrait {
private val skipDefaultValues =
serializationConfig.serializationInclusionMode == SerializationInclusionModes.NonDefault
private val TimeToLiveExpiredPropertyName = "timeToLiveExpired"
private val utcFormatter = DateTimeFormatter
.ofPattern("yyyy-MM-dd'T'HH:mm:ss'Z'").withZone(ZoneOffset.UTC)
def fromRowToInternalRow(row: Row,
rowSerializer: ExpressionEncoder.Serializer[Row]): InternalRow = {
try {
rowSerializer.apply(row)
}
catch {
case inner: RuntimeException =>
throw new Exception(
s"Cannot convert row into InternalRow",
inner)
}
}
def ensureObjectNode(jsonNode: JsonNode): ObjectNode = {
if (jsonNode.isValueNode || jsonNode.isArray) {
try Utils
.getSimpleObjectMapper.readTree(s"""{"${Constants.Properties.VALUE}": $jsonNode}""")
.asInstanceOf[ObjectNode]
catch {
case e: IOException =>
throw new IllegalStateException(s"Unable to parse JSON $jsonNode", e)
}
} else {
jsonNode.asInstanceOf[ObjectNode]
}
}
def fromObjectNodeToRow(schema: StructType,
objectNode: ObjectNode,
schemaConversionMode: SchemaConversionMode): Row = {
val values: Seq[Any] = convertStructToSparkDataType(schema, objectNode, schemaConversionMode)
new GenericRowWithSchema(values.toArray, schema)
}
def fromObjectNodeToRowWithComputedColumns(schema: StructType,
objectNode: ObjectNode,
schemaConversionMode: SchemaConversionMode,
computedColumns: Map[String, ObjectNode => Any]): Row = {
val values: Seq[Any] =
convertStructToSparkDataTypeWithComputedColumns(
schema,
objectNode,
schemaConversionMode,
computedColumns)
new GenericRowWithSchema(values.toArray, schema)
}
def fromObjectNodeToChangeFeedRowV1(schema: StructType,
objectNode: ObjectNode,
schemaConversionMode: SchemaConversionMode): Row = {
val values: Seq[Any] = convertStructToChangeFeedSparkDataTypeV1(schema, objectNode, schemaConversionMode)
new GenericRowWithSchema(values.toArray, schema)
}
def fromRowToObjectNode(row: Row): ObjectNode = {
val rawBodyFieldName = if (row.schema.names.contains(CosmosTableSchemaInferrer.RawJsonBodyAttributeName) &&
row.schema.apply(CosmosTableSchemaInferrer.RawJsonBodyAttributeName).dataType.isInstanceOf[StringType]) {
Some(CosmosTableSchemaInferrer.RawJsonBodyAttributeName)
} else if (row.schema.names.contains(CosmosTableSchemaInferrer.OriginRawJsonBodyAttributeName) &&
row.schema.apply(CosmosTableSchemaInferrer.OriginRawJsonBodyAttributeName).dataType.isInstanceOf[StringType]) {
Some(CosmosTableSchemaInferrer.OriginRawJsonBodyAttributeName)
} else {
None
}
if (rawBodyFieldName.isDefined) {
// Special case when the reader read the rawJson
val rawJson = row.getAs[String](rawBodyFieldName.get)
convertRawBodyJsonToObjectNode(rawJson, rawBodyFieldName.get)
} else {
val objectNode: ObjectNode = objectMapper.createObjectNode()
row.schema.fields.zipWithIndex.foreach({
case (field, i) =>
field.dataType match {
case _: NullType => putNullConditionally(objectNode, field.name)
case _ if row.isNullAt(i) => putNullConditionally(objectNode, field.name)
case _ =>
val nodeOpt = convertSparkDataTypeToJsonNode(field.dataType, row.get(i))
if (nodeOpt.isDefined) {
objectNode.set(field.name, nodeOpt.get)
}
}
})
objectNode
}
}
def getChangeFeedLsn(objectNode: ObjectNode): String = {
objectNode.get(MetadataJsonBodyAttributeName) match {
case metadataNode: JsonNode =>
metadataNode.get(MetadataLsnAttributeName) match {
case lsnNode: JsonNode => lsnNode.asText()
case _ => null
}
case _ => null
}
}
private def convertRawBodyJsonToObjectNode(json: String, rawBodyFieldName: String): ObjectNode = {
val doc = objectMapper.readTree(json).asInstanceOf[ObjectNode]
if (rawBodyFieldName == CosmosTableSchemaInferrer.OriginRawJsonBodyAttributeName) {
doc.set(
CosmosTableSchemaInferrer.OriginETagAttributeName,
doc.get(CosmosTableSchemaInferrer.ETagAttributeName))
doc.set(
CosmosTableSchemaInferrer.OriginTimestampAttributeName,
doc.get(CosmosTableSchemaInferrer.TimestampAttributeName))
}
doc
}
def fromInternalRowToObjectNode(row: InternalRow, schema: StructType): ObjectNode = {
val rawBodyFieldName = if (schema.names.contains(CosmosTableSchemaInferrer.RawJsonBodyAttributeName) &&
schema.apply(CosmosTableSchemaInferrer.RawJsonBodyAttributeName).dataType.isInstanceOf[StringType]) {
Some(CosmosTableSchemaInferrer.RawJsonBodyAttributeName)
} else if (schema.names.contains(CosmosTableSchemaInferrer.OriginRawJsonBodyAttributeName) &&
schema.apply(CosmosTableSchemaInferrer.OriginRawJsonBodyAttributeName).dataType.isInstanceOf[StringType]) {
Some(CosmosTableSchemaInferrer.OriginRawJsonBodyAttributeName)
} else {
None
}
if (rawBodyFieldName.isDefined) {
val rawBodyFieldIndex = schema.fieldIndex(rawBodyFieldName.get)
// Special case when the reader read the rawJson
val rawJson = convertRowDataToString(row.get(rawBodyFieldIndex, StringType))
convertRawBodyJsonToObjectNode(rawJson, rawBodyFieldName.get)
} else {
val objectNode: ObjectNode = objectMapper.createObjectNode()
schema.fields.zipWithIndex.foreach({
case (field, i) =>
field.dataType match {
case _: NullType => putNullConditionally(objectNode, field.name)
case _ if row.isNullAt(i) => putNullConditionally(objectNode, field.name)
case _ =>
val nodeOpt = convertSparkDataTypeToJsonNode(field.dataType, row.get(i, field.dataType))
if (nodeOpt.isDefined) {
objectNode.set(field.name, nodeOpt.get)
}
}
})
objectNode
}
}
private def convertToStringKeyMap(input: Any): Map[String, _] = {
try {
input.asInstanceOf[Map[String, _]]
}
catch {
case _: ClassCastException =>
throw new Exception(
s"Cannot cast $input into a Json value. MapTypes must have "
+ s"keys of StringType for conversion Json")
}
}
private def convertRowDataToString(rowData: Any): String = {
rowData match {
case str: String =>
str
case string: UTF8String =>
string.toString
case _ =>
throw new Exception(s"Cannot cast $rowData into a String.")
}
}
private def convertSparkDataTypeToJsonNode(fieldType: DataType, rowData: Any): Option[JsonNode] = {
if (serializationConfig.serializationInclusionMode == SerializationInclusionModes.NonEmpty ||
serializationConfig.serializationInclusionMode == SerializationInclusionModes.NonDefault) {
convertSparkDataTypeToJsonNodeConditionally(fieldType, rowData: Any)
} else {
Some(convertSparkDataTypeToJsonNodeNonNull(fieldType, rowData: Any))
}
}
private def isDefaultValue(value: Any): Boolean = {
value match {
case stringValue: String => stringValue.isEmpty
case intValue: Int => intValue == 0
case shortValue: Short => shortValue == 0
case byteValue: Byte => byteValue == 0
case longValue: Long => longValue == 0
case arrayValue: Array[_] => arrayValue.isEmpty
case booleanValue: Boolean => !booleanValue
case doubleValue: Double => doubleValue == 0
case floatValue: Float => floatValue == 0
case bigDecimalValue: java.math.BigDecimal => bigDecimalValue.compareTo(java.math.BigDecimal.ZERO) == 0
case arrayDataValue: ArrayData => arrayDataValue.numElements() == 0
case sequenceValue: Seq[_] => sequenceValue.isEmpty
case stringMapValue: Map[String, _] => stringMapValue.isEmpty
case unsafeMapDataValue: UnsafeMapData => unsafeMapDataValue.numElements() == 0
case _ => throw new Exception("Invalid value type used - can't determine default value")
}
}
protected[spark] def convertToJsonNodeConditionally[T](value: T): Option[JsonNode] = {
if (skipDefaultValues && isDefaultValue(value)) {
None
} else {
Some(objectMapper.convertValue(value, classOf[JsonNode]))
}
}
private def convertSparkDataTypeToJsonNodeConditionally
(
fieldType: DataType,
rowData: Any
): Option[JsonNode] = {
fieldType match {
case StringType =>
val stringValue = convertRowDataToString(rowData)
if (isDefaultValue(stringValue)) {
None
} else {
Some(objectMapper.convertValue(stringValue, classOf[JsonNode]))
}
case BinaryType =>
val blobValue = rowData.asInstanceOf[Array[Byte]]
if (isDefaultValue(blobValue)) {
None
} else {
Some(objectMapper.convertValue(blobValue, classOf[JsonNode]))
}
case BooleanType => convertToJsonNodeConditionally(rowData.asInstanceOf[Boolean])
case DoubleType => convertToJsonNodeConditionally(rowData.asInstanceOf[Double])
case IntegerType => convertToJsonNodeConditionally(rowData.asInstanceOf[Int])
case ShortType => convertToJsonNodeConditionally(rowData.asInstanceOf[Short])
case ByteType => convertToJsonNodeConditionally(rowData.asInstanceOf[Byte])
case LongType => convertToJsonNodeConditionally(rowData.asInstanceOf[Long])
case FloatType => convertToJsonNodeConditionally(rowData.asInstanceOf[Float])
case DecimalType() if rowData.isInstanceOf[Decimal] =>
convertToJsonNodeConditionally(rowData.asInstanceOf[Decimal].toJavaBigDecimal)
case DecimalType() if rowData.isInstanceOf[Long] =>
convertToJsonNodeConditionally(new java.math.BigDecimal(rowData.asInstanceOf[java.lang.Long]))
case DecimalType() =>
convertToJsonNodeConditionally(rowData.asInstanceOf[java.math.BigDecimal])
case DateType if rowData.isInstanceOf[java.lang.Long] =>
serializationConfig.serializationDateTimeConversionMode match {
case SerializationDateTimeConversionModes.Default =>
convertToJsonNodeConditionally(rowData.asInstanceOf[Long])
case SerializationDateTimeConversionModes.AlwaysEpochMillisecondsWithUtcTimezone =>
convertToJsonNodeConditionally(LocalDate
.ofEpochDay(rowData.asInstanceOf[Long])
.atStartOfDay()
.toInstant(ZoneOffset.UTC).toEpochMilli)
case SerializationDateTimeConversionModes.AlwaysEpochMillisecondsWithSystemDefaultTimezone =>
val localDate = LocalDate
.ofEpochDay(rowData.asInstanceOf[Long])
.atStartOfDay()
val localTimestampInstant = Timestamp.valueOf(localDate).toInstant
convertToJsonNodeConditionally(
localDate
.toInstant(java.time.ZoneId.systemDefault.getRules().getOffset(localTimestampInstant)).toEpochMilli)
}
case DateType if rowData.isInstanceOf[java.lang.Integer] =>
serializationConfig.serializationDateTimeConversionMode match {
case SerializationDateTimeConversionModes.Default =>
convertToJsonNodeConditionally(rowData.asInstanceOf[java.lang.Integer])
case SerializationDateTimeConversionModes.AlwaysEpochMillisecondsWithUtcTimezone =>
convertToJsonNodeConditionally(LocalDate
.ofEpochDay(rowData.asInstanceOf[java.lang.Integer].longValue())
.atStartOfDay()
.toInstant(ZoneOffset.UTC).toEpochMilli)
case SerializationDateTimeConversionModes.AlwaysEpochMillisecondsWithSystemDefaultTimezone =>
val localDate = LocalDate
.ofEpochDay(rowData.asInstanceOf[java.lang.Integer].longValue())
.atStartOfDay()
val localTimestampInstant = Timestamp.valueOf(localDate).toInstant
convertToJsonNodeConditionally(
localDate
.toInstant(java.time.ZoneId.systemDefault.getRules().getOffset(localTimestampInstant)).toEpochMilli)
}
case DateType => convertToJsonNodeConditionally(rowData.asInstanceOf[Date].getTime)
case TimestampType if rowData.isInstanceOf[java.lang.Long] =>
serializationConfig.serializationDateTimeConversionMode match {
case SerializationDateTimeConversionModes.Default =>
convertToJsonNodeConditionally(rowData.asInstanceOf[java.lang.Long])
case SerializationDateTimeConversionModes.AlwaysEpochMillisecondsWithUtcTimezone |
SerializationDateTimeConversionModes.AlwaysEpochMillisecondsWithSystemDefaultTimezone =>
val microsSinceEpoch = rowData.asInstanceOf[java.lang.Long]
convertToJsonNodeConditionally(
Instant.ofEpochSecond(
TimeUnit.MICROSECONDS.toSeconds(microsSinceEpoch),
TimeUnit.MICROSECONDS.toNanos(
Math.floorMod(microsSinceEpoch, TimeUnit.SECONDS.toMicros(1))
)
).toEpochMilli)
}
case TimestampType if rowData.isInstanceOf[java.lang.Integer] =>
serializationConfig.serializationDateTimeConversionMode match {
case SerializationDateTimeConversionModes.Default =>
convertToJsonNodeConditionally(rowData.asInstanceOf[java.lang.Integer])
case SerializationDateTimeConversionModes.AlwaysEpochMillisecondsWithUtcTimezone |
SerializationDateTimeConversionModes.AlwaysEpochMillisecondsWithSystemDefaultTimezone =>
val microsSinceEpoch = rowData.asInstanceOf[java.lang.Integer].longValue()
convertToJsonNodeConditionally(
Instant.ofEpochSecond(
TimeUnit.MICROSECONDS.toSeconds(microsSinceEpoch),
TimeUnit.MICROSECONDS.toNanos(
Math.floorMod(microsSinceEpoch, TimeUnit.SECONDS.toMicros(1))
)
).toEpochMilli)
}
case TimestampType => convertToJsonNodeConditionally(rowData.asInstanceOf[Timestamp].getTime)
case arrayType: ArrayType if rowData.isInstanceOf[ArrayData] =>
val arrayDataValue = rowData.asInstanceOf[ArrayData]
if (isDefaultValue(arrayDataValue)) {
None
} else {
Some(convertSparkArrayToArrayNode(arrayType.elementType, arrayType.containsNull, arrayDataValue))
}
case arrayType: ArrayType =>
val seqValue = rowData.asInstanceOf[Seq[_]]
if (isDefaultValue(seqValue)) {
None
} else {
Some(convertSparkArrayToArrayNode(arrayType.elementType, arrayType.containsNull, seqValue))
}
case structType: StructType => Some(rowTypeRouterToJsonArray(rowData, structType))
case mapType: MapType =>
mapType.keyType match {
case StringType if rowData.isInstanceOf[Map[_, _]] =>
val stringKeyMap = convertToStringKeyMap(rowData)
if (isDefaultValue(stringKeyMap)) {
None
} else {
Some(convertSparkMapToObjectNode(
mapType.valueType,
mapType.valueContainsNull,
stringKeyMap))
}
case StringType if rowData.isInstanceOf[UnsafeMapData] =>
val unsafeMapDataValue = rowData.asInstanceOf[UnsafeMapData]
if (isDefaultValue(unsafeMapDataValue)) {
None
} else {
Some(convertSparkMapToObjectNode(
mapType.valueType,
mapType.valueContainsNull,
unsafeMapDataValue))
}
case _ =>
throw new Exception(s"Cannot cast $rowData into a Json value. MapTypes "
+ s"must have keys of StringType for conversion Json")
}
case _ =>
convertSparkDataTypeToJsonNodeConditionallyForSparkRuntimeSpecificDataType(fieldType, rowData)
.getOrElse(throw new IllegalArgumentException(s"Unsupported data type $fieldType"))
throw new Exception(s"Cannot cast $rowData into a Json value. $fieldType has no matching Json value.")
}
}
protected[spark] def convertSparkDataTypeToJsonNodeConditionallyForSparkRuntimeSpecificDataType
(
fieldType: DataType,
rowData: Any
): Option[JsonNode] = {
None
}
private def convertSparkDataTypeToJsonNodeNonNull(fieldType: DataType, rowData: Any): JsonNode = {
fieldType match {
case StringType => objectMapper.convertValue(convertRowDataToString(rowData), classOf[JsonNode])
case BinaryType => objectMapper.convertValue(rowData.asInstanceOf[Array[Byte]], classOf[JsonNode])
case BooleanType => objectMapper.convertValue(rowData.asInstanceOf[Boolean], classOf[JsonNode])
case DoubleType => objectMapper.convertValue(rowData.asInstanceOf[Double], classOf[JsonNode])
case ShortType => objectMapper.convertValue(rowData.asInstanceOf[Short], classOf[JsonNode])
case ByteType => objectMapper.convertValue(rowData.asInstanceOf[Byte], classOf[JsonNode])
case IntegerType => objectMapper.convertValue(rowData.asInstanceOf[Int], classOf[JsonNode])
case LongType => objectMapper.convertValue(rowData.asInstanceOf[Long], classOf[JsonNode])
case FloatType => objectMapper.convertValue(rowData.asInstanceOf[Float], classOf[JsonNode])
case DecimalType() if rowData.isInstanceOf[Decimal] => objectMapper.convertValue(rowData.asInstanceOf[Decimal].toJavaBigDecimal, classOf[JsonNode])
case DecimalType() if rowData.isInstanceOf[Long] => objectMapper.convertValue(new java.math.BigDecimal(rowData.asInstanceOf[java.lang.Long]), classOf[JsonNode])
case DecimalType() => objectMapper.convertValue(rowData.asInstanceOf[java.math.BigDecimal], classOf[JsonNode])
case DateType if rowData.isInstanceOf[java.lang.Long] =>
serializationConfig.serializationDateTimeConversionMode match {
case SerializationDateTimeConversionModes.Default =>
objectMapper.convertValue(rowData.asInstanceOf[java.lang.Long], classOf[JsonNode])
case SerializationDateTimeConversionModes.AlwaysEpochMillisecondsWithUtcTimezone =>
objectMapper.convertValue(
LocalDate
.ofEpochDay(rowData.asInstanceOf[java.lang.Long])
.atStartOfDay()
.toInstant(ZoneOffset.UTC).toEpochMilli,
classOf[JsonNode])
case SerializationDateTimeConversionModes.AlwaysEpochMillisecondsWithSystemDefaultTimezone =>
val localDate = LocalDate
.ofEpochDay(rowData.asInstanceOf[java.lang.Long])
.atStartOfDay()
val localTimestampInstant = Timestamp.valueOf(localDate).toInstant
objectMapper.convertValue(
localDate
.toInstant(java.time.ZoneId.systemDefault.getRules().getOffset(localTimestampInstant)).toEpochMilli,
classOf[JsonNode])
}
case DateType if rowData.isInstanceOf[java.lang.Integer] =>
serializationConfig.serializationDateTimeConversionMode match {
case SerializationDateTimeConversionModes.Default =>
objectMapper.convertValue(rowData.asInstanceOf[java.lang.Integer], classOf[JsonNode])
case SerializationDateTimeConversionModes.AlwaysEpochMillisecondsWithUtcTimezone =>
objectMapper.convertValue(
LocalDate
.ofEpochDay(rowData.asInstanceOf[java.lang.Integer].longValue())
.atStartOfDay()
.toInstant(ZoneOffset.UTC).toEpochMilli,
classOf[JsonNode])
case SerializationDateTimeConversionModes.AlwaysEpochMillisecondsWithSystemDefaultTimezone =>
val localDate = LocalDate
.ofEpochDay(rowData.asInstanceOf[java.lang.Integer].longValue())
.atStartOfDay()
val localTimestampInstant = Timestamp.valueOf(localDate).toInstant
objectMapper.convertValue(
localDate
.toInstant(java.time.ZoneId.systemDefault.getRules().getOffset(localTimestampInstant)).toEpochMilli,
classOf[JsonNode])
}
case DateType => objectMapper.convertValue(rowData.asInstanceOf[Date].getTime, classOf[JsonNode])
case TimestampType if rowData.isInstanceOf[java.lang.Long] =>
serializationConfig.serializationDateTimeConversionMode match {
case SerializationDateTimeConversionModes.Default =>
objectMapper.convertValue(rowData.asInstanceOf[java.lang.Long], classOf[JsonNode])
case SerializationDateTimeConversionModes.AlwaysEpochMillisecondsWithUtcTimezone |
SerializationDateTimeConversionModes.AlwaysEpochMillisecondsWithSystemDefaultTimezone =>
val microsSinceEpoch = rowData.asInstanceOf[java.lang.Long]
objectMapper.convertValue(
Instant.ofEpochSecond(
TimeUnit.MICROSECONDS.toSeconds(microsSinceEpoch),
TimeUnit.MICROSECONDS.toNanos(
Math.floorMod(microsSinceEpoch, TimeUnit.SECONDS.toMicros(1))
)
).toEpochMilli,
classOf[JsonNode])
}
case TimestampType if rowData.isInstanceOf[java.lang.Integer] =>
serializationConfig.serializationDateTimeConversionMode match {
case SerializationDateTimeConversionModes.Default =>
objectMapper.convertValue(rowData.asInstanceOf[java.lang.Integer], classOf[JsonNode])
case SerializationDateTimeConversionModes.AlwaysEpochMillisecondsWithUtcTimezone |
SerializationDateTimeConversionModes.AlwaysEpochMillisecondsWithSystemDefaultTimezone =>
val microsSinceEpoch = rowData.asInstanceOf[java.lang.Integer].longValue()
objectMapper.convertValue(
Instant.ofEpochSecond(
TimeUnit.MICROSECONDS.toSeconds(microsSinceEpoch),
TimeUnit.MICROSECONDS.toNanos(
Math.floorMod(microsSinceEpoch, TimeUnit.SECONDS.toMicros(1))
)
).toEpochMilli,
classOf[JsonNode])
}
case TimestampType => objectMapper.convertValue(rowData.asInstanceOf[Timestamp].getTime, classOf[JsonNode])
case arrayType: ArrayType if rowData.isInstanceOf[ArrayData] => convertSparkArrayToArrayNode(arrayType.elementType, arrayType.containsNull, rowData.asInstanceOf[ArrayData])
case arrayType: ArrayType => convertSparkArrayToArrayNode(arrayType.elementType, arrayType.containsNull, rowData.asInstanceOf[Seq[_]])
case structType: StructType => rowTypeRouterToJsonArray(rowData, structType)
case mapType: MapType =>
mapType.keyType match {
case StringType if rowData.isInstanceOf[Map[_, _]] =>
val stringKeyMap = convertToStringKeyMap(rowData)
convertSparkMapToObjectNode(
mapType.valueType,
mapType.valueContainsNull,
stringKeyMap)
case StringType if rowData.isInstanceOf[UnsafeMapData] =>
convertSparkMapToObjectNode(
mapType.valueType,
mapType.valueContainsNull,
rowData.asInstanceOf[UnsafeMapData])
case _ =>
throw new Exception(s"Cannot cast $rowData into a Json value. MapTypes "
+ s"must have keys of StringType for conversion Json")
}
case _ =>
convertSparkDataTypeToJsonNodeNonNullForSparkRuntimeSpecificDataType(fieldType, rowData)
}
}
protected[spark] def convertSparkDataTypeToJsonNodeNonNullForSparkRuntimeSpecificDataType(fieldType: DataType, rowData: Any): JsonNode = {
throw new Exception(s"Cannot cast $rowData into a Json value. $fieldType has no matching Json value.")
}
private def putNullConditionally(objectNode: ObjectNode, fieldName: String) = {
if (serializationConfig.serializationInclusionMode == SerializationInclusionModes.Always) {
objectNode.putNull(fieldName)
}
}
private def convertSparkMapToObjectNode(elementType: DataType, containsNull: Boolean, data: Map[String, Any]): ObjectNode = {
val objectNode = objectMapper.createObjectNode()
data.foreach(x =>
if (containsNull && x._2 == null) {
putNullConditionally(objectNode, x._1)
}
else {
val nodeOpt = convertSparkSubItemToJsonNode(elementType, containsNull, x._2)
if (nodeOpt.isDefined) {
objectNode.set(x._1, nodeOpt.get)
}
})
objectNode
}
private def convertSparkMapToObjectNode(elementType: DataType, containsNull: Boolean, data: UnsafeMapData): ObjectNode = {
val objectNode = objectMapper.createObjectNode()
val keys: Array[String] = data.keyArray().toArray[UTF8String](StringType).map(_.toString)
val values: Array[AnyRef] = data.valueArray().toObjectArray(elementType)
keys.zip(values).toMap.foreach(x =>
if (containsNull && x._2 == null) {
putNullConditionally(objectNode, x._1)
}
else {
val nodeOpt = convertSparkSubItemToJsonNode(elementType, containsNull, x._2)
if (nodeOpt.isDefined) {
objectNode.set(x._1, nodeOpt.get)
}
})
objectNode
}
private def convertSparkArrayToArrayNode(elementType: DataType, containsNull: Boolean, data: Seq[Any]): ArrayNode = {
val arrayNode = objectMapper.createArrayNode()
data.foreach(value => writeSparkArrayDataToArrayNode(arrayNode, elementType, containsNull, value))
arrayNode
}
private def convertSparkArrayToArrayNode(elementType: DataType, containsNull: Boolean, data: ArrayData): ArrayNode = {
val arrayNode = objectMapper.createArrayNode()
data.foreach(elementType, (_, value)
=> writeSparkArrayDataToArrayNode(arrayNode, elementType, containsNull, value))
arrayNode
}
private def writeSparkArrayDataToArrayNode(arrayNode: ArrayNode,
elementType: DataType,
containsNull: Boolean,
value: Any): Unit = {
if (containsNull && value == null) {
arrayNode.add(objectMapper.nullNode())
}
else {
val nodeOpt = convertSparkSubItemToJsonNode(elementType, containsNull, value)
if (nodeOpt.isDefined) {
arrayNode.add(nodeOpt.get)
}
}
}
private def convertSparkSubItemToJsonNode
(
elementType: DataType,
containsNull: Boolean,
data: Any
): Option[JsonNode] = {
if (serializationConfig.serializationInclusionMode == SerializationInclusionModes.NonEmpty ||
serializationConfig.serializationInclusionMode == SerializationInclusionModes.NonDefault) {
convertSparkSubItemToJsonNodeConditionally(elementType, containsNull, data: Any)
} else {
Some(convertSparkSubItemToJsonNodeNonNull(elementType, containsNull, data: Any))
}
}
private def convertSparkSubItemToJsonNodeNonNull
(
elementType: DataType,
containsNull: Boolean,
data: Any
): JsonNode = {
elementType match {
case subDocuments: StructType => rowTypeRouterToJsonArray(data, subDocuments)
case subArray: ArrayType if data.isInstanceOf[ArrayData]
=> convertSparkArrayToArrayNode(subArray.elementType, containsNull, data.asInstanceOf[ArrayData])
case subArray: ArrayType
=> convertSparkArrayToArrayNode(subArray.elementType, containsNull, data.asInstanceOf[Seq[_]])
case _ => convertSparkDataTypeToJsonNodeNonNull(elementType, data)
}
}
private def convertSparkSubItemToJsonNodeConditionally
(
elementType: DataType,
containsNull: Boolean,
data: Any
): Option[JsonNode] = {
elementType match {
case subDocuments: StructType => Some(rowTypeRouterToJsonArray(data, subDocuments))
case subArray: ArrayType if data.isInstanceOf[ArrayData] =>
val arrayDataValue = data.asInstanceOf[ArrayData]
if (isDefaultValue(arrayDataValue)) {
None
} else {
Some(convertSparkArrayToArrayNode(subArray.elementType, containsNull, arrayDataValue))
}
case subArray: ArrayType =>
val sequenceData = data.asInstanceOf[Seq[_]]
if (isDefaultValue(sequenceData)) {
None
} else {
Some(convertSparkArrayToArrayNode(subArray.elementType, containsNull, sequenceData))
}
case _ => convertSparkDataTypeToJsonNodeConditionally(elementType, data)
}
}
private def rowTypeRouterToJsonArray(element: Any, schema: StructType): ObjectNode = {
element match {
case e: Row => fromRowToObjectNode(e)
case e: InternalRow => fromInternalRowToObjectNode(e, schema)
case _ => throw new Exception(s"Cannot cast $element into a Json value. Struct $element has no matching Json value.")
}
}
private def getAttributeNodeAsString(objectNode: ObjectNode, attributeName: String): String = {
objectNode.get(attributeName) match {
case jsonNode: JsonNode => jsonNode.toString
case _ => null
}
}
private def parseLsn(objectNode: ObjectNode): Long = {
objectNode.get(LsnAttributeName)
match {
case lsnNode: JsonNode =>
Option(lsnNode).fold(-1L)(v => v.asLong(-1))
case _ => -1L
}
}
private def parseTtlExpired(objectNode: ObjectNode): Boolean = {
objectNode.get(MetadataJsonBodyAttributeName) match {
case metadataNode: JsonNode =>
metadataNode.get(TimeToLiveExpiredPropertyName) match {
case valueNode: JsonNode =>
Option(valueNode).fold(false)(v => v.asBoolean(false))
case _ => false
}
case _ => false
}
}
private def parseId(objectNode: ObjectNode): String = {
val currentNode = getCurrentOrPreviousNode(objectNode)
currentNode.get(IdAttributeName) match {
case valueNode: JsonNode =>
Option(valueNode).fold(null: String)(v => v.asText(null))
case _ => null
}
}
private def parseTimestamp(objectNode: ObjectNode): Long = {
val currentNode = getCurrentOrPreviousNode(objectNode)
currentNode.get(TimestampAttributeName) match {
case valueNode: JsonNode =>
Option(valueNode).fold(-1L)(v => v.asLong(-1))
case _ => -1L
}
}
private def parseETag(objectNode: ObjectNode): String = {
val currentNode = getCurrentOrPreviousNode(objectNode)
currentNode.get(ETagAttributeName) match {
case valueNode: JsonNode =>
Option(valueNode).fold(null: String)(v => v.asText(null))
case _ => null
}
}
private def getCurrentOrPreviousNode(objectNode: ObjectNode): JsonNode = {
var currentNode = objectNode.get(CurrentAttributeName)
if (currentNode == null || currentNode.isEmpty) {
currentNode = objectNode.get(PreviousRawJsonBodyAttributeName)
}
currentNode
}
// For single-master, crts will always be same as _ts
// For multi-master, crts will be the latest resolution timestamp of any conflicts
private def parseCrts(objectNode: ObjectNode): Long = {
objectNode.get(MetadataJsonBodyAttributeName) match {
case metadataNode: JsonNode =>
metadataNode.get(CrtsAttributeName) match {
case valueNode: JsonNode =>
Option(valueNode).fold(-1L)(v => v.asLong(-1))
case _ => -1L
}
}
}
private def parseOperationType(objectNode: ObjectNode): String = {
objectNode.get(MetadataJsonBodyAttributeName) match {
case metadataNode: JsonNode =>
metadataNode.get(OperationTypeAttributeName) match {
case valueNode: JsonNode =>
Option(valueNode).fold(null: String)(v => v.asText(null))
case _ => null
}
case _ => null
}
}
private def parseMetadataLsn(objectNode: ObjectNode): Long = {
getChangeFeedLsn(objectNode) match {
case lsn: String => lsn.toLong
case _ => -1L
}
}
private def parsePreviousImageLsn(objectNode: ObjectNode): Long = {
objectNode.get(MetadataJsonBodyAttributeName) match {
case metadataNode: JsonNode =>
metadataNode.get(PreviousImageLsnAttributeName) match {
case lsnNode: JsonNode =>
Option(lsnNode).fold(-1L)(v => v.asLong(-1))
case _ => -1L
}
case _ => -1L
}
}
private def convertStructToSparkDataType(schema: StructType,
objectNode: ObjectNode,
schemaConversionMode: SchemaConversionMode): Seq[Any] =
schema.fields.map {
case StructField(CosmosTableSchemaInferrer.RawJsonBodyAttributeName, StringType, _, _) =>
objectNode.toString
case StructField(CosmosTableSchemaInferrer.PreviousRawJsonBodyAttributeName, StringType, _, _) =>
getAttributeNodeAsString(objectNode, PreviousRawJsonBodyAttributeName)
case StructField(CosmosTableSchemaInferrer.OperationTypeAttributeName, StringType, _, _) =>
parseOperationType(objectNode)
case StructField(CosmosTableSchemaInferrer.TtlExpiredAttributeName, BooleanType, _, _) =>
parseTtlExpired(objectNode)
case StructField(CosmosTableSchemaInferrer.LsnAttributeName, LongType, _, _) =>
parseLsn(objectNode)
case StructField(name, dataType, _, _) =>
Option(objectNode.get(name)).map(convertToSparkDataType(dataType, _, schemaConversionMode)).orNull
}
private def convertStructToSparkDataTypeWithComputedColumns(
schema: StructType,
objectNode: ObjectNode,
schemaConversionMode: SchemaConversionMode,
computedColumns: Map[String, (ObjectNode) => Any]): Seq[Any] =
schema.fields.map {
case StructField(CosmosTableSchemaInferrer.RawJsonBodyAttributeName, StringType, _, _) =>
objectNode.toString
case StructField(CosmosTableSchemaInferrer.PreviousRawJsonBodyAttributeName, StringType, _, _) =>
getAttributeNodeAsString(objectNode, PreviousRawJsonBodyAttributeName)
case StructField(CosmosTableSchemaInferrer.OperationTypeAttributeName, StringType, _, _) =>
parseOperationType(objectNode)
case StructField(CosmosTableSchemaInferrer.TtlExpiredAttributeName, BooleanType, _, _) =>
parseTtlExpired(objectNode)
case StructField(CosmosTableSchemaInferrer.LsnAttributeName, LongType, _, _) =>
parseLsn(objectNode)
case StructField(name, dataType, _, _) =>
Option(objectNode.get(name)).map(convertToSparkDataType(dataType, _, schemaConversionMode)).orElse {
computedColumns.get(name) match {
case Some(function) => Some(function.apply(objectNode))
case _ => None
}
}.orNull
}
private def convertStructToChangeFeedSparkDataTypeV1(schema: StructType,
objectNode: ObjectNode,
schemaConversionMode: SchemaConversionMode): Seq[Any] =
schema.fields.map {
case StructField(CosmosTableSchemaInferrer.RawJsonBodyAttributeName, StringType, _, _) =>
getAttributeNodeAsString(objectNode, CurrentAttributeName)
case StructField(CosmosTableSchemaInferrer.IdAttributeName, StringType, _, _) =>
parseId(objectNode)
case StructField(CosmosTableSchemaInferrer.TimestampAttributeName, LongType, _, _) =>
parseTimestamp(objectNode)
case StructField(CosmosTableSchemaInferrer.ETagAttributeName, StringType, _, _) =>
parseETag(objectNode)
case StructField(CosmosTableSchemaInferrer.LsnAttributeName, LongType, _, _) =>
parseMetadataLsn(objectNode)
case StructField(CosmosTableSchemaInferrer.MetadataJsonBodyAttributeName, StringType, _, _) =>
getAttributeNodeAsString(objectNode, MetadataJsonBodyAttributeName)
case StructField(CosmosTableSchemaInferrer.PreviousRawJsonBodyAttributeName, StringType, _, _) =>
getAttributeNodeAsString(objectNode, PreviousRawJsonBodyAttributeName)
case StructField(CosmosTableSchemaInferrer.OperationTypeAttributeName, StringType, _, _) =>
parseOperationType(objectNode)
case StructField(CosmosTableSchemaInferrer.CrtsAttributeName, LongType, _, _) =>
parseCrts(objectNode)
case StructField(CosmosTableSchemaInferrer.PreviousImageLsnAttributeName, LongType, _, _) =>
parsePreviousImageLsn(objectNode)
case StructField(name, dataType, _, _) =>
Option(objectNode.get(name)).map(convertToSparkDataType(dataType, _, schemaConversionMode)).orNull
}
private def convertToSparkDataType(dataType: DataType,
value: JsonNode,
schemaConversionMode: SchemaConversionMode): Any =
(value, dataType) match {
case (_: NullNode, _) | (_, _: NullType) => null
case (jsonNode: ObjectNode, struct: StructType) =>
fromObjectNodeToRow(struct, jsonNode, schemaConversionMode)
case (jsonNode: ObjectNode, map: MapType) =>
jsonNode.fields().asScala
.map(element => (
element.getKey,
convertToSparkDataType(map.valueType, element.getValue, schemaConversionMode))).toMap
case (jsonNode: ObjectNode, string: StringType) =>
jsonNode.toString
case (arrayNode: ArrayNode, array: ArrayType) =>
arrayNode.elements().asScala
.map(convertToSparkDataType(array.elementType, _, schemaConversionMode)).toArray
case (binaryNode: BinaryNode, _: BinaryType) =>
binaryNode.binaryValue()
case (arrayNode: ArrayNode, _: BinaryType) =>
// Assuming the array is of bytes
objectMapper.convertValue(arrayNode, classOf[Array[Byte]])
case (_, _: BooleanType) => value.asBoolean()
case (_, _: StringType) => value.asText()
case (_, _: DateType) => handleConversionErrors(() => toDate(value), schemaConversionMode)
case (_, _: TimestampType) => handleConversionErrors(() => toTimestamp(value), schemaConversionMode)
case (isJsonNumber(), DoubleType) => value.asDouble()
case (isJsonNumber(), DecimalType()) => value.decimalValue()
case (isJsonNumber(), FloatType) => value.asDouble()
case (isJsonNumber(), LongType) => value.asLong()
case (isJsonNumber(), _) => value.asInt()
case (textNode: TextNode, DoubleType) =>
handleConversionErrors(() => textNode.asText.toDouble, schemaConversionMode)
case (textNode: TextNode, DecimalType()) =>
handleConversionErrors(() => new java.math.BigDecimal(textNode.asText), schemaConversionMode)
case (textNode: TextNode, FloatType) =>
handleConversionErrors(() => textNode.asText.toFloat, schemaConversionMode)
case (textNode: TextNode, LongType) =>
handleConversionErrors(() => textNode.asText.toLong, schemaConversionMode)
case (textNode: TextNode, IntegerType) =>
handleConversionErrors(() => textNode.asText.toInt, schemaConversionMode)
case _ =>
if (schemaConversionMode == SchemaConversionModes.Relaxed) {
try {
convertToSparkDataTypeForSparkRuntimeSpecificDataType(dataType, value, schemaConversionMode)
}
catch {
case e:
Exception =>
this.logError(s"Unsupported datatype conversion [Value: $value] of ${value.getClass}] to $dataType]")
null
}
}
else {
throw new IllegalArgumentException(
s"Unsupported datatype conversion [Value: $value] of ${value.getClass}] to $dataType]")
}
}
protected[spark] def convertToSparkDataTypeForSparkRuntimeSpecificDataType
(dataType: DataType,
value: JsonNode,
schemaConversionMode: SchemaConversionMode): Any =
(value, dataType) match {
case _ =>
this.logError(s"Unsupported datatype conversion [Value: $value] of ${value.getClass}] to $dataType]")
null
}
private def handleConversionErrors[A] = (conversion: () => A,
schemaConversionMode: SchemaConversionMode) => {
Try(conversion()) match {
case Success(convertedValue) => convertedValue
case Failure(error) =>
if (schemaConversionMode == SchemaConversionModes.Relaxed) {
null
}
else {
throw error
}
}
}
private def toTimestamp(value: JsonNode): Timestamp = {
value match {
case isJsonNumber() => new Timestamp(value.asLong())
case textNode: TextNode =>
parseDateTimeFromString(textNode.asText()) match {
case Some(odt) => Timestamp.valueOf(odt.toLocalDateTime)
case None =>
throw new IllegalArgumentException(
s"Value '${textNode.asText()} cannot be parsed as Timestamp.")
}
case _ => Timestamp.valueOf(value.asText())
}
}
private def toDate(value: JsonNode): Date = {
value match {
case isJsonNumber() => new Date(value.asLong())
case textNode: TextNode =>
parseDateTimeFromString(textNode.asText()) match {
case Some(odt) => Date.valueOf(odt.toLocalDate)
case None =>
throw new IllegalArgumentException(
s"Value '${textNode.asText()} cannot be parsed as Date.")
}
case _ => Date.valueOf(value.asText())
}
}
private def parseDateTimeFromString(value: String): Option[OffsetDateTime] = {
try {
val odt = OffsetDateTime.parse(value, DateTimeFormatter.ISO_OFFSET_DATE_TIME) //yyyy-MM-ddTHH:mm:ss+01:00
Some(odt)
}
catch {
case _: Exception =>
try {
val odt = OffsetDateTime.parse(value, utcFormatter) //yyyy-MM-ddTHH:mm:ssZ
Some(odt)
}
catch {
case _: Exception => None
}
}
}
protected[spark] object isJsonNumber {
def unapply(x: JsonNode): Boolean = x match {
case _: com.fasterxml.jackson.databind.node.IntNode
| _: com.fasterxml.jackson.databind.node.DecimalNode
| _: com.fasterxml.jackson.databind.node.DoubleNode
| _: com.fasterxml.jackson.databind.node.FloatNode
| _: com.fasterxml.jackson.databind.node.LongNode => true
case _ => false
}
}
}
// scalastyle:on
© 2015 - 2025 Weber Informatics LLC | Privacy Policy