All Downloads are FREE. Search and download functionalities are using the official Maven repository.

format.standard.StandardImporter.scala Maven / Gradle / Ivy

The newest version!
package avrohugger
package format
package standard

import avrohugger.format.abstractions.Importer
import avrohugger.matchers.TypeMatcher
import avrohugger.stores.SchemaStore
import avrohugger.types._
import org.apache.avro.{ Protocol, Schema }
import treehugger.forest._
import definitions._
import treehuggerDSL._

import scala.jdk.CollectionConverters._

object StandardImporter extends Importer {

  /**
    * Shapeless representations of Coproduct are by default only active when `union`
    * is defined with more than two types or three types where one of them is nullable.
    * Otherwise the default values require no special imports
    * since they are codegen in terms of [[Option]] and [[Either]]
    */
  private[this] def getShapelessImports(
    topLevelRecordSchemas: List[Schema],
    typeMatcher: TypeMatcher,
    unionType: ShapelessUnionType): List[Import] = {

    def determineShapelessCoproductImports(
      field: Schema.Field,
      schema: Schema,
      potentialRecursives: Set[String]): List[String] = schema.getType match {
      case Schema.Type.UNION =>
        coproductImportsForUnionType(field, schema) ++
          schema.getTypes().asScala.toList.flatMap(s =>
            determineShapelessCoproductImports(field, s, potentialRecursives))
      case Schema.Type.ARRAY =>
        determineShapelessCoproductImports(field, schema.getElementType(), potentialRecursives)
      case Schema.Type.MAP =>
        determineShapelessCoproductImports(field, schema.getValueType(), potentialRecursives)
      case Schema.Type.RECORD =>
        schema.getFields().asScala.toList.flatMap(f => {
          if (potentialRecursives.contains(schema.getFullName)) List.empty
          else determineShapelessCoproductImports(field, f.schema(), potentialRecursives + schema.getFullName)
        })
      case _ =>
        List.empty[String]
    }

    def determineShapelessTagImport(
      schema: Schema,
      potentialRecursives: Set[String]): List[String] = schema.getType match {
      case Schema.Type.UNION => schema.getTypes().asScala.toList.flatMap(s =>
        determineShapelessTagImport(s, potentialRecursives))
      case Schema.Type.ARRAY => determineShapelessTagImport(schema.getElementType(), potentialRecursives)
      case Schema.Type.MAP => determineShapelessTagImport(schema.getValueType(), potentialRecursives)
      case Schema.Type.RECORD => schema.getFields().asScala.toList.flatMap(f => {
        if (potentialRecursives.contains(schema.getFullName)) List.empty
        else determineShapelessTagImport(f.schema, potentialRecursives + schema.getFullName)
      })
      case Schema.Type.BYTES => importsForBigDecimalTagged(schema)
      case _ => List.empty[String]
    }

    def importsForBigDecimalTagged(schemas: Schema*): List[String] =
      schemas.find { schema =>
        schema.getType == Schema.Type.BYTES && LogicalType.foldLogicalTypes(
          schema = schema,
          default = false) {
          case Decimal(_, _) => typeMatcher.avroScalaTypes.decimal match {
            case ScalaBigDecimal(_) => false
            case ScalaBigDecimalWithPrecision(_) => true
          }
        }
      }.map(_ => List("tag.@@")).getOrElse(Nil)

    def coproductImportsForUnionType(
      field: Schema.Field,
      unionSchema: Schema): List[String] = {
      val thresholdNonNullTypes = unionType match {
        case OptionalShapelessCoproduct => 0 // if a union contains at least one type, then it will need :+:
        case OptionShapelessCoproduct => 1
        case OptionEitherShapelessCoproduct => 2 // unions of one nullable type become Option, two become Either
      }
      val unionContainsNull: Schema => Boolean = _.getType == Schema.Type.NULL

      def shapelessCoproductTest(unionTypes: List[Schema], maxNonNullTypes: Int): Boolean =
        (unionTypes.length > maxNonNullTypes && !unionTypes.exists(unionContainsNull)) ||
          (unionTypes.length > maxNonNullTypes + 1 && unionTypes.exists(unionContainsNull))

      def defaultValueTest(field: Schema.Field, unionTypes: List[Schema], maxNonNullTypes: Int) =
        (unionTypes.length > maxNonNullTypes && !unionTypes.exists(unionContainsNull) && field.hasDefaultValue) ||
          (unionTypes.length > maxNonNullTypes + 1 && unionTypes.exists(unionContainsNull) && field.hasDefaultValue)

      val unionTypes = unionSchema.getTypes().asScala.toList
      val isShapelessCoproduct: Boolean = shapelessCoproductTest(unionTypes, thresholdNonNullTypes)
      val hasDefaultValue: Boolean = defaultValueTest(field, unionTypes, thresholdNonNullTypes)
      val unionImports = if (isShapelessCoproduct && hasDefaultValue)
        List(":+:", "CNil", "Coproduct")
      else if (isShapelessCoproduct)
        List(":+:", "CNil")
      else
        List.empty[String]

      unionImports
    }

    val shapelessImport: List[String] => List[Import] = {
      case Nil => Nil
      case head :: Nil => List(IMPORT(RootClass.newClass(s"shapeless.$head")))
      case list => List(IMPORT(RootClass.newClass(s"shapeless.{${list.mkString(", ")}}")))
    }
    val shapelessCopSymbols: List[String] =
      for {
        topLevelRecordSchema <- topLevelRecordSchemas
        field <- topLevelRecordSchema.getFields().asScala
        symbol <- determineShapelessCoproductImports(field, field.schema(), Set())
      } yield symbol
    val shapelessTag: List[String] =
      for {
        topLevelRecordSchema <- topLevelRecordSchemas
        field <- topLevelRecordSchema.getFields().asScala
        symbol <- determineShapelessTagImport(field.schema(), Set())
      } yield symbol

    shapelessImport(shapelessCopSymbols.distinct) ++
      shapelessImport(shapelessTag.distinct)
  }

  def getImports(
    schemaOrProtocol: Either[Schema, Protocol],
    currentNamespace: Option[String],
    schemaStore: SchemaStore,
    typeMatcher: TypeMatcher): List[Import] = {
    val topLevelSchemas = getTopLevelSchemas(schemaOrProtocol, schemaStore, typeMatcher)
    val recordSchemas = getRecordSchemas(topLevelSchemas)
    val fixedSchemas = getFixedSchemas(topLevelSchemas)
    val enumSchemas = getEnumSchemas(topLevelSchemas)
    val userDefinedDeps = getUserDefinedImports(recordSchemas ++ fixedSchemas ++ enumSchemas, currentNamespace, typeMatcher)
    val shapelessDeps = typeMatcher.avroScalaTypes.union match {
      case unionType: ShapelessUnionType => getShapelessImports(recordSchemas, typeMatcher, unionType)
      case OptionScala3UnionType => List()
    }
    val libraryDeps = shapelessDeps
    libraryDeps ++ userDefinedDeps
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy