All Downloads are FREE. Search and download functionalities are using the official Maven repository.

matchers.DefaultValueMatcher.scala Maven / Gradle / Ivy

package avrohugger
package matchers

import avrohugger.matchers.custom.CustomDefaultParamMatcher
import avrohugger.matchers.custom.CustomDefaultValueMatcher
import avrohugger.stores.ClassStore
import avrohugger.types._
import java.nio.charset.StandardCharsets

import treehugger.forest._
import definitions._
import org.apache.avro.Schema
import com.fasterxml.jackson.databind.JsonNode
import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.databind.node.{NullNode, ObjectNode, TextNode}
import treehugger.forest
import treehuggerDSL._
import scala.util.Try

import scala.jdk.CollectionConverters._

object DefaultValueMatcher {

  val nullNode = new TextNode("null")

  // This code was stolen from here:
  // https://github.com/julianpeeters/avro-scala-macro-annotations/blob/104fa325a00044ff6d31184fa7ff7b6852e9acd5/macros/src/main/scala/avro/scala/macro/annotations/provider/matchers/FromJsonMatcher.scala
  def getDefaultValue(
    classStore: ClassStore,
    namespace: Option[String],
    field: Schema.Field,
    typeMatcher: TypeMatcher,
    useFullName: Boolean = false): Tree = {

    val nullNode = new TextNode("null")

    def fromJsonNode(node: JsonNode, schema: Schema): Tree = {
      schema.getType match {
        case _ if node == null => EmptyTree //not `default=null`, but no default
        case Schema.Type.INT =>
          LogicalType.foldLogicalTypes[Tree](
            schema = schema,
            default = LIT(node.intValue())) {
            case Date =>
              CustomDefaultValueMatcher.checkCustomDateType(
                node.longValue(),
                typeMatcher.avroScalaTypes.date)
            case TimeMillis =>
              CustomDefaultValueMatcher.checkCustomTimeMillisType(
                node.longValue(),
                typeMatcher.avroScalaTypes.timeMillis)
          }
        case Schema.Type.FLOAT => LIT(node.doubleValue().asInstanceOf[Float])
        case Schema.Type.LONG =>
          LogicalType.foldLogicalTypes[Tree](
            schema = schema,
            default = LIT(node.longValue())) {
            case TimestampMillis =>
            CustomDefaultValueMatcher.checkCustomTimestampMillisType(
              node.longValue(),
              typeMatcher.avroScalaTypes.timestampMillis)
          }
        case Schema.Type.DOUBLE => LIT(node.doubleValue())
        case Schema.Type.BOOLEAN => LIT(node.booleanValue())
        case Schema.Type.STRING =>
          LogicalType.foldLogicalTypes[Tree](
            schema = schema,
            default = LIT(node.textValue())) {
              case UUID => REF("java.util.UUID.fromString") APPLY LIT(node.textValue())
            }
        case Schema.Type.BYTES =>
          CustomDefaultParamMatcher.checkCustomDecimalType(
            decimalType = typeMatcher.avroScalaTypes.decimal,
            schema = schema,
            default = REF("Array[Byte]") APPLY node.textValue().getBytes.map((e: Byte) => LIT(e)),
            decimalValue =
              Try(new org.apache.avro.Conversions.DecimalConversion().fromBytes(
                java.nio.ByteBuffer.wrap(node.textValue().getBytes(StandardCharsets.UTF_8)),
                schema,
                schema.getLogicalType).toString
              ).toOption
          )
        case Schema.Type.ENUM => {
          val refName = if (useFullName) schema.getFullName else schema.getName
          typeMatcher.avroScalaTypes.`enum` match {
            case JavaEnum => (REF(refName) DOT node.textValue())
            case ScalaEnumeration => (REF(refName) DOT node.textValue())
            case ScalaCaseObjectEnum => (REF(refName) DOT node.textValue())
            case EnumAsScalaString => LIT(node.textValue())
          }
        }
        case Schema.Type.NULL => LIT(null)
        case Schema.Type.UNION => {
          val unionSchemas = schema.getTypes().asScala.toList
          val result = unionDefaultArgsImpl(node, unionSchemas, fromJsonNode, typeMatcher, classStore, namespace)
          result
        }
        case Schema.Type.ARRAY => {
          val collectionType = CustomDefaultParamMatcher.checkCustomArrayType(typeMatcher.avroScalaTypes.array)
          collectionType APPLY(node.elements().asScala.toSeq.map(e => fromJsonNode(e, schema.getElementType)))
        }
        case Schema.Type.MAP => {
          val kvps = node.fields().asScala.toList.map(e => LIT(e.getKey) ANY_-> fromJsonNode(e.getValue, schema.getValueType))
          MAKE_MAP(kvps)
        }
        case Schema.Type.RECORD  => {
          val fields  = schema.getFields
          val jsObject = node match {
            case t: TextNode =>
              val mapper = new ObjectMapper();
              mapper.readValue(t.textValue(), classOf[ObjectNode])
            case o: ObjectNode => o
            case _ => throw new Exception(s"Invalid default value for field: $field, value: $node")
          }

          val fieldValues = fields.asScala.map { f =>
            fromJsonNode(jsObject.get(f.name), f.schema)
          }
          NEW(schema.getName, fieldValues: _*)
        }
        case Schema.Type.FIXED => {
           REF(schema.getName()) APPLY(
            CustomDefaultParamMatcher.checkCustomDecimalType(
              decimalType = typeMatcher.avroScalaTypes.decimal,
              schema = schema,
              default = REF("Array[Byte]") APPLY node.textValue().getBytes.map((e: Byte) => LIT(e)),
              decimalValue =
                Try(new org.apache.avro.Conversions.DecimalConversion().fromBytes(
                  java.nio.ByteBuffer.wrap(node.textValue().getBytes(StandardCharsets.UTF_8)),
                  schema,
                  schema.getLogicalType).toString
                ).toOption
            )
          )
        }
      }
    }
    val defaultValue = org.apache.avro.util.internal.Accessor.defaultValue(field)
    fromJsonNode(defaultValue, field.schema)
  }

  /**
    * Handles unions default values.
    *
    * Per the avro spec:
    * (Note that when a default value is specified for a record field whose type is a union,
    * the type of the default value must match the first element of the union)
    */
  private[this] def unionDefaultArgsImpl(node: JsonNode,
                                         unionSchemas: List[Schema],
                                         treeMatcher: (JsonNode, Schema) => Tree,
                                         typeMatcher: TypeMatcher,
                                         classStore: ClassStore,
                                         namespace: Option[String]) : Tree = {

    def COPRODUCT(defaultParam: Schema, tp: List[Type]): Tree =  {
      val copTypes = tp :+ typeRef(RootClass.newClass(newTypeName("CNil")))
      val chain: forest.Tree = INFIX_CHAIN(":+:", copTypes.map(t => Ident(t.safeToString)))
      val chainedS = treeToString(chain)
      val copType = typeRef(RootClass.newClass(newTypeName(chainedS)))
      REF("shapeless.Coproduct") APPLYTYPE copType APPLY treeMatcher(node, defaultParam)
    }

    val includesNull: Boolean = unionSchemas.exists(_.getType == Schema.Type.NULL)

    val nonNullableSchemas: List[Schema] = unionSchemas.filter(_.getType != Schema.Type.NULL)

    def unionsAsOptionalShapelessCoproductStrategy = nonNullableSchemas match {
      case firstSchema :: _ => //Coproduct
        COPRODUCT(firstSchema,
          nonNullableSchemas
            .map(typeMatcher.toScalaType(classStore, namespace, _)))
      case _ => throw new Exception("unrecognized shape for shapeless coproduct")
    }

    def unionsArityStrategy(
      classStore: ClassStore,
      namespace: Option[String],
      typeMatcher: TypeMatcher,
      useEither: Boolean) =
      nonNullableSchemas match {
        case Nil =>
          UNIT
        case List(schemaA) => //Option
          treeMatcher(node, schemaA)
        case List(schemaA, _) if useEither => //Either
          LEFT(treeMatcher(node, schemaA))
        case firstSchema :: _ => //Coproduct
          COPRODUCT(firstSchema,
            nonNullableSchemas
              .map(typeMatcher.toScalaType(classStore, namespace, _)))
      }

    def matchedTree(classStore: ClassStore, namespace: Option[String]) = typeMatcher.avroScalaTypes.union match {
      case OptionShapelessCoproduct | OptionEitherShapelessCoproduct =>
        unionsArityStrategy(classStore, namespace, typeMatcher, typeMatcher.avroScalaTypes.union.useEitherForTwoNonNullTypes)
      case OptionalShapelessCoproduct => unionsAsOptionalShapelessCoproductStrategy
    }

    node match {
      case `nullNode` => NONE
      case _ : NullNode => NONE
      case _ if includesNull => SOME(matchedTree(classStore, namespace))
      case _ => matchedTree(classStore, namespace)
    }
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy