All Downloads are FREE. Search and download functionalities are using the official Maven repository.

kyo.internal.LayerMacros.scala Maven / Gradle / Ivy

The newest version!
package kyo.internal

import kyo.*
import kyo.Ansi.*
import scala.quoted.*

private[kyo] object LayerMacros:

    transparent inline def reflect(using q: Quotes): q.reflectModule = q.reflect

    def reportErrors(using Quotes)(errors: ::[GraphError[reflect.TypeRepr, Expr[Layer[?, ?]]]]): Nothing =
        import reflect.*
        val messages = errors.map { error =>
            error match
                case GraphError.MissingInput(input, Some(parent)) => // TODO: show full type of parent
                    s"Missing Input: ${input.show} for Layer: ${parent.value.show}".red

                case GraphError.MissingInput(input, None) =>
                    s"Missing Input: ${input.show}".red

                case GraphError.AmbiguousInputs(target, found, parent) => // TODO: show full type of parent
                    s"Ambigious Inputs: ${found.map(_.value.show).mkString(", ")} for: ${target.show}".red

                case GraphError.CircularDependency(node) => // TODO: improve error message to include multiple nodes
                    s"Circular dependencies found: ${node.value.show} with Inputs: ${node.inputs.map(_.show).mkString(", ")} and Outputs: ${node
                            .outputs.map(_.show).mkString(", ")}".red
        }.distinct
        report.errorAndAbort(messages.mkString("\n"))
    end reportErrors

    transparent inline def make[Target](inline layers: Layer[?, ?]*): Layer[Target, ?] =
        ${ kyo.internal.LayerMacros.makeImpl[Target]('layers) }

    def makeImpl[Target: Type](using Quotes)(expr: Expr[Seq[Layer[?, ?]]]): Expr[Layer[Target, ?]] =
        import reflect.*
        if TypeRepr.of[Target] =:= TypeRepr.of[Nothing] then
            report.errorAndAbort("Missing Target Type; Did you forget to provide it?".red)
        expr match
            case Varargs(layers) =>
                val nodes = layers.map(layerToNode(_))
                val graph = Graph(nodes.toSet)(_ <:< _)

                val targets = flattenAnd(TypeRepr.of[Target])
                val targetLayer = graph.buildTargets(targets, None) match
                    case Validated.Success(value) =>
                        value
                    case Validated.Error(errors) =>
                        reportErrors(errors)

                val exprFold = targetLayer.fold[Expr[Layer[?, ?]]](
                    { case ('{ $left: Layer[out1, s1] }, '{ $right: Layer[out2, s2] }) =>
                        '{ $left.and($right) }
                    },
                    { case ('{ $left: Layer[out1, s1] }, right) =>
                        right match
                            case '{ $right: Layer[out2, Env[`out1`] & s2] } =>
                                '{ $left.to($right) }
                    },
                    _.value, {
                        // TODO: MAke nIcEr PlZEaz
                        val debugFold =
                            targetLayer.fold[String]("(" + _ + " and " + _ + ")", "(" + _ + " to " + _ + ")", _.value.show, "Empty")

                        report.errorAndAbort(
                            s"""|
                                | Empty layer found as input to Layer with non-zero requirements. Did you fully resolve dependencies?
                                | Debug: $debugFold
                                |""".stripMargin
                        )
                    }
                )

                exprFold.asInstanceOf[Expr[Layer[Target, ?]]]

            case _ => report.errorAndAbort("No layers found. Did you forget to provide them?".red)
        end match
    end makeImpl

    def layerToNode(using Quotes)(expr: Expr[Layer[?, ?]]): Node[reflect.TypeRepr, Expr[Layer[?, ?]]] =
        import reflect.*
        expr match
            case '{ $layer: Layer[out, s] } =>
                val inputs  = extractInputsFromPending(TypeRepr.of[s])
                val outputs = flattenAnd(TypeRepr.of[out])
                Node(inputs, outputs, expr)
        end match
    end layerToNode

    def extractInputsFromPending(using Quotes)(typeRepr: reflect.TypeRepr): Set[reflect.TypeRepr] =
        flattenAnd(typeRepr).flatMap(extractEnvs)

    object AsType:
        def unapply(using Quotes)(typeRepr: reflect.TypeRepr): Option[Type[?]] =
            Some(typeRepr.asType)

    def extractEnvs(using Quotes)(typeRepr: reflect.TypeRepr): Set[reflect.TypeRepr] =
        import reflect.*
        typeRepr.asType match
            case '[Env[tpe]] => flattenAnd(TypeRepr.of[tpe])
            case _           => Set.empty
    end extractEnvs

    def flattenAnd(using Quotes)(typeRepr: reflect.TypeRepr): Set[reflect.TypeRepr] =
        import reflect.*
        typeRepr match
            case AndType(left, right) =>
                flattenAnd(left) ++ flattenAnd(right)
            case typeRepr =>
                Set(typeRepr)
        end match
    end flattenAnd

    enum LayerLike[+A]:
        case And(left: LayerLike[A], right: LayerLike[A])
        case To(left: LayerLike[A], right: LayerLike[A])
        case Value(value: A)
        case Empty

        infix def and[A1 >: A](that: LayerLike[A1]): LayerLike[A1] = And(this, that)
        infix def to[A1 >: A](that: LayerLike[A1]): LayerLike[A1]  = To(this, that)

        def fold[B](andCase: (B, B) => B, toCase: (B, B) => B, valueCase: A => B, emptyCase: => B): B =
            this match
                case And(left, right) =>
                    andCase(left.fold(andCase, toCase, valueCase, emptyCase), right.fold(andCase, toCase, valueCase, emptyCase))
                case To(left, right) =>
                    toCase(left.fold(andCase, toCase, valueCase, emptyCase), right.fold(andCase, toCase, valueCase, emptyCase))
                case Value(value) => valueCase(value)
                case Empty        => emptyCase
    end LayerLike

    object LayerLike:
        def apply[A](value: A): LayerLike[A]            = Value(value)
        given [A]: CanEqual[LayerLike[A], LayerLike[A]] = CanEqual.derived

    final case class Node[Key, Value](inputs: Set[Key], outputs: Set[Key], value: Value)
    object Node:
        given [Key, Value]: CanEqual[Node[Key, Value], Node[Key, Value]] = CanEqual.derived

    enum Validated[+E, +A]:
        case Success(value: A)
        case Error(errors: ::[E])

        def map[B](f: A => B): Validated[E, B] =
            this match
                case Success(value) => Success(f(value))
                case Error(errors)  => Error(errors)

        def flatMap[E1 >: E, B](f: A => Validated[E1, B]): Validated[E1, B] =
            this match
                case Success(value) => f(value)
                case Error(errors)  => Error(errors)

        def zipWith[E1 >: E, B, C](that: Validated[E1, B])(f: (A, B) => C): Validated[E1, C] =
            (this, that) match
                case (Success(value1), Success(value2)) => Success(f(value1, value2))
                case (Error(e :: es), Error(es2))       => Error(::(e, es ++ es2))
                case (Error(errors1), _)                => Error(errors1)
                case (_, Error(errors2))                => Error(errors2)
    end Validated

    object Validated:
        def succeed[E, A](value: A): Validated[E, A] = Success(value)
        def error[E, A](error: E): Validated[E, A]   = Error(::(error, Nil))

        def traverse[E, A, B](set: Set[A])(f: A => Validated[E, B]): Validated[E, Set[B]] =
            set.foldLeft[Validated[E, Set[B]]](Validated.Success(Set.empty)) { (acc, item) =>
                acc.zipWith(f(item)) { (accSet, b) =>
                    accSet + b
                }
            }

        def sequence[E, A](set: Set[Validated[E, A]]): Validated[E, Set[A]] =
            set.foldLeft(Validated.Success(Set.empty)) { (acc, item) =>
                acc.zipWith(item) { (accSet, b) =>
                    accSet + b
                }
            }
    end Validated

    enum GraphError[Key, Value]:
        case MissingInput(input: Key, parent: Option[Node[Key, Value]])
        case CircularDependency(node: Node[Key, Value]) // TODO: Maybe also trace the path.
        case AmbiguousInputs(target: Key, found: Set[Node[Key, Value]], parent: Option[Node[Key, Value]])
    end GraphError

    final case class Graph[Key, Value](nodes: Set[Node[Key, Value]])(equals: (Key, Key) => Boolean):

        def buildTargets(
            targets: Set[Key],
            parent: Option[Node[Key, Value]],
            seen: Set[Node[Key, Value]] = Set.empty
        ): Validated[GraphError[Key, Value], LayerLike[Node[Key, Value]]] =
            for
                nodes <- findNodesWithOutputs(targets, parent, seen)
                values <- Validated.traverse(nodes) { node =>
                    if node.inputs.isEmpty then Validated.succeed(LayerLike.Value(node))
                    else
                        buildTargets(node.inputs, Some(node), seen = seen + node)
                            .map { input => input to LayerLike.Value(node) }
                }
            yield values.reduceOption(_ and _).getOrElse(LayerLike.Empty)
        end buildTargets

        def findNodesWithOutputs(
            targets: Set[Key],
            parent: Option[Node[Key, Value]],
            seen: Set[Node[Key, Value]]
        ): Validated[GraphError[Key, Value], Set[Node[Key, Value]]] =
            Validated.traverse(targets) { target => findNodeWithOutputFor(target, parent, seen) }

        def findNodeWithOutputFor(
            target: Key,
            parent: Option[Node[Key, Value]],
            seen: Set[Node[Key, Value]]
        ): Validated[GraphError[Key, Value], Node[Key, Value]] =
            val matching = nodes.filter { node => node.outputs.exists(output => equals(output, target)) }
            matching.size match
                case 1 =>
                    val matched = matching.head
                    if seen.contains(matched) then Validated.error(GraphError.CircularDependency(matched))
                    else Validated.succeed(matched)
                case 0 => Validated.error(GraphError.MissingInput(target, parent))
                case _ => Validated.error(GraphError.AmbiguousInputs(target, matching, parent))
            end match
        end findNodeWithOutputFor
    end Graph
end LayerMacros




© 2015 - 2025 Weber Informatics LLC | Privacy Policy