All Downloads are FREE. Search and download functionalities are using the official Maven repository.

format.abstractions.Importer.scala Maven / Gradle / Ivy

package avrohugger
package format
package abstractions

import avrohugger.input.DependencyInspector
import avrohugger.input.NestedSchemaExtractor
import avrohugger.matchers.TypeMatcher
import avrohugger.matchers.custom.CustomNamespaceMatcher
import avrohugger.stores.SchemaStore

import org.apache.avro.{ Schema, Protocol }
import org.apache.avro.Schema.Type.{ ENUM, RECORD, UNION, MAP, ARRAY, FIXED }

import treehugger.forest._
import definitions.RootClass
import treehuggerDSL._

import scala.jdk.CollectionConverters._

/** Parent to all ouput formats' importers
  *
  * _ABSTRACT MEMBERS_: to be implemented by a subclass
  * getImports
  *
  * _CONCRETE MEMBERS_: implementations to be inherited by a subclass
  * getEnumSchemas
  * getFieldSchemas
  * getUserDefinedImports
  * getRecordSchemas
  * getTopLevelSchemas
  * isEnum
  * isRecord
  */
trait Importer {

  ///////////////////////////// abstract members ///////////////////////////////
  def getImports(
    schemaOrProtocol: Either[Schema, Protocol],
    currentNamespace: Option[String],
    schemaStore: SchemaStore,
    typeMatcher: TypeMatcher): List[Import]

  ////////////////////////////// concrete members //////////////////////////////
  // gets enum schemas which may be dependencies
  def getEnumSchemas(
    topLevelSchemas: List[Schema],
    alreadyImported: List[Schema] = List.empty[Schema]): List[Schema] = {
    def nextSchemas(s: Schema, us: List[Schema]) = getRecordSchemas(List(s), us)
    topLevelSchemas
      .flatMap(schema => {
        schema.getType match {
          case RECORD =>
            val fieldSchemasWithChildSchemas = getFieldSchemas(schema).toSeq
              .filter(s => alreadyImported.contains(s))
              .flatMap(s => nextSchemas(s, alreadyImported :+ s))
            Seq(schema) ++ fieldSchemasWithChildSchemas
          case ENUM =>
            Seq(schema)
          case UNION =>
            schema.getTypes().asScala
              .find(s => s.getType != NULL).toSeq
              .filter(s => alreadyImported.contains(s))
              .flatMap(s => nextSchemas(schema, alreadyImported :+ s))
          case MAP =>
            Seq(schema.getValueType)
              .filter(s => alreadyImported.contains(s))
              .flatMap(s => nextSchemas(schema, alreadyImported :+ s))
          case ARRAY =>
            Seq(schema.getElementType)
              .filter(s => alreadyImported.contains(s))
              .flatMap(s => nextSchemas(schema, alreadyImported :+ s))
          case _ =>
            Seq.empty[Schema]
        }
      })
      .filter(schema => schema.getType == ENUM)
      .distinct
      .toList
  }

  def getFixedSchemas(topLevelSchemas: List[Schema]): List[Schema] =
    topLevelSchemas
      .flatMap(schema => {
        schema.getType match {
          case FIXED => Seq(schema)
          case _ => Seq.empty[Schema]
        }
      })
      .filter(_.getType == FIXED)
      .distinct
      .toList
  
  def getFieldSchemas(schema: Schema): List[Schema] = {
    schema.getFields().asScala.toList.map(field => field.schema)
  }

  def getTypeSchemas(schema: Schema): List[Schema] = {
    schema.getTypes().asScala.toList
  }

  def getUserDefinedImports(
    recordSchemas: List[Schema],
    namespace: Option[String],
    typeMatcher: TypeMatcher): List[Import] = {

    def checkNamespace(schema: Schema): Option[String] = {
      val maybeReferredNamespace =
        DependencyInspector.getReferredNamespace(schema)
      CustomNamespaceMatcher.checkCustomNamespace(
        maybeReferredNamespace,
        typeMatcher,
        maybeDefaultNamespace = maybeReferredNamespace)
    }

    def asImportDef(packageName: String, fields: List[Schema]): Import = {
      val maybeUpdatedPackageName = CustomNamespaceMatcher.checkCustomNamespace(
        Some(packageName),
        typeMatcher,
        maybeDefaultNamespace = Some(packageName))
      val updatedPkg = maybeUpdatedPackageName.getOrElse(packageName)
      val importedPackageSym = RootClass.newClass(updatedPkg)
      val importedTypes =
        fields.map(field => DependencyInspector.getReferredTypeName(field))
      IMPORT(importedPackageSym, importedTypes)
    }

    def requiresImportDef(schema: Schema): Boolean = {
      (isRecord(schema) || isEnum(schema) || isFixed(schema)) &&
      checkNamespace(schema).isDefined     &&
      checkNamespace(schema) != namespace
    }

    recordSchemas
      .filter(schema => requiresImportDef(schema))
      .groupBy(schema => checkNamespace(schema).getOrElse(schema.getNamespace))
      .toList
      .map(group => group match {
        case(packageName, fields) => asImportDef(packageName, fields)
      })
  }

  // gets record schemas which may be dependencies
  def getRecordSchemas(
    topLevelSchemas: List[Schema],
    alreadyImported: List[Schema] = List.empty[Schema]): List[Schema] = {
    def nextSchemas(s: Schema, us: List[Schema]) = getRecordSchemas(List(s), us)
    topLevelSchemas
      .flatMap(schema => {
        schema.getType match {
          case RECORD =>
            val fieldSchemasWithChildSchemas = getFieldSchemas(schema).toSeq
              .filter(s => alreadyImported.contains(s))
              .flatMap(s => nextSchemas(s, alreadyImported :+ s))
            Seq(schema) ++ fieldSchemasWithChildSchemas
          case ENUM =>
            Seq(schema)
          case UNION =>
            schema.getTypes().asScala
              .find(s => s.getType != NULL).toSeq
              .filter(s => alreadyImported.contains(s))
              .flatMap(s => nextSchemas(schema, alreadyImported :+ s))
          case MAP =>
            Seq(schema.getValueType)
              .filter(s => alreadyImported.contains(s))
              .flatMap(s => nextSchemas(schema, alreadyImported :+ s))
          case ARRAY =>
            Seq(schema.getElementType)
              .filter(s => alreadyImported.contains(s))
              .flatMap(s => nextSchemas(schema, alreadyImported :+ s))
          case _ =>
            Seq.empty[Schema]
        }
      })
      .filter(schema => isRecord(schema))
      .distinct
      .toList
  }

  def getTopLevelSchemas(
    schemaOrProtocol: Either[Schema,  Protocol],
    schemaStore: SchemaStore,
    typeMatcher: TypeMatcher): List[Schema] = {
    schemaOrProtocol match {
      case Left(schema) =>
        schema::(NestedSchemaExtractor.getNestedSchemas(schema, schemaStore, typeMatcher))
      case Right(protocol) => protocol.getTypes().asScala.toList.flatMap(schema => {
        schema::(NestedSchemaExtractor.getNestedSchemas(schema, schemaStore, typeMatcher))
      })
    }

  }

  def isFixed(schema: Schema): Boolean = ( schema.getType == FIXED )

  def isEnum(schema: Schema): Boolean = ( schema.getType == ENUM )

  def isRecord(schema: Schema): Boolean = ( schema.getType == RECORD )

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy