All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.softwaremill.macwire.autocats.internals.CatsProvidersGraphContext.scala Maven / Gradle / Ivy

The newest version!
package com.softwaremill.macwire.autocats.internals

import scala.reflect.macros.blackbox
import com.softwaremill.macwire.internals._
import cats.implicits._

class CatsProvidersGraphContext[C <: blackbox.Context](val c: C, val log: Logger)
    extends CatsProviders[C]
    with GraphBuilderUtils[C] {

  case class Param(symbol: c.Symbol, tpe: c.Type)
  lazy val typeCheckUtil = new TypeCheckUtil[c.type](c, log)

  class CatsProvidersGraph(providers: List[Provider], val root: Provider) {
    log(s"Created graph for providers [${mkStringProviders(providers)}]")

    private def verifyOrder(resultProviders: List[Provider]) = {
      log.withBlock(s"Verifying order of [${mkStringProviders(resultProviders)}]") {
        def go(provider: Provider): Set[Provider] =
          provider.dependencies.flatten.foldl(Set(provider)) { case (result, (_, dep)) => go(dep) ++ result }

        resultProviders.lastOption.map(go).map { usedProviders =>
          val notUsedProviders = providers.diff(usedProviders.toSeq)

          if (notUsedProviders.nonEmpty)
            c.abort(
              c.enclosingPosition,
              s"Not used providers for the following types [${notUsedProviders.map(_.resultType).mkString(", ")}]"
            )
        }
      }
    }

    private def failOnMissingDependencies() = log.withBlock("Checking missing dependencies") {
      type ProviderPath = List[(c.Symbol, Provider)]
      case class CheckContext(currentPath: ProviderPath, missingPaths: Set[ProviderPath]) {
        def withProvider(sym: c.Symbol, provider: Provider)(f: CheckContext => CheckContext) = {
          val currentCtx = copy(currentPath = currentPath.:+((sym, provider)))
          val processedCtx = f(currentCtx)
          copy(missingPaths = processedCtx.missingPaths)
        }

        def missingPath(): CheckContext = copy(missingPaths = missingPaths.+(currentPath))
      }

      object CheckContext {
        lazy val empty = CheckContext(List.empty, Set.empty)
      }

      def go(ctx: CheckContext)(sym: c.Symbol, provider: Provider): CheckContext =
        log.withBlock(s"Checking provider [$provider] for type [${provider.resultType}]") {
          ctx.withProvider(sym, provider) { currentCtx =>
            provider match {
              case _: NotResolvedProvider => currentCtx.missingPath()
              case _ =>
                provider.dependencies.foldLeft(currentCtx) { case (paramsCtx, deps) =>
                  deps.foldLeft(paramsCtx) { case (paramCtx, (sym, param)) => go(paramCtx)(sym, param) }
                }
            }
          }
        }

      val resultCtx = go(CheckContext.empty)(c.universe.NoSymbol, root)

      if (resultCtx.missingPaths.nonEmpty) {
        def buildPathMsg(path: ProviderPath) = if (path.size <= 1) path.map(_._1).mkString
        else {
          val head = path.head._2
          val mid = path.drop(1).dropRight(1)
          val last = path.last._2

          val midStr = mid.map { case (sym, provider) => s".${sym.name} -> [${provider.symbol}]" }.mkString("", "", "")

          s"Missing dependency of type [${last.resultType}]. Path [${head.symbol}]$midStr.${last.symbol.name}"
        }

        val msg = resultCtx.missingPaths
          .map(buildPathMsg)
          .mkString(s"Failed to create an instance of [${root.resultType}].\n", "\n", "\n")

        c.error(c.enclosingPosition, msg)
      }
    }

    def topologicalSort(): List[Provider] = log.withBlock("Stable topological sort") {
      def go(provider: Provider, usedProviders: Set[Provider]): List[Provider] =
        log.withBlock(s"Going deeper for type [${provider.resultType}]") {

          provider.dependencies.flatten.foldl(List.empty[Provider]) {
            case (r, (_, p: NotResolvedProvider)) =>
              log.withBlock(s"Skipping not resolved provider for type [${p.resultType}]")(r)
            case (r, (_, p)) if (usedProviders ++ r.toSet).contains(p) =>
              log.withBlock(s"Already used provider for type [${p.resultType}]")(r)
            case (r, (_, p)) =>
              log.withResult { (r ::: go(p, (usedProviders ++ r.toSet) + p)) :+ p }(result =>
                s"Built list for the following types [${result.map(_.resultType).mkString(", ")}]"
              )
          }
        }

      val result = providers.foldLeft(List.empty[Provider]) {
        case (resultProviders, nextProvider) if resultProviders.contains(nextProvider) => resultProviders
        case (resultProviders, nextProvider) => resultProviders ::: go(nextProvider, resultProviders.toSet)
      } :+ root

      failOnMissingDependencies()
      verifyOrder(result)

      result
    }

  }

  /** Creates a graph representation of a given providers list and all necessary intermediate providers that are
    * required to create an instance of the `rootType`.
    */
  def buildGraph(rawProviders: List[c.universe.Expr[Any]], rootType: c.Type): CatsProvidersGraph = {

    /** `FactoryMethodTree` does not extend the `Provider` trait, but it may be considered as a lazy representation of
      * `FactoryMethod` so the name `inputProviders` is appropriate
      */
    val inputProviders = rawProviders
      .map { expr =>
        val tree = expr.tree
        val tpe = typeCheckUtil.typeCheckIfNeeded(tree)

        FactoryMethodTree(tree)
          .map(_.asRight[Provider])
          .orElse(Resource.fromTree(tree, tpe).map(_.asLeft[FactoryMethodTree]))
          .orElse(Effect.fromTree(tree, tpe).map(_.asLeft[FactoryMethodTree]))
          .getOrElse(new Instance(tree).asLeft[FactoryMethodTree])
      }

    val (providers, fmts): (List[Provider], List[FactoryMethodTree]) =
      inputProviders.partitionBifold(identity)

    log(s"Providers: [${providers.mkString(", ")}]")
    log(s"Factory methods: [${fmts.mkString(", ")}]")

    val initContext = BuilderContext(providers, fmts)

    val resolvedFMContext = resolveFactoryMethods(initContext)

    /** We assume that we cannot use input provider directly, so we create a result object with available constructors.
      * It's a mimic of `wire` 's property
      */

    val (resolvedCtx, rootProvider) = maybeResolveWithFactoryMethod(resolvedFMContext)(
      rootType
    ).getOrElse(
      findResolvableCreator(resolvedFMContext)(rootType).getOrElse(
        c.abort(c.enclosingPosition, s"Cannot construct an instance of type: [$rootType]")
      )
    )

    val inputProvidersTypes = inputProviders
      .map {
        case Left(value)  => value.resultType
        case Right(value) => value.resultType
      }

    val sortedProviders = sortProvidersWithInputOrder(inputProvidersTypes)(resolvedCtx)

    log(s"Input providers order [${sortedProviders.mkString(", ")}]")

    new CatsProvidersGraph(sortedProviders, rootProvider)
  }

  /** Make sure that we compose resources in the same order we received them
    */
  private def sortProvidersWithInputOrder(
      inputProvidersTypes: List[c.Type]
  )(resultContext: BuilderContext): List[Provider] = {
    val inputProviders = inputProvidersTypes.map(tpe =>
      resultContext.providers.find(_.resultType <:< tpe).getOrElse(c.abort(c.enclosingPosition, "Internal error"))
    )

    inputProviders ++ resultContext.providers.diff(inputProviders)
  }

  private def maybeResolveWithFactoryMethod(ctx: BuilderContext)(param: c.Type): Option[(BuilderContext, Provider)] =
    ctx.providers
      .find {
        case fm: FactoryMethod => fm.resultType <:< param
        case _                 => false
      }
      .map(r => (ctx, r))

  type CreatorType = (c.Symbol, List[List[c.Symbol]], List[List[c.Tree]] => c.Tree)

  private def findResolvableCreator(ctx: BuilderContext)(param: c.Type): Option[(BuilderContext, Provider)] = {
    val maybeConstructor = ConstructorCrimper.constructorFactory(c, log)(param).map(resolveCreatorParams(ctx)(param, _))
    val maybeApply = CompanionCrimper.applyFactory(c, log)(param).map(resolveCreatorParams(ctx)(param, _))

    def allDependenciesResolved(result: (BuilderContext, FactoryMethod)): Boolean =
      result._2.dependencies.flatten.collectFirst { case t @ (_, _: NotResolvedProvider) => t }.isEmpty

    (maybeConstructor, maybeApply) match {
      // Found at least one resolvable creator
      case (Some(result), _) if allDependenciesResolved(result) => Some(result)
      case (_, Some(result)) if allDependenciesResolved(result) => Some(result)

      // Found at least one non-resolvable creator
      case (Some(result), _)    => Some(result)
      case (None, Some(result)) => Some(result)

      // Not found any creator
      case (None, None) => None
    }
  }

  private def resolveCreatorParams(
      ctx: BuilderContext
  )(param: c.Type, creator: CreatorType): (BuilderContext, FactoryMethod) = {
    val (methodSymbol, creatorParams, creatorF) = creator

    val paramsTypes = creatorParams.map(_.map(paramSym => Param(paramSym, paramType(c)(param, paramSym))))
    log.trace(s"Creator params [${paramsTypes.mkString(", ")}] for type [$param]")

    val (updatedCtx, resolvedConParams) = resolveParamsLists(ctx)(paramsTypes)

    val fm = new FactoryMethod(methodSymbol, param, param, resolvedConParams, creatorF)
    (updatedCtx.addProvider(fm), fm)
  }

  private def resolveParamsList(
      ctx: BuilderContext
  )(params: List[Param]): (BuilderContext, List[(c.Symbol, Provider)]) =
    params.foldLeft((ctx, List.empty[(c.Symbol, Provider)])) {
      case ((currentCtx, resolvedParams), Param(paramSym, paramTpe)) =>
        currentCtx.resolve(paramTpe) match {
          case Some(Left(provider)) => (currentCtx, resolvedParams.:+((paramSym, provider)))
          case Some(Right(fmt)) => {
            val (updatedCtx, fm) = resolveFactoryMethod(currentCtx)(fmt)

            (updatedCtx, resolvedParams.:+((paramSym, fm)))
          }
          case None if isWireable(c)(paramTpe) =>
            findResolvableCreator(currentCtx)(paramTpe)
              .map { case (updatedCtx, fm) =>
                updatedCtx.logContext
                (updatedCtx, resolvedParams.:+((paramSym, fm)))
              }
              .getOrElse(((ctx, resolvedParams.:+((paramSym, new NotResolvedProvider(paramTpe, paramSym))))))
          case _ => (ctx, resolvedParams.:+((paramSym, new NotResolvedProvider(paramTpe, paramSym))))
        }
    }

  private def resolveParamsLists(
      ctx: BuilderContext
  )(params: List[List[Param]]): (BuilderContext, List[List[(c.Symbol, Provider)]]) =
    log.withBlock(s"Resolving params [${mkStringFrom2DimList(params)}]") {
      params.foldLeft((ctx, List.empty[List[(c.Symbol, Provider)]])) {
        case ((currentCtx, resolvedParamsLists), paramsList) =>
          val (updatedCtx, resolvedParamsList) = resolveParamsList(currentCtx)(paramsList)
          (updatedCtx, resolvedParamsLists :+ resolvedParamsList)
      }
    }

  private def resolveFactoryMethod(ctx: BuilderContext)(fmt: FactoryMethodTree): (BuilderContext, FactoryMethod) = {
    import c.universe._

    val FactoryMethodTree(params, fun, resultType) = fmt

    log.withBlock(s"Resolving factory method [$fun]") {
      val paramsTypes = params.map { case s @ ValDef(_, name, tpt, rhs) =>
        Param(s.symbol, typeCheckUtil.typeCheckIfNeeded(tpt))
      }
      val (updatedCtx, deps) = resolveParamsLists(ctx)(List(paramsTypes))

      log(s"Resolved dependencies [${deps.map(_.mkString(", ")).mkString("\n")}]")

      val fm = new FactoryMethod(
        symbol = fun.symbol,
        methodType = fun.symbol.asMethod.returnType,
        resultType = resultType,
        dependencies = deps,
        apply = _.foldLeft(fun)((acc: Tree, args: List[Tree]) => Apply(acc, args))
      )
      (
        updatedCtx.resolvedFactoryMethod(fm),
        fm
      )
    }
  }

  /** DFS based algorithm that resolves all `FactoryMethodTree`
    */
  private def resolveFactoryMethods(ctx: BuilderContext): BuilderContext = ctx.next() match {
    case None => ctx
    case Some(fmt) => {
      val (updatedCtx, _) = resolveFactoryMethod(ctx)(fmt)

      resolveFactoryMethods(updatedCtx)
    }
  }

  private def mkStringFrom2DimList[T](ll: List[List[T]]): String = ll.map(_.mkString(", ")).mkString("\n")
  private def mkStringProviders(providers: List[Provider]) = providers.map(_.resultType).mkString(", ")
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy