All Downloads are FREE. Search and download functionalities are using the official Maven repository.

libretto.scaletto.impl.FreeScaletto.scala Maven / Gradle / Ivy

The newest version!
package libretto.scaletto.impl

import libretto.scaletto.Scaletto
import libretto.lambda.{AForest, CapturingFun, ClosedSymmetricMonoidalCategory, CocartesianSemigroupalCategory, CoproductPartitioning, Distribution, EnumModule, Focus, Lambdas, LambdasImpl, Member, Partitioning, PatternMatching, SemigroupalCategory, Shuffled, Sink, SymmetricSemigroupalCategory, Tupled, Var}
import libretto.lambda.Partitioning.SubFun
import libretto.lambda.util.{Applicative, Exists, NonEmptyList, SourcePos, StaticValue, Validated}
import libretto.lambda.util.Validated.{Invalid, Valid, invalid}
import libretto.lambda.util.Monad.monadEither
import libretto.util.Async
import scala.concurrent.duration.FiniteDuration
import scala.reflect.TypeTest

object FreeScaletto extends Scaletto {

  override type One = phantoms.One
  override type Void = phantoms.Void
  override type Done = phantoms.Done
  override type Need = phantoms.Need
  override type Ping = phantoms.Ping
  override type Pong = phantoms.Pong
  override type RTerminus = phantoms.RTerminus
  override type LTerminus = phantoms.LTerminus
  override type |*|[A, B] = phantoms.|*|[A, B]
  override type |+|[A, B] = phantoms.|+|[A, B]
  override type |&|[A, B] = phantoms.|&|[A, B]
  override type ||[A, B] = phantoms.||[A, B]
  override type ::[Label, T] = phantoms.::[Label, T]
  override type OneOf[Cases] = phantoms.OneOf[Cases]
  override type Rec[F[_]] = phantoms.Rec[F[_]]
  override type Sub[A, B] = phantoms.Sub[A, B]
  override type -[A] = phantoms.-[A]
  override type Val[A] = phantoms. Val[A]
  override type Res[A] = phantoms. Res[A]
  override type UInt31 = Val[Int]
  override type -⚬[A, B] = libretto.scaletto.impl.-⚬[A, B]

  import -⚬.{*, given}
  import Fun.*

  override type ScalaFun[A, B] = ScalaFunction[A, B]

  override object ScalaFun extends ScalaFuns {
    override def apply[A, B](f: A => B): ScalaFun[A, B]        = ScalaFunction.Direct(f)
    override def blocking[A, B](f: A => B): ScalaFun[A, B]     = ScalaFunction.Blocking(f)
    override def async[A, B](f: A => Async[B]): ScalaFun[A, B] = ScalaFunction.Asynchronous(f)

    override def adapt[A, B, Z, C](f: ScalaFun[A, B])(
      pre: Z => A,
      post: B => C,
    ): ScalaFun[Z, C] =
      f.adapt(pre, post)

    override def adaptWith[A, B, Z, P, C](f: ScalaFun[A, B])(
      pre: Z => (P, A),
      post: (P, B) => C,
    ): ScalaFun[Z, C] =
      f.adaptWith(pre, post)

    override def eval[A, B]: ScalaFun[(ScalaFun[A, B], A), B] =
      ScalaFunction.eval[A, B]
  }

  override def id[A]: A -⚬ A =
    -⚬.id[A]

  override def andThen[A, B, C](f: A -⚬ B, g: B -⚬ C): A -⚬ C =
    f > g

  override def par[A, B, C, D](
    f: A -⚬ B,
    g: C -⚬ D,
  ): (A |*| C) -⚬ (B |*| D) =
    -⚬.par(f, g)

  override def introFst[B]: B -⚬ (One |*| B) =
    -⚬.𝒞.introFst[B]

  override def introSnd[A]: A -⚬ (A |*| One) =
    -⚬.𝒞.introSnd[A]

  override def elimFst[B]: (One |*| B) -⚬ B =
    -⚬.𝒞.elimFst[B]

  override def elimSnd[A]: (A |*| One) -⚬ A =
    -⚬.𝒞.elimSnd[A]

  override def assocLR[A, B, C]: ((A |*| B) |*| C) -⚬ (A |*| (B |*| C)) =
    -⚬.𝒞.assocLR

  override def assocRL[A, B, C]: (A |*| (B |*| C)) -⚬ ((A |*| B) |*| C) =
    -⚬.𝒞.assocRL

  override def swap[A, B]: (A |*| B) -⚬ (B |*| A) =
    -⚬.𝒞.swap

  override def injectL[A, B]: A -⚬ (A |+| B) =
    -⚬.cocat.injectL

  override def injectR[A, B]: B -⚬ (A |+| B) =
    -⚬.cocat.injectR

  override def either[A, B, C](f: A -⚬ C, g: B -⚬ C): (A |+| B) -⚬ C =
    -⚬.cocat.either(f, g)

  override def absurd[A]: Void -⚬ A =
    Regular(Absurd())

  override def chooseL[A, B]: (A |&| B) -⚬ A =
    Regular(ChooseL())

  override def chooseR[A, B]: (A |&| B) -⚬ B =
    Regular(ChooseR())

  override def choice[A, B, C](f: A -⚬ B, g: A -⚬ C): A -⚬ (B |&| C) =
    -⚬.choice(f, g)

  override def ping: One -⚬ Ping =
    Regular(PingF())

  override def pong: Pong -⚬ One =
    Regular(PongF())

  override def delayIndefinitely: Done -⚬ RTerminus =
    Regular(DelayIndefinitely())

  override def regressInfinitely: LTerminus -⚬ Need =
    Regular(RegressInfinitely())

  override def fork: Done -⚬ (Done |*| Done) =
    Regular(Fork())

  override def join: (Done |*| Done) -⚬ Done =
    Regular(Join())

  override def forkNeed: (Need |*| Need) -⚬ Need =
    Regular(ForkNeed())

  override def joinNeed: Need -⚬ (Need |*| Need) =
    Regular(JoinNeed())

  override def notifyDoneL: Done -⚬ (Ping |*| Done) =
    Regular(NotifyDoneL())

  override def notifyNeedL: (Pong |*| Need) -⚬ Need =
    Regular(NotifyNeedL())

  override def forkPing: Ping -⚬ (Ping |*| Ping) =
    Regular(ForkPing())

  override def forkPong: (Pong |*| Pong) -⚬ Pong =
    Regular(ForkPong())

  override def joinPing: (Ping |*| Ping) -⚬ Ping =
    Regular(JoinPing())

  override def joinPong: Pong -⚬ (Pong |*| Pong) =
    Regular(JoinPong())

  override def strengthenPing: Ping -⚬ Done =
    Regular(StrengthenPing())

  override def strengthenPong: Need -⚬ Pong =
    Regular(StrengthenPong())

  override def joinRTermini: (RTerminus |*| RTerminus) -⚬ RTerminus =
    Regular(JoinRTermini())

  override def joinLTermini: LTerminus -⚬ (LTerminus |*| LTerminus) =
    Regular(JoinLTermini())

  override def notifyEither[A, B]: (A |+| B) -⚬ (Ping |*| (A |+| B)) =
    Regular(NotifyEither())

  override def notifyChoice[A, B]: (Pong |*| (A |&| B)) -⚬ (A |&| B) =
    Regular(NotifyChoice())

  override def injectLOnPing[A, B]: (Ping |*| A) -⚬ (A |+| B) =
    Regular(InjectLOnPing())

  override def chooseLOnPong[A, B]: (A |&| B) -⚬ (Pong |*| A) =
    Regular(ChooseLOnPong())

  override def distributeL[A, B, C]: (A |*| (B |+| C)) -⚬ ((A |*| B) |+| (A |*| C)) =
    Regular(DistributeL())

  override def coDistributeL[A, B, C]: ((A |*| B) |&| (A |*| C)) -⚬ (A |*| (B |&| C)) =
    Regular(CoDistributeL())

  override def rInvertSignal: (Done |*| Need) -⚬ One =
    Regular(RInvertSignal())

  override def lInvertSignal: One -⚬ (Need |*| Done) =
    Regular(LInvertSignal())

  override def rInvertPingPong: (Ping |*| Pong) -⚬ One =
    Regular(RInvertPingPong())

  override def lInvertPongPing: One -⚬ (Pong |*| Ping) =
    Regular(LInvertPongPing())

  override def rInvertTerminus: (RTerminus |*| LTerminus) -⚬ One =
    Regular(RInvertTerminus())

  override def lInvertTerminus: One -⚬ (LTerminus |*| RTerminus) =
    Regular(LInvertTerminus())

  override def rec[A, B](using pos: SourcePos)(f: (A -⚬ B) => (A -⚬ B)): A -⚬ B =
    -⚬.rec(pos, f)

  override def rec[A, B](f: (Sub[A, B] |*| A) -⚬ B): A -⚬ B =
    Regular(RecFun(f))

  override def invoke[A, B]: (Sub[A, B] |*| A) -⚬ B =
    Regular(InvokeSub())

  override def comonoidSub[A, B]: Comonoid[Sub[A, B]] =
    new Comonoid[Sub[A, B]]:
      override def counit: Sub[A, B] -⚬ One = -⚬.ignoreSub
      override def split: Sub[A, B] -⚬ (Sub[A, B] |*| Sub[A, B]) = -⚬.dupSub

  override def unpack[F[_]]: Rec[F] -⚬ F[Rec[F]] =
    Regular(Unpack())

  override def pack[F[_]]: F[Rec[F]] -⚬ Rec[F] =
    Regular(Pack())

  override def racePair: (Ping |*| Ping) -⚬ (One |+| One) =
    Regular(RacePair())

  override def selectPair: (One |&| One) -⚬ (Pong |*| Pong) =
    Regular(SelectPair())

  override def sharedCode[A, B](using pos: SourcePos)(f: A -⚬ B): A -⚬ B =
    f.blueprint match
      case Valid(f) =>
        FunRef(new Object, f) // XXX use a proper ID
      case Invalid(es) =>
        assemblyError(ForbiddenCaptureOfRecursiveCallsIntoLibCode(pos, es))

  override def sub[A, B](using pos: SourcePos)(f: A -⚬ B): One -⚬ Sub[A, B] =
    f.blueprint match
      case Valid(f) =>
        ConstSub(f)
      case Invalid(es) =>
        assemblyError(ForbiddenCaptureOfRecursiveCallsIntoLibCode(pos, es))

  override def crashWhenDone[A, B](msg: String): (Done |*| A) -⚬ B =
    Regular(CrashWhenDone(msg))

  override def delay: Val[FiniteDuration] -⚬ Done =
    Regular(Delay())

  override def liftEither[A, B]: Val[Either[A, B]] -⚬ (Val[A] |+| Val[B]) =
    Regular(LiftEither())

  override def liftPair[A, B]: Val[(A, B)] -⚬ (Val[A] |*| Val[B]) =
    Regular(LiftPair())

  override def unliftPair[A, B]: (Val[A] |*| Val[B]) -⚬ Val[(A, B)] =
    Regular(UnliftPair())

  override def mapVal[A, B](f: ScalaFun[A, B]): Val[A] -⚬ Val[B] =
    Regular(MapVal(f))

  override def constVal[A](a: A): Done -⚬ Val[A] =
    Regular(ConstVal(a))

  override def constNeg[A](a: A): Neg[A] -⚬ Need =
    Regular(ConstNeg(a))

  override def neglect[A]: Val[A] -⚬ Done =
    Regular(Neglect())

  override def notifyVal[A]: Val[A] -⚬ (Ping |*| Val[A]) =
    Regular(NotifyVal())

  override def notifyNeg[A]: (Pong |*| Neg[A]) -⚬ Neg[A] =
    Regular(NotifyNeg())

  override def debugPrint(msg: String): Ping -⚬ One =
    Regular(DebugPrint(msg))

  override def acquire[A, R, B](
    acquire: ScalaFun[A, (R, B)],
    release: Option[ScalaFun[R, Unit]],
  ): Val[A] -⚬ (Res[R] |*| Val[B]) =
    Regular(Acquire(acquire, release))

  override def tryAcquire[A, R, B, E](
    acquire: ScalaFun[A, Either[E, (R, B)]],
    release: Option[ScalaFun[R, Unit]],
  ): Val[A] -⚬ (Val[E] |+| (Res[R] |*| Val[B])) =
    Regular(TryAcquire(acquire, release))

  override def release[R]: Res[R] -⚬ Done =
    Regular(Release())

  override def releaseWith[R, A, B](f: ScalaFunction[(R, A), B]): (Res[R] |*| Val[A]) -⚬ Val[B] =
    Regular(ReleaseWith(f))

  override def effect[R, A, B](f: ScalaFunction[(R, A), B]): (Res[R] |*| Val[A]) -⚬ (Res[R] |*| Val[B]) =
    Regular(Effect(f))

  override def effectWr[R, A](f: ScalaFunction[(R, A), Unit]): (Res[R] |*| Val[A]) -⚬ Res[R] =
    Regular(EffectWr(f))

  override def tryEffectAcquire[R, A, S, B, E](
    f: ScalaFunction[(R, A), Either[E, (S, B)]],
    release: Option[ScalaFunction[S, Unit]],
  ): (Res[R] |*| Val[A]) -⚬ (Res[R] |*| (Val[E] |+| (Res[S] |*| Val[B]))) =
    Regular(TryEffectAcquire(f, release))

  override def tryTransformResource[R, A, S, B, E](
    f: ScalaFunction[(R, A), Either[E, (S, B)]],
    release: Option[ScalaFunction[S, Unit]],
  ): (Res[R] |*| Val[A]) -⚬ (Val[E] |+| (Res[S] |*| Val[B])) =
    Regular(TryTransformResource(f, release))

  override def trySplitResource[R, A, S, T, B, E](
    f: ScalaFunction[(R, A), Either[E, (S, T, B)]],
    release1: Option[ScalaFunction[S, Unit]],
    release2: Option[ScalaFunction[T, Unit]],
  ): (Res[R] |*| Val[A]) -⚬ (Val[E] |+| ((Res[S] |*| Res[T]) |*| Val[B])) =
    Regular(TrySplitResource(f, release1, release2))

  override def forevert[A]: One -⚬ (-[A] |*| A) =
    Regular(Forevert())

  override def backvert[A]: (A |*| -[A]) -⚬ One =
    Regular(Backvert())

  override def distributeInversion[A, B]: -[A |*| B] -⚬ (-[A] |*| -[B]) =
    Regular(DistributeInversion())

  override def factorOutInversion[A, B]: (-[A] |*| -[B]) -⚬ -[A |*| B] =
    Regular(FactorOutInversion())

  override def category: ClosedSymmetricMonoidalCategory[-⚬, |*|, One, =⚬] =
    𝒞

  override def sizeOf[A, B](f: A -⚬ B): Long =
    f.size

  override object UInt31 extends UInt31Scaletto {
    override def apply(n: Int): Done -⚬ UInt31 = {
      require(n >= 0, s"$n is negative")
      constVal(n)
    }

    override def add: (UInt31 |*| UInt31) -⚬ UInt31 =
      unliftPair > mapVal { case (x, y) => x + y }

    override def multiply: (UInt31 |*| UInt31) -⚬ UInt31 =
      unliftPair > mapVal { case (x, y) => x * y }

    override def increment: UInt31 -⚬ UInt31 =
      mapVal { _ + 1 }

    override def decrement: UInt31 -⚬ (Done |+| UInt31) =
      mapVal[Int, Either[Unit, Int]] {
        case 0 => Left(())
        case n => Right(n-1)
      } > liftEither > either(
        FreeScaletto.this.neglect > injectL,
        injectR,
      )

    override def neglect: UInt31 -⚬ Done =
      FreeScaletto.this.neglect

    override def fromInt: Val[Int] -⚬ UInt31 =
      id

    override def toInt: UInt31 -⚬ Val[Int] =
      id
  }

  override val OneOf: EnumModule.LeftAssociative[-⚬, |*|, OneOf, ||, ::] =
    EnumModule.fromBinarySums[-⚬, |*|, |+|, OneOf, [x, y] =>> y || x, ::](
      inj = [Label, A, Cases] => (i: Member[[x, y] =>> y || x, ::, Label, A, Cases]) => Regular(OneOfInject(i)),
      peel = [Label, A, Cases] => DummyImplicit ?=> Regular(OneOfPeel()),
      unpeel = [Label, A, Cases] => DummyImplicit ?=> Regular(OneOfUnpeel()),
      extract = [Label, A] => DummyImplicit ?=> Regular(OneOfExtractSingle()),
    )(using
      𝒞,
      cocat,
      distribution,
    )

  override val SumPartitioning =
    new CoproductPartitioning[-⚬, |*|, |+|]("InL", "InR")(using 𝒞, cocat, distribution)

  private val patmat =
    new PatternMatching[-⚬, |*|]

  import patmat.{Pattern, PatternMatchError}

  private type Var[A] = libretto.lambda.Var[VarOrigin, A]

  private type Extractor[A, B] =
    libretto.lambda.Extractor[-⚬, |*|, A, B]

  private val Extractor =
    libretto.lambda.Extractor

  private case class ExtractorOccurrence[A, B](
    ext: Extractor[A, B],
    pos: SourcePos,
  )

  private type PartialFun[A, B] =
    (A -⚬ B) | ExtractorOccurrence[A, B]

  private val psh: Shuffled[PartialFun, |*|] =
    Shuffled[PartialFun, |*|]

  private type -?>[A, B] = psh.Shuffled[A, B]

  private type LinearityViolation = Lambdas.LinearityViolation[VarOrigin, ScopeInfo]

  private def total[A, B](f: PartialFun[A, B]): TotalRes[A -⚬ B] =
    f match
      case g: (A -⚬ B) =>
        TotalRes.success(g)
      case p: ExtractorOccurrence[A, B] =>
        p.ext.isTotal match
          case Some(g) => TotalRes.success(g)
          case None => libretto.lambda.UnhandledCase.raise(s"Non-exhaustive matcher $p")

  private def foldTotal[A, B](f: A -?> B): TotalRes[A -⚬ B] =
    f.foldMapA(
      [X, Y] => (g: PartialFun[X, Y]) => total(g)
    )

  /** The result of extracting a total function from a partial one. */
  private type TotalRes[T] = Validated[(SourcePos, NonEmptyList[String]), T]
  private object TotalRes {
    def success[T](value: T): TotalRes[T] =
      Valid(value)
  }

  private val lambdas: Lambdas[PartialFun, |*|, VarOrigin, ScopeInfo] { val shuffled: psh.type } =
    Lambdas.using(psh)[VarOrigin, ScopeInfo]()

  override opaque type $[A] = lambdas.Expr[A]

  override opaque type LambdaContext = lambdas.Context

  private type MetaFun[A, B] = CapturingFun[-⚬, |*|, Tupled[|*|, $, _], A, B]

  override val `$`: $Ops = new $Ops {
    override def one(using pos: SourcePos, ctx: lambdas.Context): $[One] =
      lambdas.Expr.const([x] => (_: Unit) => introFst[x])(VarOrigin.OneIntro(pos))

    override def map[A, B](a: $[A])(f: A -⚬ B)(pos: SourcePos)(using
      lambdas.Context,
    ): $[B] =
      (a map f)(VarOrigin.FunAppRes(pos))

    override def zip[A, B](a: $[A], b: $[B])(
      pos: SourcePos,
    )(using
      lambdas.Context,
    ): $[A |*| B] =
      (a zip b)(VarOrigin.Pairing(pos))

    override def unzip[A, B](ab: $[A |*| B])(
      pos: SourcePos,
    )(using
      lambdas.Context,
    ): ($[A], $[B]) =
      lambdas.Expr.unzip(ab)(
        VarOrigin.Prj1(pos),
        VarOrigin.Prj2(pos),
      )

    override def matchAgainst[A, B](a: $[A], extractor: Extractor[A, B])(pos: SourcePos)(using LambdaContext): $[B] =
      lambdas.Expr.map(a, ExtractorOccurrence(extractor, pos))(VarOrigin.ExtractorRes(pos))

    override def app[A, B](f: $[A =⚬ B], a: $[A])(
      pos: SourcePos,
    )(using
      lambdas.Context,
    ): $[B] =
      map((f zip a)(VarOrigin.FunAndArg(pos)))(eval[A, B])(pos)

    override def nonLinear[A](a: $[A])(
      split: Option[A -⚬ (A |*| A)],
      discard: Option[A -⚬ One],
    )(
      pos: SourcePos,
    )(using
      lambdas.Context,
    ): $[A] = {
      lambdas.Context.registerNonLinearOps(a)(
        split,
        discard.map(f => [B] => (_: DummyImplicit) ?=> elimFst[A, B](f)),
      )
      a
    }
  }

  private sealed trait UnpackOnlyFun[A, B]:
    def compile: (A -⚬ B)
    def reverseCompile: (B -⚬ A)

  private object UnpackOnlyFun {
    case class Unpack[F[_]]() extends UnpackOnlyFun[Rec[F], F[Rec[F]]]:
      override def compile: Rec[F] -⚬ F[Rec[F]] = unpack
      override def reverseCompile: F[Rec[F]] -⚬ Rec[F] = pack

    object UnpackOnlySubFun extends Partitioning.SubFun[-⚬, UnpackOnlyFun] {
      override def compile[X, Y](f: UnpackOnlyFun[X, Y]): X -⚬ Y = f.compile
      override def reverseCompile[X, Y](f: UnpackOnlyFun[X, Y]): Y -⚬ X = f.reverseCompile

      override def sameAs[G[_,_]](that: SubFun[-⚬, G]): Option[[X, Y] => UnpackOnlyFun[X, Y] => G[X, Y]] =
        that match
          case UnpackOnlySubFun =>
            Some([X, Y] => (f: UnpackOnlyFun[X, Y]) => f)
          case _ =>
            None

      override def testEqual[X, Y, Z](f: UnpackOnlyFun[X, Y], g: UnpackOnlyFun[X, Z]): Option[Y =:= Z] =
        (f, g) match
          case (Unpack(), Unpack()) =>
            Some(summon[Y =:= Z])
    }

    given Partitioning.SubFun[-⚬, UnpackOnlyFun] = UnpackOnlySubFun
  }

  override def recPartitioning[F[_]](
    p: Partitioning[-⚬, |*|, F[Rec[F]]],
  ): Partitioning[-⚬, |*|, Rec[F]] { type Partition[P] = p.Partition[P] } =
    p.contramap[UnpackOnlyFun, Rec[F]](
      UnpackOnlyFun.Unpack[F]()
    )

  private given SymmetricSemigroupalCategory[PartialFun, |*|] with {
    override def id[A]: PartialFun[A, A] = FreeScaletto.id
    override def assocLR[A, B, C]: PartialFun[(A |*| B) |*| C, A |*| (B |*| C)] = FreeScaletto.this.assocLR
    override def assocRL[A, B, C]: PartialFun[A |*| (B |*| C), (A |*| B) |*| C] = FreeScaletto.this.assocRL
    override def swap[A, B]: PartialFun[A |*| B, B |*| A] = FreeScaletto.this.swap
    override def andThen[A, B, C](f: PartialFun[A, B], g: PartialFun[B, C]): PartialFun[A, C] =
      (total(f) zip total(g)) match
        case Valid((f, g)) => FreeScaletto.this.andThen(f, g)
        case Invalid(es) => raiseTotalityViolations(es) // XXX
    override def par[A1, A2, B1, B2](f1: PartialFun[A1, B1], f2: PartialFun[A2, B2]): PartialFun[A1 |*| A2, B1 |*| B2] =
      (total(f1) zip total(f2)) match
        case Valid((f1, f2)) => FreeScaletto.this.par(f1, f2)
        case Invalid(es) => raiseTotalityViolations(es) // XXX
  }

  private type UnusedInBranch = PatternMatching.UnusedInBranch[VarOrigin, ScopeInfo]

  private def switchSink[A, R](
    a: $[A],
    cases: Sink[[x, y] =>> (ScopeInfo, lambdas.Delambdifold[x, y]), |+|, A, R],
  )(
    pos: SourcePos,
  )(using
    lambdas.Context,
  ): Validated[UnusedInBranch, $[R]] = {
    type -?>[X, Y] = PartialFun[X, Y]

    given CocartesianSemigroupalCategory[-?>, |+|] with {
      override def andThen[A, B, C](f: A -?> B, g: B -?> C): A -?> C = f > g
      override def id[A]: A -?> A = FreeScaletto.id
      override def injectL[A, B]: A -?> (A |+| B) = FreeScaletto.this.injectL
      override def injectR[A, B]: B -?> (A |+| B) = FreeScaletto.this.injectR
      override def either[A, B, C](f: A -?> C, g: B -?> C): (A |+| B) -?> C =
        (total(f) zip total(g)) match
          case Valid((f, g)) => FreeScaletto.this.either(f, g)
          case Invalid(es)   => raiseTotalityViolations(es) // XXX
    }

    given Distribution[-?>, |*|, |+|] with {
      override val cat: SemigroupalCategory[-?>, |*|] = summon

      override def distLR[A, B, C]: (A |*| (B |+| C)) -?> ((A |*| B) |+| (A |*| C)) =
        FreeScaletto.this.distributeL

      override def distRL[A, B, C]: ((A |+| B) |*| C) -?> ((A |*| C) |+| (B |*| C)) =
        FreeScaletto.this.distributeR
    }

    CapturingFun.compileSink[PartialFun, |*|, |+|, lambdas.Expr, ScopeInfo, A, R](
      cases,
    )(
      lambdas.Context.exprDiscarder,
    ).map {
      case CapturingFun.NoCapture(f)  => lambdas.Expr.map(a, f)(VarOrigin.FunAppRes(pos))
      case CapturingFun.Closure(x, f) => mapTupled(Tupled.zip(x, Tupled.atom(a)), f)(pos)
    }.emap { case (c, exprs) =>
      PatternMatching.UnusedInBranch(Var.Set.fromList(exprs.toList.map(_.value.resultVar)), c)
    }
  }

  override def switch[A, R](using
    ctx: LambdaContext,
    switchPos: SourcePos,
  )(
    a: $[A],
  )(
    cases: (SourcePos, LambdaContext ?=> $[A] => $[R])*
  ): $[R] =
    cases.toList match
      case Nil =>
        throw IllegalArgumentException("switch must have at least 1 case")
      case c :: cs =>
        switchImpl(switchPos, a, NonEmptyList(c, cs))
          .getOrReportErrors

  private case class MisplacedExtractor(ext: ExtractorOccurrence[?, ?])

  private case class PatternMatchError(switchPos: SourcePos, error: patmat.PatternMatchError)

  private type SwitchRes[T] = Validated[LinearityViolation | UnusedInBranch | MisplacedExtractor | PatternMatchError, T]

  extension [T](r: SwitchRes[T]) {
    private def getOrReportErrors: T =
      r match
        case Valid(value) => value
        case Invalid(errors) => assemblyErrors(errors)
  }

  private def switchImpl[A, R](
    switchPos: SourcePos,
    scrutinee: $[A],
    cases: NonEmptyList[(SourcePos, LambdaContext ?=> $[A] => $[R])],
  )(using
    ctx: LambdaContext,
  ): SwitchRes[$[R]] = {
    import CapturingFun.{Closure, NoCapture}

    val inVar  = VarOrigin.SwitchIn(switchPos)
    val outVar = VarOrigin.SwitchOut(switchPos)
    val cases1: NonEmptyList[(ScopeInfo, VarOrigin, LambdaContext ?=> $[A] => $[R])] =
      cases.map { case (pos, f) => (ScopeInfo.SwitchCase(pos), VarOrigin.SwitchCase(pos), f) }

    patmat
      .forLambdas(lambdas)(
        isExtractor =
          [X, Y] => (f: PartialFun[X, Y]) =>
            f match {
              case eo: ExtractorOccurrence[X, Y] => Some(eo.ext)
              case _ => None
            },
        lower =
          [X, Y] => (f: PartialFun[X, Y]) =>
            f match {
              case f: (X -⚬ Y) => Valid(f)
              case f: ExtractorOccurrence[x, y] => invalid(MisplacedExtractor(f))
            },
        lift = [X, Y] => (f: X -⚬ Y) => (f: PartialFun[X, Y]),
      )
      .delambdifyAndCompile(scrutinee, inVar, outVar, cases1)
      .emap {
        case e: LinearityViolation => e
        case e: UnusedInBranch => e
        case e: MisplacedExtractor => e
        case e: patmat.PatternMatchError => PatternMatchError(switchPos, e)
      }
  }

  override val λ = new LambdaOps {
    override def apply[A, B](using pos: SourcePos)(f: lambdas.Context ?=> $[A] => $[B]): A -⚬ B =
      compile(f)(pos)

    private def recLocal[A, B](using pos: SourcePos, ctx: LambdaContext)(
      f: LambdaContext ?=> $[Sub[A, B]] => $[A] => $[B],
    ): MetaFun[A, B] =
      given Comonoid[Sub[A, B]] =
        comonoidSub[A, B]

      val f1: LambdaContext ?=> $[Sub[A, B] |*| A] => $[B] =
        { case *(self) |*| a => f(self)(a) }

      val g: MetaFun[Sub[A, B] |*| A, B] =
        compileMeta(f1)(pos)

      import CapturingFun.{Closure, NoCapture}

      g match
        case NoCapture(g) =>
          NoCapture(FreeScaletto.rec(g))
        case Closure(x, g) =>
          recLocalWithCapture(pos, x, g)

    private def recLocalWithCapture[X, A, B](
      recFunPos: SourcePos,
      x: Tupled[|*|, $, X],
      f: (X |*| (Sub[A, B] |*| A)) -⚬ B
    )(using
      LambdaContext
    ): MetaFun[A, B] =
      val dupX: X -⚬ (X |*| X) =
        multiDup(x)
          .valueOr { vs => assemblyError(MissingDupForRecFun(recFunPos, vs)) }
      val elimX: X -⚬ One =
        multiDiscard(x)
          .valueOr { vs => assemblyError(MissingDiscardForRecFun(recFunPos, vs)) }
      val g: (Sub[X |*| A, B] |*| (X |*| A)) -⚬ B =
        λ { case h |*| (x |*| a) =>
          val x1 |*| x2 = dupX(x)
          val h1: $[Sub[A, B]] =
            (h |*| x1) |> captureIntoSub(elimX, dupX)
          f(x2 |*| (h1 |*| a))
        }
      val h: (X |*| A) -⚬ B =
        FreeScaletto.rec(g)
      CapturingFun.Closure(x, h)

    private def multiDup[X](x: Tupled[|*|, $, X])(using lambdas.Context): Validated[Var[?], X -⚬ (X |*| X)] =
      import libretto.lambda.Bin.{Leaf, Branch}
      import Tupled.fromBin

      x.asBin match
        case Leaf(x) =>
          val v: Var[X] = x.resultVar
          lambdas.Context.getSplit(v) match
            case None => invalid(v)
            case Some(dup) =>
              dup match
                case ExtractorOccurrence(_, _) => throw AssertionError(s"TODO: make unrepresentable (at ${summon[SourcePos]})")
                case f: (X -⚬ (X |*| X))       => Valid(f: X -⚬ (X |*| X))
        case Branch(l, r) =>
          (multiDup(fromBin(l)) zip multiDup(fromBin(r)))
            .map { case (dup1, dup2) => andThen(par(dup1, dup2), 𝒞.ixi) }


    private def multiDiscard[X](x: Tupled[|*|, $, X])(using lambdas.Context): Validated[Var[?], X -⚬ One] =
      import libretto.lambda.Bin.{Leaf, Branch}
      import Tupled.fromBin

      x.asBin match
        case Leaf(x) =>
          val v: Var[X] = x.resultVar
          lambdas.Context.getDiscard(v) match
            case None => invalid(v)
            case Some(elimFst) =>
              elimFst[One] match
                case ExtractorOccurrence(_, _) => throw AssertionError(s"TODO: make unrepresentable (at ${summon[SourcePos]})")
                case elimFst: ((X |*| One) -⚬ One) => Valid(andThen(introSnd[X], elimFst))
        case Branch(l, r) =>
          (multiDiscard(fromBin(l)) zip multiDiscard(fromBin(r)))
            .map { case (dis1, dis2) => andThen(par(dis1, dis2), elimFst) }

    override val closure: ClosureOps =
      new ClosureOps {
        override def apply[A, B](using pos: SourcePos, ctx: lambdas.Context)(
          f: lambdas.Context ?=> $[A] => $[B],
        ): $[A =⚬ B] =
          compileClosure(f)(pos)(using ctx)

        override def rec[A, B](using pos: SourcePos, ctx: LambdaContext)(
          f: LambdaContext ?=> $[Sub[A, B]] => $[A] => $[B],
        ): $[A =⚬ B] = {
          import CapturingFun.{Closure, NoCapture}

          val captVar   = VarOrigin.CapturedVars(pos)
          val resultVar = VarOrigin.ClosureVal(pos)

          recLocal(f) match
            case NoCapture(f) =>
              constant(obj(f))
            case Closure(captured, f) =>
              val x = lambdas.Expr.zipN(captured)(captVar)
              lambdas.Expr.map(x, 𝒞.curry(f))(resultVar)
        }
      }

    private def compile[A, B](f: lambdas.Context ?=> $[A] => $[B])(
      pos: SourcePos,
    ): A -⚬ B = {
      import CapturingFun.{Closure, NoCapture}

      val c = ScopeInfo.TopLevelLambda(pos)
      val a = VarOrigin.Lambda(pos)

      metaFunOrError(
        lambdas.delambdifyTopLevel(c, a, f)
      ) match {
        case NoCapture(f) =>
          f
        case Closure(captured, f) =>
          assemblyError(UnboundVariables(c, lambdas.Expr.initialVars(captured)))
      }
    }

    private def compileMeta[A, B](f: lambdas.Context ?=> $[A] => $[B])(
      pos: SourcePos,
    )(using
      ctx: lambdas.Context,
    ): MetaFun[A, B] = {
      val c = ScopeInfo.NestedLambda(pos)
      val a = VarOrigin.Lambda(pos)

      metaFunOrError(
        lambdas.delambdifyNested(c, a, f)
      )
    }

    private def compileClosure[A, B](f: lambdas.Context ?=> $[A] => $[B])(
      pos: SourcePos,
    )(using
      ctx: lambdas.Context,
    ): $[A =⚬ B] = {
      import CapturingFun.{Closure, NoCapture}

      val scopeInfo = ScopeInfo.NestedLambda(pos)
      val bindVar   = VarOrigin.Lambda(pos)
      val captVar   = VarOrigin.CapturedVars(pos)
      val resultVar = VarOrigin.ClosureVal(pos)

      metaFunOrError(
        lambdas.delambdifyNested[A, B](scopeInfo, bindVar, f)
      ) match {
        case Closure(captured, f) =>
          val x = lambdas.Expr.zipN(captured)(captVar)
          lambdas.Expr.map(x, 𝒞.curry(f))(resultVar)
        case NoCapture(f) =>
          val captured0 = $.one(using pos)
          (captured0 map 𝒞.curry(elimFst > f))(resultVar)
      }
    }

    private def metaFunOrError[A, B](
      res: Validated[LinearityViolation, lambdas.Delambdified[A, B]]
    ): MetaFun[A, B] =
      res match
        case Invalid(es) =>
          assemblyErrors(es)
        case Valid(f) =>
          f.traverseFun([x, y] => f => foldTotal(f)) match
            case Invalid(es) => raiseTotalityViolations(es)
            case Valid(f) => f
  }

  /** Preprocessed [[ValSwitch]]. */
  private sealed trait ValHandler[A, R] {
    def compile(using LambdaContext): Validated[
      LinearityViolation,
      Exists[[AA] =>> (Val[A] -⚬ AA, Sink[[x, y] =>> (ScopeInfo, lambdas.Delambdifold[x, y]), |+|, AA, R])]
    ]
  }

  private object ValDecomposition {
    class Last[A, R](
      pos: SourcePos,
      f: LambdaContext ?=> $[Val[A]] => $[R],
    ) extends ValHandler[A, R] {
      override def compile(using LambdaContext): Validated[
        LinearityViolation,
        Exists[[AA] =>> (Val[A] -⚬ AA, Sink[[x, y] =>> (ScopeInfo, lambdas.Delambdifold[x, y]), |+|, AA, R])]
      ] = {
        val scope = ScopeInfo.ValCase(pos)
        val label = VarOrigin.Lambda(pos)
        lambdas.delambdifyFoldNested(scope, label, f)
          .map { g => Exists((id[Val[A]], Sink((scope, g)))) }
      }
    }

    class Cons[A, H, T, R](
      partition: A => Either[H, T],
      pos: SourcePos,
      f: LambdaContext ?=> $[Val[H]] => $[R],
      t: ValHandler[T, R],
    ) extends ValHandler[A, R] {
      override def compile(using LambdaContext): Validated[
        LinearityViolation,
        Exists[[AA] =>> (Val[A] -⚬ AA, Sink[[x, y] =>> (ScopeInfo, lambdas.Delambdifold[x, y]), |+|, AA, R])]
      ] = {
        val scope = ScopeInfo.ValCase(pos)
        val h = lambdas.delambdifyFoldNested(scope, VarOrigin.Lambda(pos), f)
        (h zip t.compile).map { case (h, t) =>
          type AA = Val[H] |+| t.T
          val partition1: Val[A] -⚬ AA =
            mapVal(partition) > liftEither > either(injectL, t.value._1 > injectR)
          val sink: Sink[[x, y] =>> (ScopeInfo, lambdas.Delambdifold[x, y]), |+|, AA, R] =
            Sink.Join(Sink((scope, h)), t.value._2)
          Exists((partition1, sink))
        }
      }
    }

    def from[A, R](cases: ValSwitch.Cases[A, A, R]): ValHandler[A, R] =
      cases match {
        case c1: ValSwitch.FirstCase[a, a_, r] =>
          Last(c1.pos, c1.f)
        case cn: ValSwitch.NextCase[a, a1, a2, r] =>
          summon[a =:= A]
          // (a1 | a2) =:= A
          (from[a1, a2, R](
            (cn.base: ValSwitch.Cases[A, a1, R]).asInstanceOf[ValSwitch.Cases[a1 | a2, a1, R]],
            Last(cn.pos, cn.f),
          ): ValHandler[a1 | a2, R]).asInstanceOf[ValHandler[A, R]]
      }

    def from[A1, A2, R](
      cases: ValSwitch.Cases[A1 | A2, A1, R],
      acc: ValHandler[A2, R],
    ): ValHandler[A1 | A2, R] = {
      def prepend[X](pos: SourcePos, f: LambdaContext ?=> $[Val[X]] => $[R])(using TypeTest[X | A2, X]): ValHandler[X | A2, R] = {
        val partition: (X | A2) => Either[X, A2] = {
          case (x: X) => Left(x)
          case a2     => Right(a2.asInstanceOf[A2])
        }
        Cons[X | A2, X, A2, R](partition, pos, f, acc)
      }
      cases match {
        case c1: ValSwitch.FirstCase[a, a1, r] =>
          prepend[A1](c1.pos, c1.f)(using c1.typeTest)
        case cn: ValSwitch.NextCase[a, a10, a11, r] =>
          // a =:= (A1 | A2)
          // A1 =:= (a10 | a11)
          // Compiler does not infer these equations. See // https://github.com/lampepfl/dotty/issues/17075
          val ev = summon[A1 =:= A1].asInstanceOf[A1 =:= (a10 | a11)]
          (from[a10, a11 | A2, R](
            (cn.base: ValSwitch.Cases[a, a10, r]).asInstanceOf[ValSwitch.Cases[a10 | (a11 | A2), a10, R]],
            prepend[a11](cn.pos, cn.f)(using cn.typeTest),
          ): ValHandler[a10 | (a11 | A2), R])
            .asInstanceOf[ValHandler[A1 | A2, R]]
      }
    }
  }

  override def switchVal[A, R](
    a: $[Val[A]],
    cases: ValSwitch.Cases[A, A, R],
  )(pos: SourcePos)(using
    LambdaContext,
  ): $[R] =
    ValDecomposition.from(cases).compile
      .flatMap {
        case Exists.Some((partition, sink)) =>
          switchSink(
            $.map(a)(partition)(pos),
            sink,
          )(pos)
      }
      .valueOr(assemblyErrors)

  override val |*| : ConcurrentPairOps =
    new ConcurrentPairOps {}

  private def mapTupled[A, B](a: Tupled[|*|, lambdas.Expr, A], f: PartialFun[A, B])(pos: SourcePos)(using lambdas.Context): lambdas.Expr[B] =
    lambdas.Expr.map(
      lambdas.Expr.zipN(a)(VarOrigin.CapturedVars(pos)),
      f,
    )(VarOrigin.FunAppRes(pos))

  private type AssemblyErrorData =
    UnboundVariables
    | LinearityViolation
    | UnusedInBranch
    | MisplacedExtractor
    | PatternMatchError
    | MissingComonoidOpForRecFun
    | ForbiddenCaptureOfRecursiveCallsIntoLibCode

  private def assemblyError(e: AssemblyErrorData): Nothing =
    assemblyErrors(NonEmptyList.of(e))

  private def assemblyErrors(es: NonEmptyList[AssemblyErrorData]): Nothing =
    throw AssemblyError(es)

  override class AssemblyError private[FreeScaletto](
    errors: NonEmptyList[AssemblyErrorData]
  ) extends Exception(AssemblyError.formatMessages(errors))

  private object AssemblyError {

    def formatMessages(es: NonEmptyList[AssemblyErrorData]): String =
      val lines =
        es.toList.flatMap { e =>
          val NonEmptyList(line0, lines) = formatMessage(e)
          val l0 = s" * $line0"
          val ls = lines.map(l => s"   $l")
          l0 :: ls
        }
      ("Compilation errors:" :: lines).mkString("\n")

    /** Returns a list of lines. */
    def formatMessage(e: AssemblyErrorData): NonEmptyList[String] =
      e match
        case e: UnboundVariables   => unboundMsg(e)
        case e: LinearityViolation => linearityMsg(e)
        case e: UnusedInBranch     => unusedInBranchMsg(e)
        case e: MisplacedExtractor => misplacedExtMsg(e)
        case e: PatternMatchError  => patmatMsg(e)
        case e: MissingComonoidOpForRecFun => missingComonoidForRecFunMsg(e)
        case e: ForbiddenCaptureOfRecursiveCallsIntoLibCode => forbiddenCaptureOfRecursiveCallsIntoLibCodeMsg(e)

    private def unboundMsg(e: UnboundVariables): NonEmptyList[String] =
      val UnboundVariables(scope, vs) = e
      NonEmptyList(
        s"Undefined variables (possibly from outer context) when compiling ${scope.print}:",
        vs.list.map(v => s"- $v") :+ "Consider using λ.closure instead."
      )

    private def linearityMsg(e: LinearityViolation): NonEmptyList[String] = {
      import Lambdas.LinearityViolation.{Overused, Unused}

      def overusedMsg(vs: Var.Set[VarOrigin], exitedScope: ScopeInfo): NonEmptyList[String] =
        NonEmptyList(
          s"Variables used more than once when compiling ${exitedScope.print}:",
          vs.list.map(v => s" - ${v.origin.print}")
        )

      def unusedMsg[A](v: Var[A], exitedScope: ScopeInfo): NonEmptyList[String] =
        NonEmptyList.of(
          s"Unused variable: ${v.origin.print}",
          s"When exiting scope: ${exitedScope.print}",
        )

      e match
        case Overused(vs, ctxInfo) => overusedMsg(vs, ctxInfo)
        case Unused(v, ctxInfo)    => unusedMsg(v, ctxInfo)
    }

    def unusedInBranchMsg(e: UnusedInBranch): NonEmptyList[String] =
      val UnusedInBranch(vs, scope) = e
      val pos = scope.sourcePos
      NonEmptyList(
        s"Case at ${pos.filename}:${pos.line} has unused variables which are used by some of the other cases:",
        vs.list.map(v => s" - ${v.origin.print}")
      )

    private def misplacedExtMsg(e: MisplacedExtractor): NonEmptyList[String] =
      e match { case MisplacedExtractor(ExtractorOccurrence(ext, pos)) =>
        NonEmptyList.of(s"Extractor used outside of switch pattern: ${ext.show} (at ${pos.filename}:${pos.line})")
      }

    private def patmatMsg(e: PatternMatchError): NonEmptyList[String] =
      import patmat.PatternMatchError.{IncompatibleExtractors, NonExhaustiveness}

      val PatternMatchError(switchPos, error) = e

      error match
        case IncompatibleExtractors(base, incompatibles) =>
          "Use of incompatible extractors:" ::
            s"    ${base.partition} (from ${base.partitioning})" ::
            s"  is incompatible with" ::
            incompatibles.map { ext =>
              s"  - ${ext.partition} (from ${ext.partitioning})"
            }
        case NonExhaustiveness(ext) =>
          NonEmptyList.of(
            "Non-exhaustive pattern match. It would fail on",
            s"  ${ext.show} (at ${switchPos.filename}:${switchPos.line})"
          )

    private def missingComonoidForRecFunMsg(e: MissingComonoidOpForRecFun): NonEmptyList[String] =
      e match
        case MissingDupForRecFun(pos, vs) =>
          NonEmptyList(
            s"Recursive function definition at ${pos.filename}:${pos.line} captures the following variables which lack the ability to be duplicated:",
            vs.list.map { v =>  s"  - ${v.origin.print}" }
            ::: List(
              "Note: duplication is needed for potential recursive invocation.",
              "Consider using the `case *(a)` extractor for each of the varaibles to register a Comonoid instance."
            )
          )
        case MissingDiscardForRecFun(pos, vs) =>
          NonEmptyList(
            s"Recursive function definition at ${pos.filename}:${pos.line} captures the following variables which lack the ability to be discarded:",
            vs.list.map { v =>  s"  - ${v.origin.print}" }
            ::: List(
              "Note: discarding is needed when there's no more recursive invocation.",
              "Consider using the `case *(a)` extractor for each of the varaibles to register a Comonoid instance."
            )
          )

    private def forbiddenCaptureOfRecursiveCallsIntoLibCodeMsg(e: ForbiddenCaptureOfRecursiveCallsIntoLibCode): NonEmptyList[String] =
      val ForbiddenCaptureOfRecursiveCallsIntoLibCode(defnPos, refs) = e
      NonEmptyList(
        s"Library code defininition at ${defnPos.filename}:${defnPos.line} attempts to capture self-references to parent recursive functions:",
        refs.map { ref =>
          val p = ref.defnPos
          s" - Recursive function defined at ${p.filename}:${p.line}"
        }.toList
      ) :+ s"Consider rewriting the library function as taking a `Sub`-routine instead of capturing a self-reference."
  }

  private case class UnboundVariables(scope: ScopeInfo, vs: Var.Set[VarOrigin])

  private sealed trait MissingComonoidOpForRecFun

  private case class MissingDupForRecFun(recFunPos: SourcePos, vs: Var.Set[VarOrigin])
    extends MissingComonoidOpForRecFun

  private object MissingDupForRecFun {
    def apply(recFunPos: SourcePos, vs: NonEmptyList[Var[?]]): MissingDupForRecFun =
      MissingDupForRecFun(recFunPos, Var.Set.fromList(vs.toList))
  }

  private case class MissingDiscardForRecFun(recFunPos: SourcePos, vs: Var.Set[VarOrigin])
    extends MissingComonoidOpForRecFun

  private object MissingDiscardForRecFun {
    def apply(recFunPos: SourcePos, vs: NonEmptyList[Var[?]]): MissingDiscardForRecFun =
      MissingDiscardForRecFun(recFunPos, Var.Set.fromList(vs.toList))
  }

  private case class ForbiddenCaptureOfRecursiveCallsIntoLibCode(
    libCodeDefnPos: SourcePos,
    capturedRefs: NonEmptyList[-⚬.SelfRef[?, ?]],
  )

  private def raiseTotalityViolations(es: NonEmptyList[(SourcePos, NonEmptyList[String])]): Nothing =
    libretto.lambda.UnhandledCase.raise(s"raiseTotalityViolations($es)")
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy