All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.azure.cosmos.spark.CosmosRowConverter.scala Maven / Gradle / Ivy

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package com.azure.cosmos.spark

import com.azure.cosmos.spark.SchemaConversionModes.SchemaConversionMode
import com.fasterxml.jackson.annotation.JsonInclude.Include
// scalastyle:off underscore.import
import com.fasterxml.jackson.databind.node._
import com.fasterxml.jackson.databind.{JsonNode, ObjectMapper}
import java.time.format.DateTimeFormatter
import java.time.LocalDateTime
import scala.collection.concurrent.TrieMap

// scalastyle:off underscore.import
import org.apache.spark.sql.types._
// scalastyle:on underscore.import

import scala.util.{Try, Success, Failure}

// scalastyle:off
private[cosmos] object CosmosRowConverter {

    // TODO: Expose configuration to handle duplicate fields
    // See: https://github.com/Azure/azure-sdk-for-java/pull/18642#discussion_r558638474
    private val rowConverterMap = new TrieMap[CosmosSerializationConfig, CosmosRowConverter]

    def get(serializationConfig: CosmosSerializationConfig): CosmosRowConverter = {
        rowConverterMap.get(serializationConfig) match {
            case Some(existingRowConverter) => existingRowConverter
            case None =>
                val newRowConverterCandidate = createRowConverter(serializationConfig)
                rowConverterMap.putIfAbsent(serializationConfig, newRowConverterCandidate) match {
                    case Some(existingConcurrentlyCreatedRowConverter) => existingConcurrentlyCreatedRowConverter
                    case None => newRowConverterCandidate
                }
        }
    }

    private def createRowConverter(serializationConfig: CosmosSerializationConfig): CosmosRowConverter = {
        val objectMapper = new ObjectMapper()
        import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule
        objectMapper.registerModule(new JavaTimeModule)
        serializationConfig.serializationInclusionMode match {
            case SerializationInclusionModes.NonNull => objectMapper.setSerializationInclusion(Include.NON_NULL)
            case SerializationInclusionModes.NonEmpty => objectMapper.setSerializationInclusion(Include.NON_EMPTY)
            case SerializationInclusionModes.NonDefault => objectMapper.setSerializationInclusion(Include.NON_DEFAULT)
            case _ => objectMapper.setSerializationInclusion(Include.ALWAYS)
        }

        new CosmosRowConverter(objectMapper, serializationConfig)
    }
}

private[cosmos] class CosmosRowConverter(private val objectMapper: ObjectMapper, private val serializationConfig: CosmosSerializationConfig)
    extends CosmosRowConverterBase(objectMapper, serializationConfig) {

    override def convertSparkDataTypeToJsonNodeConditionallyForSparkRuntimeSpecificDataType
    (
        fieldType: DataType,
        rowData: Any
    ): Option[JsonNode] = {
        fieldType match {
            case TimestampNTZType if rowData.isInstanceOf[java.time.LocalDateTime] => convertToJsonNodeConditionally(rowData.asInstanceOf[java.time.LocalDateTime].toString)
            case _ =>
                throw new Exception(s"Cannot cast $rowData into a Json value. $fieldType has no matching Json value.")
        }
    }

    override def convertSparkDataTypeToJsonNodeNonNullForSparkRuntimeSpecificDataType(fieldType: DataType, rowData: Any): JsonNode = {
        fieldType match {
            case TimestampNTZType if rowData.isInstanceOf[java.time.LocalDateTime] => objectMapper.convertValue(rowData.asInstanceOf[java.time.LocalDateTime].toString, classOf[JsonNode])
            case _ =>
                throw new Exception(s"Cannot cast $rowData into a Json value. $fieldType has no matching Json value.")
        }
    }

    override def convertToSparkDataTypeForSparkRuntimeSpecificDataType
    (dataType: DataType,
     value: JsonNode,
     schemaConversionMode: SchemaConversionMode): Any =
        (value, dataType) match {
            case (_, _: TimestampNTZType) => handleConversionErrors(() => toTimestampNTZ(value), schemaConversionMode)
            case _ =>
                throw new IllegalArgumentException(
                    s"Unsupported datatype conversion [Value: $value] of ${value.getClass}] to $dataType]")
        }


    def toTimestampNTZ(value: JsonNode): LocalDateTime = {
        value match {
            case isJsonNumber() => LocalDateTime.parse(value.asText())
            case textNode: TextNode =>
                parseDateTimeNTZFromString(textNode.asText()) match {
                    case Some(odt) => odt
                    case None =>
                        throw new IllegalArgumentException(
                            s"Value '${textNode.asText()} cannot be parsed as LocalDateTime (TIMESTAMP_NTZ).")
                }
            case _ => LocalDateTime.parse(value.asText())
        }
    }

    private def handleConversionErrors[A] = (conversion: () => A,
                                             schemaConversionMode: SchemaConversionMode) => {
        Try(conversion()) match {
            case Success(convertedValue) => convertedValue
            case Failure(error) =>
                if (schemaConversionMode == SchemaConversionModes.Relaxed) {
                    null
                }
                else {
                    throw error
                }
        }
    }

    def parseDateTimeNTZFromString(value: String): Option[LocalDateTime] = {
        try {
            val odt = LocalDateTime.parse(value, DateTimeFormatter.ISO_DATE_TIME)
            Some(odt)
        }
        catch {
            case _: Exception =>
                try {
                    val odt = LocalDateTime.parse(value, DateTimeFormatter.ISO_DATE_TIME)
                    Some(odt) 
                }
                catch {
                    case _: Exception => None
                }
        }
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy