All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.aya.tyck.ExprTycker Maven / Gradle / Ivy

There is a newer version: 0.36.0
Show newest version
// Copyright (c) 2020-2024 Tesla (Yinsen) Zhang.
// Use of this source code is governed by the MIT license that can be found in the LICENSE.md file.
package org.aya.tyck;

import kala.collection.immutable.ImmutableSeq;
import kala.collection.immutable.ImmutableTreeSeq;
import kala.collection.mutable.MutableList;
import kala.collection.mutable.MutableTreeSet;
import org.aya.generic.Constants;
import org.aya.generic.term.DTKind;
import org.aya.pretty.doc.Doc;
import org.aya.syntax.concrete.Expr;
import org.aya.syntax.core.Closure;
import org.aya.syntax.core.def.DataDefLike;
import org.aya.syntax.core.def.PrimDef;
import org.aya.syntax.core.repr.AyaShape;
import org.aya.syntax.core.repr.ShapeRecognition;
import org.aya.syntax.core.term.*;
import org.aya.syntax.core.term.call.ClassCall;
import org.aya.syntax.core.term.call.DataCall;
import org.aya.syntax.core.term.call.MetaCall;
import org.aya.syntax.core.term.repr.IntegerTerm;
import org.aya.syntax.core.term.repr.ListTerm;
import org.aya.syntax.core.term.repr.MetaLitTerm;
import org.aya.syntax.core.term.repr.StringTerm;
import org.aya.syntax.core.term.xtt.DimTerm;
import org.aya.syntax.core.term.xtt.DimTyTerm;
import org.aya.syntax.core.term.xtt.EqTerm;
import org.aya.syntax.ref.*;
import org.aya.syntax.telescope.AbstractTele;
import org.aya.syntax.telescope.Signature;
import org.aya.tyck.ctx.LocalLet;
import org.aya.tyck.error.*;
import org.aya.tyck.pat.ClauseTycker;
import org.aya.tyck.tycker.AbstractTycker;
import org.aya.tyck.tycker.AppTycker;
import org.aya.tyck.tycker.Unifiable;
import org.aya.unify.TermComparator;
import org.aya.unify.Unifier;
import org.aya.util.Ordering;
import org.aya.util.error.Panic;
import org.aya.util.error.SourceNode;
import org.aya.util.error.SourcePos;
import org.aya.util.error.WithPos;
import org.aya.util.reporter.Reporter;
import org.jetbrains.annotations.Contract;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

import java.util.Comparator;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Supplier;

public final class ExprTycker extends AbstractTycker implements Unifiable {
  public final @NotNull MutableTreeSet> withTerms =
    MutableTreeSet.create(Comparator.comparing(SourceNode::sourcePos));
  public final @NotNull MutableList> userHoles = MutableList.create();
  private @NotNull LocalLet localLet;

  public void addWithTerm(@NotNull Expr.WithTerm with, @NotNull SourcePos pos, @NotNull Term type) {
    withTerms.add(new WithPos<>(pos, with));
    with.theCoreType().set(type);
  }

  public ExprTycker(
    @NotNull TyckState state, @NotNull LocalCtx ctx, @NotNull LocalLet let,
    @NotNull Reporter reporter
  ) {
    super(state, ctx, reporter);
    this.localLet = let;
  }

  public ExprTycker(@NotNull TyckState state, @NotNull Reporter reporter) {
    this(state, new MapLocalCtx(), new LocalLet(), reporter);
  }

  public void solveMetas() {
    state.solveMetas(reporter);
    withTerms.forEach(with -> with.data().theCoreType().update(this::freezeHoles));
    userHoles.forEach(hole -> hole.data().solution().update(this::freezeHoles));
  }

  /**
   * @param type may not be in whnf, because we want unnormalized type to be used for unification.
   */
  public @NotNull Jdg inherit(@NotNull WithPos expr, @NotNull Term type) {
    return switch (expr.data()) {
      case Expr.Lambda(var ref, var body) -> switch (whnf(type)) {
        case DepTypeTerm(var kind, var dom, var cod) when kind == DTKind.Pi -> {
          // unifyTyReported(param, dom, expr);
          var core = subscoped(ref, dom, () ->
            inherit(body, cod.apply(new FreeTerm(ref))).wellTyped()).bind(ref);
          yield new Jdg.Default(new LamTerm(core), type);
        }
        case EqTerm eq -> {
          var core = subscoped(ref, DimTyTerm.INSTANCE, () ->
            inherit(body, eq.appA(new FreeTerm(ref))).wellTyped()).bind(ref);
          checkBoundaries(eq, core, body.sourcePos(), msg ->
            new CubicalError.BoundaryDisagree(expr, msg, new UnifyInfo(state)));
          yield new Jdg.Default(new LamTerm(core), eq);
        }
        case MetaCall metaCall -> {
          var pi = metaCall.asDt(this::whnf, "_dom", "_cod", DTKind.Pi);
          if (pi == null) yield fail(expr.data(), type, BadTypeError.absOnNonPi(state, expr, type));
          unifier(metaCall.ref().pos(), Ordering.Eq).compare(metaCall, pi, null);
          var core = subscoped(ref, pi.param(), () ->
            inherit(body, pi.body().apply(new FreeTerm(ref))).wellTyped()).bind(ref);
          yield new Jdg.Default(new LamTerm(core), pi);
        }
        default -> fail(expr.data(), type, BadTypeError.absOnNonPi(state, expr, type));
      };
      case Expr.Hole hole -> {
        var freshHole = freshMeta(Constants.randomName(hole), expr.sourcePos(),
          new MetaVar.OfType(type), hole.explicit());
        hole.solution().set(freshHole);
        userHoles.append(new WithPos<>(expr.sourcePos(), hole));
        if (hole.explicit()) fail(new Goal(state, freshHole, localCtx().clone(), hole.accessibleLocal()));
        yield new Jdg.Default(freshHole, type);
      }
      case Expr.LitInt(var end) -> {
        var ty = whnf(type);
        if (ty == DimTyTerm.INSTANCE) {
          if (end == 0 || end == 1) yield new Jdg.Default(end == 0 ? DimTerm.I0 : DimTerm.I1, ty);
          else yield fail(expr.data(), new PrimError.BadInterval(expr.sourcePos(), end));
        }
        yield inheritFallbackUnify(ty, synthesize(expr), expr);
      }
      case Expr.BinTuple(var lhs, var rhs) -> switch (whnf(type)) {
        case DepTypeTerm(var kind, var lhsT, var rhsTClos) when kind == DTKind.Sigma -> {
          var lhsX = inherit(lhs, lhsT).wellTyped();
          var rhsX = inherit(rhs, rhsTClos.apply(lhsX)).wellTyped();
          yield new Jdg.Default(new TupTerm(lhsX, rhsX), type);
        }
        case MetaCall meta -> inheritFallbackUnify(meta, synthesize(expr), expr);
        default -> fail(expr.data(), BadTypeError.sigmaCon(state, expr, type));
      };
      case Expr.Array arr when arr.arrayBlock().isRight()
        && whnf(type) instanceof DataCall dataCall
        && state.shapeFactory.find(dataCall.ref()).getOrNull() instanceof ShapeRecognition recog
        && recog.shape() == AyaShape.LIST_SHAPE -> {
        var arrayBlock = arr.arrayBlock().getRightValue();
        var elementTy = dataCall.args().get(0);
        var results = ImmutableTreeSeq.from(arrayBlock.exprList().map(
          element -> inherit(element, elementTy).wellTyped()));
        yield new Jdg.Default(new ListTerm(results, recog, dataCall), type);
      }
      case Expr.Match(var discriminant, var clauses) -> {
        var wellArgs = discriminant.map(this::synthesize);
        var telescope = new AbstractTele.Locns(
          wellArgs.map(x -> new Param(LocalVar.IGNORED.name(), x.type(), true)),
          type);
        var signature = new Signature(telescope, discriminant.map(WithPos::sourcePos));
        var clauseTycker = new ClauseTycker.Worker(
          new ClauseTycker(this),
          // always nameless
          ImmutableSeq.fill(discriminant.size(), LocalVar.IGNORED),
          signature, clauses, ImmutableSeq.empty(), true);
        var wellClauses = clauseTycker.check(expr.sourcePos())
          .wellTyped()
          .map(WithPos::data);
        yield new Jdg.Default(new MatchTerm(wellArgs.map(Jdg::wellTyped), wellClauses), type);
      }
      case Expr.Let let -> checkLet(let, e -> inherit(e, type));
      default -> inheritFallbackUnify(type, synthesize(expr), expr);
    };
  }

  /**
   * @param type   expected type
   * @param result wellTyped + actual type from synthesize
   * @param expr   original expr, used for error reporting
   */
  private @NotNull Jdg inheritFallbackUnify(@NotNull Term type, @NotNull Jdg result, @NotNull WithPos expr) {
    type = whnf(type);
    var resultType = result.type();
    // Try coercive subtyping for (Path A ...) into (I -> A)
    if (type instanceof DepTypeTerm(var kind, var dom, var cod) && kind == DTKind.Pi && dom == DimTyTerm.INSTANCE) {
      if (whnf(resultType) instanceof EqTerm eq) {
        var closure = makeClosurePiPath(expr, eq, cod, result.wellTyped());
        if (closure == null) return makeErrorResult(type, result);
        return new Jdg.Default(new LamTerm(closure), eq);
      }
    }
    // Try coercive subtyping for (I -> A) into (Path A ...)
    if (type instanceof EqTerm eq) {
      if (whnf(resultType) instanceof DepTypeTerm(
        var kind, var dom, var cod
      ) && kind == DTKind.Pi && dom == DimTyTerm.INSTANCE) {
        var closure = makeClosurePiPath(expr, eq, cod, result.wellTyped());
        if (closure == null) return makeErrorResult(type, result);
        checkBoundaries(eq, closure, expr.sourcePos(), msg ->
          new CubicalError.BoundaryDisagree(expr, msg, new UnifyInfo(state)));
        return new Jdg.Default(new LamTerm(closure), eq);
      }
    }
    // Try coercive subtyping between classes
    if (type instanceof ClassCall clazz) {
      // Try coercive subtyping for `SomeClass (foo := 114514)` into `SomeClass`
      resultType = whnf(resultType);
      if (resultType instanceof ClassCall resultClazz) {
        // TODO: check whether resultClazz <: clazz
        if (true) {
          // No need to coerce
          if (clazz.args().size() == resultClazz.args().size()) return result;
          var forget = resultClazz.args().drop(clazz.args().size());
          return new Jdg.Default(ClassCastTerm.make(clazz.ref(), result.wellTyped(), clazz.args(), forget), type);
        } else {
          return makeErrorResult(type, result);
        }
      }
    }
    if (unifyTyReported(type, resultType, expr)) return result;
    return makeErrorResult(type, result);
  }

  private static @NotNull Jdg makeErrorResult(@NotNull Term type, @NotNull Jdg result) {
    return new Jdg.Default(new ErrorTerm(result.wellTyped()), type);
  }

  private @Nullable Closure makeClosurePiPath(@NotNull WithPos expr, EqTerm eq, Closure cod, @NotNull Term core) {
    var ref = new FreeTerm(new LocalVar("i"));
    var wellTyped = subscoped(ref.name(), DimTyTerm.INSTANCE, () ->
      unifyTyReported(eq.appA(ref), cod.apply(ref), expr));
    if (!wellTyped) return null;
    if (expr.data() instanceof Expr.WithTerm with)
      addWithTerm(with, expr.sourcePos(), eq);
    return core instanceof LamTerm(var clo) ? clo
      // This is kinda unsafe but it should be fine
      : new Closure.Jit(i -> new AppTerm(core, i));
  }

  public @NotNull Term ty(@NotNull WithPos expr) {
    return switch (expr.data()) {
      case Expr.Hole hole -> {
        var meta = freshMeta(Constants.randomName(hole), expr.sourcePos(), MetaVar.Misc.IsType, hole.explicit());
        if (hole.explicit()) fail(new Goal(state, meta, localCtx().clone(), hole.accessibleLocal()));
        yield meta;
      }
      case Expr.Sort(var kind, var lift) -> new SortTerm(kind, lift);
      case Expr.DepType(var kind, var param, var last) -> {
        var wellParam = ty(param.typeExpr());
        addWithTerm(param, param.sourcePos(), wellParam);
        yield subscoped(param.ref(), wellParam, () ->
          new DepTypeTerm(kind, wellParam, ty(last).bind(param.ref())));
      }
      case Expr.Let let -> checkLet(let, e -> lazyJdg(ty(e))).wellTyped();
      default -> {
        var result = synthesize(expr);
        if (!(result.type() instanceof SortTerm))
          fail(expr.data(), BadTypeError.doNotLike(state, expr, result.type(),
            _ -> Doc.plain("type")));
        yield result.wellTyped();
      }
    };
  }

  public @NotNull Jdg.Sort sort(@NotNull WithPos expr) {
    return new Jdg.Sort(sort(expr, ty(expr)));
  }

  private @NotNull SortTerm sort(@NotNull WithPos errorMsg, @NotNull Term term) {
    return switch (whnf(term)) {
      case SortTerm u -> u;
      case MetaCall hole -> {
        unifyTyReported(hole, SortTerm.Type0, errorMsg);
        yield SortTerm.Type0;
      }
      default -> {
        fail(BadTypeError.doNotLike(state, errorMsg, term, _ -> Doc.plain("universe")));
        yield SortTerm.Type0;
      }
    };
  }

  public @NotNull Jdg synthesize(@NotNull WithPos expr) {
    var result = doSynthesize(expr);
    if (expr.data() instanceof Expr.WithTerm with) {
      addWithTerm(with, expr.sourcePos(), result.type());
    }
    return result;
  }

  public @NotNull Jdg doSynthesize(@NotNull WithPos expr) {
    return switch (expr.data()) {
      case Expr.Sugar s -> throw new Panic(s.getClass() + " is desugared, should be unreachable");
      case Expr.App(var f, var a) -> {
        int lift;
        if (f.data() instanceof Expr.Lift(var inner, var level)) {
          lift = level;
          f = inner;
        } else lift = 0;
        if (f.data() instanceof Expr.Ref ref) {
          yield checkApplication(ref, lift, expr.sourcePos(), a);
        } else try {
          yield ArgsComputer.generateApplication(this, a, synthesize(f)).lift(lift);
        } catch (NotPi e) {
          yield fail(expr.data(), BadTypeError.appOnNonPi(state, expr, e.actual));
        }
      }
      case Expr.Proj(var p, var ix, _, _) -> {
        var result = synthesize(p);
        var wellP = result.wellTyped();

        yield ix.fold(iix -> {
          if (iix != ProjTerm.INDEX_FST && iix != ProjTerm.INDEX_SND) {
            return fail(expr.data(), new ClassError.ProjIxError(expr, iix));
          }
          return switch (whnf(result.type())) {
            case MetaCall metaCall -> {
              var sigma = metaCall.asDt(this::whnf, "_fstTy", "_sndTy", DTKind.Sigma);
              if (sigma == null) yield fail(expr.data(), BadTypeError.sigmaAcc(state, expr, iix, result.type()));
              unifier(metaCall.ref().pos(), Ordering.Eq).compare(metaCall, sigma, null);
              if (iix == ProjTerm.INDEX_FST) {
                yield new Jdg.Default(ProjTerm.fst(wellP), sigma.param());
              } else {
                yield new Jdg.Default(ProjTerm.snd(wellP), sigma.body().apply(ProjTerm.fst(wellP)));
              }
            }
            case DepTypeTerm(var kind, var param, var body) when kind == DTKind.Sigma -> {
              var ty = iix == ProjTerm.INDEX_FST ? param : body.apply(ProjTerm.fst(wellP));
              yield new Jdg.Default(ProjTerm.make(wellP, iix == ProjTerm.INDEX_FST), ty);
            }
            default -> fail(expr.data(), BadTypeError.sigmaAcc(state, expr, iix, result.type()));
          };
        }, member -> {
          // TODO: MemberCall
          throw new UnsupportedOperationException("TODO");
        });
      }
      case Expr.Hole hole -> throw new UnsupportedOperationException("TODO");
      case Expr.Lambda lam -> inherit(expr, generatePi(lam, expr.sourcePos()));
      case Expr.LitInt(var integer) -> {
        // TODO[literal]: int literals. Currently the parser does not allow negative literals.
        var defs = state.shapeFactory.findImpl(AyaShape.NAT_SHAPE);
        if (defs.isEmpty()) {
          yield fail(expr.data(), new NoRuleError(expr, null));
        }
        if (defs.sizeGreaterThan(1)) {
          var type = freshMeta("_ty" + integer + "'", expr.sourcePos(), MetaVar.Misc.IsType, false);
          yield new Jdg.Default(new MetaLitTerm(expr.sourcePos(), integer, defs, type), type);
        }
        var match = defs.getFirst();
        var type = new DataCall((DataDefLike) match.def(), 0, ImmutableSeq.empty());
        yield new Jdg.Default(new IntegerTerm(integer, match.recog(), type), type);
      }
      case Expr.Lift(WithPos(var innerPos, Expr.Ref ref), var level) ->
        checkApplication(ref, level, innerPos, ImmutableSeq.empty());
      case Expr.Lift(var inner, var level) -> synthesize(inner).map(x -> x.elevate(level));
      case Expr.LitString litStr -> {
        if (!state.primFactory.have(PrimDef.ID.STRING))
          yield fail(litStr, new NoRuleError(expr, null));
        yield new Jdg.Default(new StringTerm(litStr.string()), state.primFactory.getCall(PrimDef.ID.STRING));
      }
      case Expr.Ref ref -> checkApplication(ref, 0, expr.sourcePos(), ImmutableSeq.empty());
      case Expr.DepType _ -> lazyJdg(ty(expr));
      case Expr.Sort _ -> sort(expr);
      case Expr.BinTuple(var lhs, var rhs) -> {
        var lhsX = synthesize(lhs);
        var rhsX = synthesize(rhs);
        var wellTyped = new TupTerm(lhsX.wellTyped(), rhsX.wellTyped());
        var ty = new DepTypeTerm(DTKind.Sigma, lhsX.type(), Closure.mkConst(rhsX.type()));

        yield new Jdg.Default(wellTyped, ty);
      }
      case Expr.Let let -> checkLet(let, this::synthesize);
      case Expr.Error err -> new Jdg.Default(new ErrorTerm(err), ErrorTerm.typeOf(err));
      case Expr.Array arr when arr.arrayBlock().isRight() -> {
        var arrayBlock = arr.arrayBlock().getRightValue();
        var elements = arrayBlock.exprList();

        // find def
        var defs = state.shapeFactory.findImpl(AyaShape.LIST_SHAPE);
        if (defs.isEmpty()) yield fail(arr, new NoRuleError(expr, null));
        if (defs.sizeGreaterThan(1)) {
          var elMeta = freshMeta("el_ty", expr.sourcePos(), MetaVar.Misc.IsType, false);
          var tyMeta = freshMeta("arr_ty", expr.sourcePos(), MetaVar.Misc.IsType, false);
          var results = elements.map(element -> inherit(element, elMeta).wellTyped());
          yield new Jdg.Default(new MetaLitTerm(expr.sourcePos(), results, defs, tyMeta), tyMeta);
        }
        var match = defs.getFirst();
        var def = (DataDefLike) match.def();

        // List (A : Type)
        var sort = def.signature().telescopeRich(0);
        // the sort of type below.
        var elementTy = freshMeta(sort.name(), expr.sourcePos(), new MetaVar.OfType(sort.type()), false);

        // do type check
        var results = ImmutableTreeSeq.from(elements.map(element -> inherit(element, elementTy).wellTyped()));
        var type = new DataCall(def, 0, ImmutableSeq.of(elementTy));
        yield new Jdg.Default(new ListTerm(results, match.recog(), type), type);
      }
      case Expr.New(var classCall) -> {
        var wellTyped = synthesize(classCall);
        if (!(wellTyped.wellTyped() instanceof ClassCall call)) {
          yield fail(expr.data(), BadTypeError.classCon(state, classCall, wellTyped.wellTyped()));
        }

        // check whether the call is fully applied
        if (call.args().size() != call.ref().members().size()) {
          yield fail(expr.data(), new ClassError.NotFullyApplied(classCall));
        }

        yield new Jdg.Default(new NewTerm(call), call);
      }
      case Expr.Unresolved _ -> Panic.unreachable();
      default -> fail(expr.data(), new NoRuleError(expr, null));
    };
  }

  private @NotNull Jdg checkApplication(
    @NotNull Expr.Ref f, int lift, @NotNull SourcePos sourcePos,
    @NotNull ImmutableSeq args
  ) {
    try {
      var result = doCheckApplication(sourcePos, f.var(), lift, args);
      addWithTerm(f, sourcePos, result.type());
      return result;
    } catch (NotPi notPi) {
      var expr = new Expr.App(new WithPos<>(sourcePos, f), args);
      return fail(expr, BadTypeError.appOnNonPi(state, new WithPos<>(sourcePos, expr), notPi.actual));
    }
  }

  private @NotNull Jdg doCheckApplication(
    @NotNull SourcePos sourcePos, @NotNull AnyVar f,
    int lift, @NotNull ImmutableSeq args
  ) throws NotPi {
    return switch (f) {
      case LocalVar ref when localLet.contains(ref) ->
        ArgsComputer.generateApplication(this, args, localLet.get(ref)).lift(lift);
      case LocalVar lVar -> ArgsComputer.generateApplication(this, args,
        new Jdg.Default(new FreeTerm(lVar), localCtx().get(lVar))).lift(lift);
      case CompiledVar(var content) -> new AppTycker<>(this, sourcePos, args.size(), lift, (params, k) ->
        computeArgs(sourcePos, args, params, k)).checkCompiledApplication(content);
      case DefVar defVar -> new AppTycker<>(this, sourcePos, args.size(), lift, (params, k) ->
        computeArgs(sourcePos, args, params, k)).checkDefApplication(defVar);
      default -> throw new UnsupportedOperationException("TODO");
    };
  }

  private Jdg computeArgs(
    @NotNull SourcePos pos, @NotNull ImmutableSeq args,
    @NotNull AbstractTele params, @NotNull BiFunction k
  ) throws NotPi {
    return new ArgsComputer(this, pos, args, params).boot(k);
  }

  /**
   * tyck a let expr with the given checker
   *
   * @param checker check the type of the body of {@param let}
   */
  private @NotNull Jdg checkLet(@NotNull Expr.Let let, @NotNull Function, Jdg> checker) {
    // pushing telescopes into lambda params, for example:
    // `let f (x : A) : B x` is desugared to `let f : Pi (x : A) -> B x`
    var letBind = let.bind();
    var typeExpr = Expr.buildPi(letBind.sourcePos(),
      letBind.telescope().view(), letBind.result());
    // as well as the body of the binding, for example:
    // `let f x := g` is desugared to `let f := \x => g`
    var definedAsExpr = Expr.buildLam(letBind.sourcePos(),
      letBind.telescope().view().map(Expr.Param::ref), letBind.definedAs());

    // Now everything is in form `let f : G := g in h`

    var type = freezeHoles(ty(typeExpr));
    var definedAsResult = inherit(definedAsExpr, type);

    return subscoped(() -> {
      localLet.put(let.bind().bindName(), definedAsResult);
      return checker.apply(let.body());
    });
  }

  /// region Overrides and public APIs
  @Override public @NotNull TermComparator unifier(@NotNull SourcePos pos, @NotNull Ordering order) {
    return new Unifier(state(), localCtx(), reporter(), pos, order, true);
  }
  @Contract(mutates = "this") public  R subscoped(@NotNull Supplier action) {
    var derived = localCtx().derive();
    var parentCtx = setLocalCtx(derived);
    var parentDef = setLocalLet(localLet.derive());
    var result = action.get();
    setLocalCtx(parentCtx);
    setLocalLet(parentDef);
    derived.extractLocal().forEach(state::removeConnection);
    return result;
  }
  @Contract(mutates = "this")
  public  R subscoped(@NotNull LocalVar var, @NotNull Term type, @NotNull Supplier action) {
    var parentCtx = setLocalCtx(localCtx().derive1(var, type));
    var result = action.get();
    setLocalCtx(parentCtx);
    state.removeConnection(var);
    return result;
  }

  public @NotNull LocalLet localLet() { return localLet; }
  public @NotNull LocalLet setLocalLet(@NotNull LocalLet let) {
    var old = localLet;
    this.localLet = let;
    return old;
  }
  /// endregion Overrides and public APIs

  protected static final class NotPi extends Exception {
    public final @NotNull Term actual;
    public NotPi(@NotNull Term actual) { this.actual = actual; }
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy