All Downloads are FREE. Search and download functionalities are using the official Maven repository.

format.specific.SpecificImporter.scala Maven / Gradle / Ivy

package avrohugger
package format
package specific

import avrohugger.format.abstractions.Importer
import avrohugger.matchers.TypeMatcher
import avrohugger.types._
import avrohugger.stores.SchemaStore
import org.apache.avro.Schema.Type.RECORD
import org.apache.avro.{Protocol, Schema}
import treehugger.forest._
import definitions._
import treehuggerDSL._

import scala.jdk.CollectionConverters._

object SpecificImporter extends Importer {

  def getImports(
    schemaOrProtocol: Either[Schema, Protocol],
    currentNamespace: Option[String],
    schemaStore: SchemaStore,
    typeMatcher: TypeMatcher): List[Import] = {
      
    val switchAnnotSymbol = RootClass.newClass("scala.annotation.switch")
    val switchImport = IMPORT(switchAnnotSymbol)
    val topLevelSchemas =
      getTopLevelSchemas(schemaOrProtocol, schemaStore, typeMatcher)
    val recordSchemas = getRecordSchemas(topLevelSchemas)
    val enumSchemas = getEnumSchemas(topLevelSchemas)
    val userDefinedDeps = getUserDefinedImports(recordSchemas ++ enumSchemas, currentNamespace, typeMatcher)
    val shapelessDeps = getShapelessImports(recordSchemas, typeMatcher)
    val libraryDeps = shapelessDeps
    
    schemaOrProtocol match {
      case Left(schema) => {
        if (schema.getType == RECORD) switchImport :: libraryDeps ::: userDefinedDeps
        else libraryDeps ++ userDefinedDeps
      }
      case Right(protocol) => {
        val types = protocol.getTypes().asScala.toList
        val messages = protocol.getMessages.asScala.toMap
        if (messages.isEmpty) switchImport :: libraryDeps ::: userDefinedDeps // for ADT
        else List.empty // for RPC
      }
    }
  }

  /**
    * Shapeless representations of Coproduct are by default only active when `union`
    * is defined with more than two types or three types where one of them is nullable.
    * Otherwise the default values require no special imports
    * since they are codegen in terms of [[Option]] and [[Either]]
    */
  private[this] def getShapelessImports(
    topLevelRecordSchemas: List[Schema],
    typeMatcher: TypeMatcher): List[Import] = {

    def determineShapelessCoproductImports(
      field: Schema.Field,
      schema: Schema,
      typeMatcher: TypeMatcher,
      potentialRecursives: List[Schema]): List[String] = schema.getType match {
      case Schema.Type.UNION  =>
        coproductImportsForUnionType(field, schema, typeMatcher) ++
          schema.getTypes().asScala.toList.flatMap(s =>
            determineShapelessCoproductImports(field, s, typeMatcher, potentialRecursives))
      case Schema.Type.ARRAY  =>
        determineShapelessCoproductImports(field, schema.getElementType(), typeMatcher, potentialRecursives)
      case Schema.Type.MAP    =>
        determineShapelessCoproductImports(field, schema.getValueType(), typeMatcher, potentialRecursives)
      case Schema.Type.RECORD =>
        schema.getFields().asScala.toList.flatMap(f => {
          if (potentialRecursives.map(_.getFullName).contains(schema.getFullName)) List.empty
          else determineShapelessCoproductImports(field, f.schema(), typeMatcher, potentialRecursives:+schema)
        })
      case _ =>
        List.empty[String]
    }

    def determineShapelessTagImport(
      schema: Schema,
      typeMatcher: TypeMatcher,
      potentialRecursives: List[Schema]): List[String] = schema.getType match {
      case Schema.Type.UNION  => schema.getTypes().asScala.toList.flatMap(s =>
                                   determineShapelessTagImport(s, typeMatcher, potentialRecursives))
      case Schema.Type.ARRAY  => determineShapelessTagImport(schema.getElementType(), typeMatcher, potentialRecursives)
      case Schema.Type.MAP    => determineShapelessTagImport(schema.getValueType(), typeMatcher, potentialRecursives)
      case Schema.Type.RECORD => schema.getFields().asScala.toList.flatMap(f => {
                                   if (potentialRecursives.map(_.getFullName).contains(schema.getFullName)) List.empty
                                   else determineShapelessTagImport(f.schema, typeMatcher, potentialRecursives:+schema)
                                 })
      case Schema.Type.BYTES  => importsForBigDecimalTagged(schema)
      case _ => List.empty[String]
    }

    def importsForBigDecimalTagged(schemas: Schema*): List[String] =
      schemas.find { schema =>
        schema.getType == Schema.Type.BYTES && LogicalType.foldLogicalTypes(
          schema = schema,
          default = false) {
            case Decimal(_, _) => typeMatcher.avroScalaTypes.decimal match {
              case ScalaBigDecimal(_) => false
              case ScalaBigDecimalWithPrecision(_) => true
            }
        }
      }.map(_ => List("tag.@@")).getOrElse(Nil)

    def coproductImportsForUnionType(
      field: Schema.Field,
      unionSchema: Schema,
      typeMatcher: TypeMatcher): List[String] = {
      val thresholdNonNullTypes = typeMatcher.avroScalaTypes.union match {
        case OptionalShapelessCoproduct => 0 // if a union contains at least one type, then it will need :+:
        case OptionShapelessCoproduct => 1
        case OptionEitherShapelessCoproduct => 2 // unions of one nullable type become Option, two become Either
      }
      val unionTypes = unionSchema.getTypes().asScala.toList
      val unionNonNullTypes = unionTypes.filterNot(_.getType == Schema.Type.NULL)
      val unionContainsNull: Boolean = unionNonNullTypes.length < unionTypes.length
      val isShapelessCoproduct: Boolean = unionNonNullTypes.length > thresholdNonNullTypes
      val hasDefaultValue: Boolean = !unionContainsNull
      val unionImports = if (isShapelessCoproduct && hasDefaultValue)
        List(":+:", "CNil", "Coproduct")
      else if (isShapelessCoproduct)
        List(":+:", "CNil")
      else
        List.empty[String]

      unionImports
    }
    val shapelessImport: List[String] => List[Import] = {
      case Nil          => Nil
      case head :: Nil  => List(IMPORT(RootClass.newClass(s"shapeless.$head")))
      case list         => List(IMPORT(RootClass.newClass(s"shapeless.{${list.mkString(", ")}}")))
    }
    val shapelessCopSymbols: List[String] =
      for {
        topLevelRecordSchema <- topLevelRecordSchemas
        field <- topLevelRecordSchema.getFields().asScala
        symbol <- determineShapelessCoproductImports(field, field.schema(), typeMatcher, List.empty[Schema])
      } yield symbol
    val shapelessTag: List[String] =
      for {
        topLevelRecordSchema <- topLevelRecordSchemas
        field <- topLevelRecordSchema.getFields().asScala
        symbol <- determineShapelessTagImport(field.schema(), typeMatcher, List.empty[Schema])
      } yield symbol
      
    shapelessImport(shapelessCopSymbols.distinct) ++
      shapelessImport(shapelessTag.distinct)
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy