All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.aya.unify.TermComparator Maven / Gradle / Ivy

There is a newer version: 0.36.0
Show newest version
// Copyright (c) 2020-2024 Tesla (Yinsen) Zhang.
// Use of this source code is governed by the MIT license that can be found in the LICENSE.md file.
package org.aya.unify;

import kala.collection.immutable.ImmutableSeq;
import org.aya.generic.Renamer;
import org.aya.generic.term.DTKind;
import org.aya.generic.term.SortKind;
import org.aya.prettier.AyaPrettierOptions;
import org.aya.syntax.core.term.*;
import org.aya.syntax.core.term.call.*;
import org.aya.syntax.core.term.marker.Formation;
import org.aya.syntax.core.term.repr.IntegerTerm;
import org.aya.syntax.core.term.repr.ListTerm;
import org.aya.syntax.core.term.repr.MetaLitTerm;
import org.aya.syntax.core.term.xtt.*;
import org.aya.syntax.ref.LocalCtx;
import org.aya.syntax.ref.LocalVar;
import org.aya.syntax.ref.MetaVar;
import org.aya.syntax.telescope.AbstractTele;
import org.aya.tyck.TyckState;
import org.aya.tyck.error.LevelError;
import org.aya.tyck.tycker.AbstractTycker;
import org.aya.tyck.tycker.Contextful;
import org.aya.util.Ordering;
import org.aya.util.Pair;
import org.aya.util.error.Panic;
import org.aya.util.error.SourcePos;
import org.aya.util.reporter.Reporter;
import org.jetbrains.annotations.ApiStatus;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

import java.util.Objects;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.function.UnaryOperator;

public abstract sealed class TermComparator extends AbstractTycker permits Unifier {
  protected final @NotNull SourcePos pos;
  protected @NotNull Ordering cmp;
  // If false, we refrain from solving meta, and return false if we encounter a non-identical meta.
  private boolean solveMeta = true;
  private @Nullable FailureData failure = null;
  final @NotNull Renamer nameGen = new Renamer();

  public TermComparator(
    @NotNull TyckState state, @NotNull LocalCtx ctx,
    @NotNull Reporter reporter, @NotNull SourcePos pos, @NotNull Ordering cmp
  ) {
    super(state, ctx, reporter);
    nameGen.store(ctx);
    this.pos = pos;
    this.cmp = cmp;
  }

  /**
   * Trying to solve {@param meta} with {@param rhs}
   *
   * @param rhs in whnf
   */
  protected abstract @Nullable Term doSolveMeta(@NotNull MetaCall meta, @NotNull Term rhs, @Nullable Term type);

  protected @Nullable Term solveMeta(@NotNull MetaCall meta, @NotNull Term rhs, @Nullable Term type) {
    var result = !solveMeta ? null : doSolveMeta(meta, whnf(rhs), type);
    if (result == null) fail(meta, rhs);
    return result;
  }

  /// region Utilities
  private void fail(@NotNull Term lhs, @NotNull Term rhs) {
    if (failure == null) {
      failure = new FailureData(lhs, rhs);
    }
  }

  private @NotNull Panic noRules(@NotNull Term term) {
    return new Panic(term.getClass() + ": " + term.toDoc(AyaPrettierOptions.debug()).debugRender());
  }
  /// endregion Utilities

  /**
   * Compare arguments ONLY.
   * For lossy comparisons, when we fail, we will need to compare them again later,
   * so don't forget to reset the {@link #failure} after first failure.
   */
  private @Nullable Term compareApprox(@NotNull Term lhs, @NotNull Term rhs) {
    return switch (new Pair<>(lhs, rhs)) {
      case Pair(FnCall lFn, FnCall rFn) -> compareCallApprox(lFn, rFn);
      case Pair(DataCall lFn, DataCall rFn) -> compareCallApprox(lFn, rFn);
      case Pair(PrimCall lFn, PrimCall rFn) -> compareCallApprox(lFn, rFn);
      case Pair(IntegerTerm lInt, IntegerTerm rInt) -> lInt.repr() == rInt.repr() ? lInt.type() : null;
      case Pair(ConCallLike lCon, ConCallLike rCon) -> compareCallApprox(lCon, rCon);
      case Pair(MemberCall lMem, MemberCall rMem) -> {
        if (!lMem.ref().equals(rMem.ref())) yield null;
        // TODO: type info?
        if (!compare(lMem.of(), rMem.of(), null)) yield null;
        yield compareMany(lMem.args(), rMem.args(),
          lMem.ref().signature().inst(ImmutableSeq.of(lMem.of())).lift(Math.min(lMem.ulift(), rMem.ulift())));
      }
      default -> null;
    };
  }

  /**
   * Compare the arguments of two callable ONLY, this method will NOT try to normalize and then compare (while the old project does).
   */
  private @Nullable Term compareCallApprox(@NotNull Callable.Tele lhs, @NotNull Callable.Tele rhs) {
    if (!lhs.ref().equals(rhs.ref())) return null;
    return compareMany(lhs.args(), rhs.args(),
      lhs.ref().signature().lift(Math.min(lhs.ulift(), rhs.ulift())));
  }

  private  R swapped(@NotNull Supplier callback) {
    cmp = cmp.invert();
    var result = callback.get();
    cmp = cmp.invert();
    return result;
  }

  /**
   * Compare two terms with the given {@param type} (if not null)
   *
   * @return true if they are 'the same' under {@param type}, false otherwise.
   */
  public boolean compare(@NotNull Term preLhs, @NotNull Term preRhs, @Nullable Term type) {
    if (preLhs == preRhs || preLhs instanceof ErrorTerm || preRhs instanceof ErrorTerm) return true;
    if (checkApproxResult(type, compareApprox(preLhs, preRhs))) return true;
    failure = null;

    var lhs = whnf(preLhs);
    var rhs = whnf(preRhs);
    if (!(lhs == preLhs && rhs == preRhs) &&
      checkApproxResult(type, compareApprox(lhs, rhs))) return true;

    if (rhs instanceof MetaCall rMeta) {
      // In case we're comparing two metas with one IsType and the other has OfType,
      // prefer solving the IsType one as the OfType one.
      if (lhs instanceof MetaCall lMeta && lMeta.ref().req() == MetaVar.Misc.IsType)
        return solveMeta(lMeta, rMeta, type) != null;
      return swapped(() -> solveMeta(rMeta, lhs, type)) != null;
    }
    // ^ Beware of the order!!
    if (lhs instanceof MetaCall lMeta) {
      return solveMeta(lMeta, rhs, type) != null;
    }

    if (rhs instanceof MemberCall && !(lhs instanceof MemberCall)) {
      return swapped(() -> doCompare(rhs, lhs, type));
    }

    return doCompare(lhs, rhs, type);
  }

  /**
   * Do compare {@param lhs} and {@param rhs} against type {@param type} (if not null),
   * with assumption on a good {@param lhs}, see below.
   *
   * @param lhs is {@link MemberCall} if rhs is not;
   *            if there is a {@link MetaCall} then it must be lhs.
   *            Reason: we case on lhs.
   */
  private boolean doCompare(Term lhs, Term rhs, @Nullable Term type) {
    var result = type == null
      ? compareUntyped(lhs, rhs) != null
      : doCompareTyped(lhs, rhs, type);

    if (!result) fail(lhs, rhs);
    return result;
  }

  private boolean checkApproxResult(@Nullable Term type, Term approxResult) {
    if (approxResult != null) {
      if (type != null) compare(approxResult, type, null);
      return true;
    } else return false;
  }

  /**
   * Compare whnf {@param lhs} and whnf {@param rhs} with {@param type} information
   *
   * @param type the whnf type.
   * @return whether they are 'the same' and their types are {@param type}
   */
  private boolean doCompareTyped(@NotNull Term lhs, @NotNull Term rhs, @NotNull Term type) {
    return switch (whnf(type)) {
      case LamTerm _, ConCallLike _, TupTerm _ -> Panic.unreachable();
      case ErrorTerm _ -> true;
      case ClassCall classCall -> {
        if (classCall.args().size() == classCall.ref().members().size()) yield true;
        // TODO: skip comparing fields that already have impl specified in the type
        yield classCall.ref().members().allMatch(member -> {
          // loop invariant: first [i] members are the "same". ([i] is the loop counter, count from 0)
          // Note that member can only refer to first [i] members, so it is safe that we supply [lhs] or [rhs]
          var ty = member.signature().inst(ImmutableSeq.of(lhs));
          var lproj = MemberCall.make(classCall, lhs, member, 0, ImmutableSeq.empty());
          var rproj = MemberCall.make(classCall, rhs, member, 0, ImmutableSeq.empty());
          return compare(lproj, rproj, ty.makePi(ImmutableSeq.empty()));
        });
      }
      case EqTerm eq -> switch (new Pair<>(lhs, rhs)) {
        case Pair(LamTerm(var lbody), LamTerm(var rbody)) -> subscoped(DimTyTerm.INSTANCE, var ->
          compare(
            lbody.apply(var),
            rbody.apply(var),
            eq.appA(new FreeTerm(var))
          ));
        case Pair(LamTerm lambda, _) -> compareLambda(lambda, rhs, eq);
        case Pair(_, LamTerm rambda) -> compareLambda(rambda, lhs, eq);
        default -> compare(lhs, rhs, null);
      };
      case DepTypeTerm pi when pi.kind() == DTKind.Pi -> switch (new Pair<>(lhs, rhs)) {
        case Pair(LamTerm(var lbody), LamTerm(var rbody)) -> subscoped(pi.param(), var ->
          compare(
            lbody.apply(var),
            rbody.apply(var),
            pi.body().apply(var)
          ));
        case Pair(LamTerm lambda, _) -> compareLambda(lambda, rhs, pi);
        case Pair(_, LamTerm rambda) -> compareLambda(rambda, lhs, pi);
        default -> compare(lhs, rhs, null);
      };
      // Sigma types
      case DepTypeTerm(_, var lTy, var rTy) -> {
        var lProj = ProjTerm.fst(lhs);
        var rProj = ProjTerm.fst(rhs);
        if (!compare(lProj, rProj, lTy)) yield false;
        yield compare(ProjTerm.snd(lhs), ProjTerm.snd(rhs), rTy.apply(lProj));
      }
      default -> compareUntyped(lhs, rhs) != null;
    };
  }

  /**
   * Compare head-normalized {@param preLhs} and whnfed {@param preRhs} without type information.
   *
   * @return the head-normalized type of {@param preLhs} and {@param preRhs} if they are 'the same', null otherwise.
   */
  private @Nullable Term compareUntyped(@NotNull Term preLhs, @NotNull Term preRhs) {
    {
      var result = compareApprox(preLhs, preRhs);
      if (result != null) return whnf(result);
      failure = null;
    }

    var lhs = whnf(preLhs);
    var rhs = whnf(preRhs);
    if (!(lhs == preLhs && rhs == preRhs)) {
      var result = compareApprox(lhs, rhs);
      if (result != null) return whnf(result);
    }

    Term result;
    if (rhs instanceof MetaCall || rhs instanceof MetaLitTerm || rhs instanceof MemberCall) {
      result = swapped(() -> doCompareUntyped(rhs, lhs));
    } else {
      result = doCompareUntyped(lhs, rhs);
    }
    if (result != null) return whnf(result);
    fail(lhs, rhs);
    return null;
  }

  private @Nullable Term doCompareUntyped(@NotNull Term lhs, @NotNull Term rhs) {
    if (lhs instanceof Formation form)
      return doCompareType(form, rhs) ?
        // It's going to be used in the synthesizer, so we freeze it first
        new Synthesizer(this).synthDontNormalize(freezeHoles(form)) : null;
    return switch (lhs) {
      case AppTerm(var f, var a) -> {
        if (!(rhs instanceof AppTerm(var g, var b))) yield null;
        var fTy = compareUntyped(f, g);
        if (!(fTy instanceof DepTypeTerm(var kk, var param, var body) && kk == DTKind.Pi)) yield null;
        if (!compare(a, b, param)) yield null;
        yield body.apply(a);
      }
      case PAppTerm(var f, var a, _, _) -> {
        if (!(rhs instanceof PAppTerm(var g, var b, _, _))) yield null;
        var fTy = compareUntyped(f, g);
        if (!(fTy instanceof EqTerm eq)) yield null;
        if (!compare(a, b, DimTyTerm.INSTANCE)) yield null;
        yield eq.appA(a);
      }
      case CoeTerm coe -> {
        if (!(rhs instanceof CoeTerm(var rType, var rR, var rS))) yield null;
        if (!compare(coe.r(), rR, DimTyTerm.INSTANCE)) yield null;
        if (!compare(coe.s(), rS, DimTyTerm.INSTANCE)) yield null;
        if (!subscoped(DimTyTerm.INSTANCE, var ->
          compare(coe.type().apply(var), rType.apply(var), null))) yield null;
        yield coe.family();
      }
      case ProjTerm(var lof, var ldx) -> {
        // Since {lhs} and {rhs} are whnf, at this point, {lof} is unable to evaluate.
        // Thus the only thing we can do is check whether {lof} and {rhs.of(}} (if rhs is ProjTerm) are 'the same'.
        if (!(rhs instanceof ProjTerm(var rof, var rdx))) yield null;
        if (!(compareUntyped(lof, rof) instanceof DepTypeTerm(var k, var lhsT, var rhsTClos) && k == DTKind.Sigma))
          yield null;
        if (ldx != rdx) yield null;
        if (ldx) yield lhsT;
        yield rhsTClos.apply(new ProjTerm(lof, true));
      }
      case FreeTerm(var lvar) -> rhs instanceof FreeTerm(var rvar) && lvar == rvar ? localCtx().get(lvar) : null;
      case DimTerm l -> rhs instanceof DimTerm r && l == r ? l : null;
      case MetaCall mCall -> solveMeta(mCall, rhs, null);
      // By typing invariant, they should have the same type, so no need to check for repr equality.
      case IntegerTerm(var lepr, _, _, var ty) -> rhs instanceof IntegerTerm rInt && lepr == rInt.repr() ? ty : null;
      case ListTerm list -> switch (rhs) {
        case ListTerm rist -> {
          if (!list.compareUntyped(rist, (l, r) ->
            compare(l, r, null))) yield null;
          yield list.type();
        }
        case ConCall rCon -> compareUntyped(list.constructorForm(), rCon);
        default -> null;
      };
      // fallback case
      case ConCallLike lCon -> switch (rhs) {
        case ListTerm rList -> compareUntyped(lhs, rList.constructorForm());
        case ConCallLike rCon -> compareCallApprox(lCon, rCon);
        default -> null;
      };
      case MetaLitTerm mlt -> switch (rhs) {
        case IntegerTerm mrt -> compareMetaLitWithLit(mlt, mrt.repr(), mrt.type());
        case ListTerm mrt -> compareMetaLitWithLit(mlt, mrt.repr(), mrt.type());
        case ConCall _ -> throw new UnsupportedOperationException("TODO (I have no time to implement this)");
        case MetaLitTerm mrt -> compareMetaLitWithLit(mlt, mrt.repr(), mrt.type());
        default -> null;
      };
      case MemberCall memberCall -> {
        // it is impossible that memberCall.of() is a cast term, since it is whnfed.
        assert !(memberCall.of() instanceof ClassCastTerm);
        if (rhs instanceof MemberCall memberCarr) {
          assert !(memberCarr.of() instanceof ClassCastTerm);
          yield compareUntyped(memberCall.of(), memberCarr.of());
        } else {
          yield null;
        }
      }
      // We already compare arguments in compareApprox, if we arrive here,
      // it means their arguments don't match (even the ref don't match),
      // so we are unable to do more if we can't normalize them.
      case FnCall _, PrimCall _ -> null;

      default -> throw noRules(lhs);
    };
  }

  private @Nullable Term compareMetaLitWithLit(@NotNull MetaLitTerm lhs, Object repr, @NotNull Term rhsType) {
    if (!Objects.equals(lhs.repr(), repr)) return null;
    if (compare(lhs.type(), rhsType, null)) return lhs.type();
    return null;
  }

  /** Compare {@param lambda} and {@param rhs} with {@param type} */
  private boolean compareLambda(@NotNull LamTerm lambda, @NotNull Term rhs, @NotNull DepTypeTerm type) {
    return subscoped(type.param(), var -> {
      var lhsBody = lambda.body().apply(var);
      var rhsBody = AppTerm.make(rhs, new FreeTerm(var));
      return compare(lhsBody, rhsBody, type.body().apply(var));
    });
  }

  private boolean compareLambda(@NotNull LamTerm lambda, @NotNull Term rhs, @NotNull EqTerm type) {
    return subscoped(DimTyTerm.INSTANCE, var -> {
      var lhsBody = lambda.body().apply(var);
      var free = new FreeTerm(var);
      var rhsBody = AppTerm.make(rhs, free);
      return compare(lhsBody, rhsBody, type.appA(free));
    });
  }

  private @Nullable Term compareMany(
    @NotNull ImmutableSeq list,
    @NotNull ImmutableSeq rist,
    @NotNull AbstractTele types
  ) {
    assert list.sizeEquals(rist);
    assert rist.sizeEquals(types.telescopeSize());
    var argsCum = new Term[types.telescopeSize()];

    for (var i = 0; i < types.telescopeSize(); ++i) {
      var l = list.get(i);
      var r = rist.get(i);
      var ty = types.telescope(i, argsCum);
      if (!compare(l, r, ty)) return null;
      argsCum[i] = l;
    }

    return types.result(argsCum);
  }

  /**
   * Compare {@param lTy} and {@param rTy}
   *
   * @param continuation invoked with {@code ? : lTy} in {@link Contextful#localCtx()} if {@param lTy} is the 'same' as {@param rTy}
   */
  private  R compareTypeWith(
    @NotNull Term lTy,
    @NotNull Term rTy,
    @NotNull Supplier onFailed,
    @NotNull Function continuation
  ) {
    if (!compare(lTy, rTy, null)) return onFailed.get();
    return subscoped(lTy, continuation);
  }

  private boolean sortLt(@NotNull SortTerm l, @NotNull SortTerm r) {
    var lift = l.lift();
    var rift = r.lift();
    // ISet <= Set0
    // Set i <= Set j if i <= j
    // Type i <= Type j if i <= j
    return switch (l.kind()) {
      case Type -> r.kind() == SortKind.Type && lift <= rift;
      case ISet -> r.kind() == SortKind.Set || r.kind() == SortKind.ISet;
      case Set -> r.kind() == SortKind.Set && lift <= rift;
    };
  }

  private boolean compareSort(@NotNull SortTerm l, @NotNull SortTerm r) {
    return switch (cmp) {
      case Gt -> {
        if (!sortLt(r, l)) {
          fail(new LevelError(pos, l, r, false));
          yield false;
        } else yield true;
      }
      case Eq -> {
        if (!(l.kind() == r.kind() && l.lift() == r.lift())) {
          fail(new LevelError(pos, l, r, true));
          yield false;
        } else yield true;
      }
      case Lt -> {
        if (!sortLt(l, r)) {
          fail(new LevelError(pos, r, l, false));
          yield false;
        } else yield true;
      }
    };
  }

  /**
   * Compare two type formation
   * Note: don't confuse with {@link TermComparator#doCompareTyped(Term, Term, Term)}
   */
  private boolean doCompareType(@NotNull Formation preLhs, @NotNull Term preRhs) {
    if (preLhs.getClass() != preRhs.getClass()) return false;
    return switch (new Pair<>(preLhs, (Formation) preRhs)) {
      case Pair(DataCall lhs, DataCall rhs) -> compareCallApprox(lhs, rhs) != null;
      case Pair(DimTyTerm _, DimTyTerm _) -> true;
      case Pair(DepTypeTerm(var lK, var lParam, var lBody), DepTypeTerm(var rK, var rParam, var rBody)) ->
        lK == rK && compareTypeWith(lParam, rParam, () -> false, var ->
          compare(lBody.apply(var), rBody.apply(var), null));
      case Pair(SortTerm lhs, SortTerm rhs) -> compareSort(lhs, rhs);
      case Pair(EqTerm(var A, var a0, var a1), EqTerm(var B, var b0, var b1)) -> {
        var tyResult = subscoped(DimTyTerm.INSTANCE, var ->
          compare(A.apply(var), B.apply(var), null));
        if (!tyResult) yield false;
        yield compare(a0, b0, A.apply(DimTerm.I0)) && compare(a1, b1, A.apply(DimTerm.I1));
      }
      default -> throw noRules(preLhs);
    };
  }

  public  R subscoped(@NotNull Term type, @NotNull Function action) {
    return super.subscoped(type, action, nameGen);
  }

  public @NotNull FailureData getFailure() {
    var failure = this.failure;
    assert failure != null;
    return failure.map(this::freezeHoles);
  }

  public record FailureData(@NotNull Term lhs, @NotNull Term rhs) {
    public @NotNull FailureData map(@NotNull UnaryOperator f) {
      return new FailureData(f.apply(lhs), f.apply(rhs));
    }
  }

  /** Maybe you're looking for {@link #compare} instead. */
  @ApiStatus.Internal public boolean checkEqn(@NotNull TyckState.Eqn eqn) {
    if (state.solutions.containsKey(eqn.lhs().ref()))
      return compare(eqn.lhs(), eqn.rhs(), eqn.type());
    else return solveMeta(eqn.lhs(), eqn.rhs(), eqn.type()) != null;
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy