
gapt.grammars.recursionSchemes.scala Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of gapt_3 Show documentation
Show all versions of gapt_3 Show documentation
General Architecture for Proof Theory
The newest version!
package gapt.grammars
import gapt.expr.formula.fol._
import gapt.expr._
import gapt.expr.formula.All
import gapt.expr.formula.And
import gapt.expr.formula.Atom
import gapt.expr.formula.Bottom
import gapt.expr.formula.Eq
import gapt.expr.formula.Formula
import gapt.expr.formula.Imp
import gapt.expr.formula.Neg
import gapt.expr.formula.Or
import gapt.expr.formula.Top
import gapt.expr.formula.fol.FOLAtom
import gapt.expr.formula.fol.FOLFormula
import gapt.expr.formula.fol.FOLVar
import gapt.expr.formula.hol._
import gapt.expr.subst.FOLSubstitution
import gapt.expr.subst.Substitution
import gapt.expr.ty.FunctionType
import gapt.expr.ty.TBase
import gapt.expr.ty.To
import gapt.expr.ty.arity
import gapt.expr.util.constants
import gapt.expr.util.expressionSize
import gapt.expr.util.freeVariables
import gapt.expr.util.rename
import gapt.expr.util.subTerms
import gapt.expr.util.syntacticMGU
import gapt.expr.util.syntacticMatching
import gapt.formats.babel.{BabelExporter, BabelSignature, MapBabelSignature, Precedence}
import gapt.logic.hol.simplifyPropositional
import gapt.logic.hol.toNNF
import gapt.proofs.context.Context
import gapt.proofs.RichFormulaSequent
import gapt.provers.maxsat.{MaxSATSolver, bestAvailableMaxSatSolver}
import gapt.utils.{Doc, Logger}
import scala.collection.mutable
case class Rule(lhs: Expr, rhs: Expr) {
require(freeVariables(rhs) subsetOf freeVariables(lhs), s"$rhs has more free variables than $lhs")
require(lhs.ty == rhs.ty, s"$lhs has different type than $rhs")
def apply(term: Expr): Option[Expr] =
syntacticMatching(lhs, term).map(_(rhs))
def apply(subst: Substitution): Rule =
Rule(subst(lhs), subst(rhs))
override def toString: String = toSigRelativeString
def toSigRelativeString(implicit sig: BabelSignature) = s"${lhs.toSigRelativeString} -> ${rhs.toSigRelativeString}"
}
private class RecursionSchemeExporter(unicode: Boolean, rs: RecursionScheme)
extends BabelExporter(unicode, rs.babelSignature) {
import Doc._
def csep(docs: List[Doc]): Doc = wordwrap(docs, ",")
def `export`(): String = {
val nonTerminals = rs.startSymbol +: (rs.nonTerminals - rs.startSymbol).toList.sortBy { _.name }
val ntDecl = group("Non-terminals:" <> nest(line <> csep(
nonTerminals map { show(_, false, Map(), Map())._1.inPrec(0) }
)))
val tDecl = group("Terminals:" <> nest(line <> csep(
rs.terminals.toList.sortBy { _.name } map { show(_, false, Map(), Map())._1.inPrec(0) }
)))
val knownTypes = (rs.nonTerminals union rs.terminals).map { c => c.name -> c }.toMap
val rules = group(stack(rs.rules.toList sortBy { _.toString } map {
case Rule(lhs, rhs) =>
group(show(lhs, false, Map(), knownTypes)._1.inPrec(Precedence.impl) > nest("→" >
show(rhs, true, Map(), knownTypes)._1.inPrec(Precedence.impl)))
}))
group(ntDecl > tDecl <> line > rules <> line).render(lineWidth)
}
}
case class RecursionScheme(startSymbol: Const, nonTerminals: Set[Const], rules: Set[Rule]) {
require(nonTerminals contains startSymbol)
rules foreach {
r =>
(r: @unchecked) match {
case Rule(Apps(leftHead: Const, _), _) =>
require(nonTerminals contains leftHead)
}
}
def terminals: Set[Const] =
rules flatMap { case Rule(lhs, rhs) => constants.nonLogical(lhs) union constants.nonLogical(rhs) } diff nonTerminals
def babelSignature = MapBabelSignature(terminals union nonTerminals)
def language: Set[Expr] = parametricLanguage()
def languageWithDummyParameters: Set[Expr] =
(startSymbol.ty: @unchecked) match {
case FunctionType(_, argtypes) =>
parametricLanguage(argtypes.zipWithIndex.map { case (t, i) => Const(s"dummy$i", t) }: _*)
}
def rulesFrom(nonTerminal: Const): Set[Rule] =
rules collect { case r @ Rule(Apps(`nonTerminal`, _), _) => r }
def parametricLanguage(params: Expr*): Set[Expr] = {
require(params.size == arity(startSymbol))
generatedTerms(startSymbol(params: _*))
}
def generatedTerms(from: Expr): Set[Expr] = {
val seen = mutable.Set[Expr]()
val gen = mutable.Set[Expr]()
def rewrite(t: Expr): Unit = t match {
case _ if seen contains t => ()
case Apps(head: Const, args) if nonTerminals contains head =>
rules foreach { _(t) foreach rewrite }
seen += t
case _ =>
gen += t
}
rewrite(from)
gen.toSet
}
override def toString: String = new RecursionSchemeExporter(unicode = true, rs = this).`export`()
}
object RecursionScheme {
def apply(startSymbol: Const, rules: (Expr, Expr)*): RecursionScheme =
apply(startSymbol, rules map { case (from, to) => Rule(from, to) } toSet)
def apply(startSymbol: Const, nonTerminals: Set[Const], rules: (Expr, Expr)*): RecursionScheme =
RecursionScheme(startSymbol, nonTerminals, rules map { case (from, to) => Rule(from, to) } toSet)
def apply(startSymbol: Const, rules: Set[Rule]): RecursionScheme = {
val nonTerminals = rules.map {
r =>
(r: @unchecked) match {
case Rule(Apps(head: Const, _), _) => head
}
} + startSymbol
RecursionScheme(startSymbol, nonTerminals, rules)
}
}
object preOrderTraversal {
def apply(term: Expr): Seq[Expr] = term match {
case App(a, b) => term +: (apply(a) ++ apply(b))
case _: Const | _: Var => Seq(term)
}
}
object canonicalVars {
def apply(term: Expr): Expr =
FOLSubstitution(preOrderTraversal(term).collect { case v: FOLVar => v }.distinct.zipWithIndex.map { case (v, i) => v -> FOLVar(s"x$i") })(term)
}
object TargetFilter {
type Type = (Expr, Expr) => Option[Boolean]
def default: Type = (from: Expr, to: Expr) =>
syntacticMatching(to, from) map { _ => true }
}
class RecSchemGenLangFormula(
val recursionScheme: RecursionScheme,
val targetFilter: TargetFilter.Type = TargetFilter.default
) {
def ruleIncluded(rule: Rule) = FOLAtom(s"${rule.lhs}->${rule.rhs}")
def derivable(from: Expr, to: Expr) = FOLAtom(s"$from=>$to")
private val rulesPerNonTerminal = Map() ++ recursionScheme.rules.groupBy { r =>
(r: @unchecked) match {
case Rule(_, Apps(nt: Const, _)) => nt
}
}.view.mapValues(_.toSeq).toMap
def reverseMatches(against: Expr) =
against match {
case Apps(nt: Const, _) => rulesPerNonTerminal.getOrElse(nt, Seq()).flatMap { rule =>
val (fvsRule, fvsAgainst) = (freeVariables(rule.lhs), freeVariables(against))
val rule_ = if (fvsRule intersect fvsAgainst nonEmpty)
rule(Substitution(rename(freeVariables(rule.lhs), freeVariables(against))))
else
rule
syntacticMGU(rule_.rhs, against).headOption.map { unifier => canonicalVars(unifier(rule_.lhs)) -> rule }
}
}
type Target = (Expr, Expr)
def apply(targets: Iterable[Target]): FOLFormula = {
val edges = mutable.ArrayBuffer[(Target, Rule, Target)]()
val goals = mutable.Set[Target]()
val queue = mutable.Queue(targets.toSeq: _*)
val alreadyDone = mutable.Set[Target]()
while (queue nonEmpty) {
val target @ (from, to) = queue.dequeue()
if (!alreadyDone(target))
reverseMatches(to).foreach {
case (newTo, rule) =>
targetFilter(from, newTo) match {
case Some(true) =>
goals += (from -> newTo)
edges += ((target, rule, from -> newTo))
case Some(false) => ()
case None =>
edges += ((target, rule, from -> newTo))
queue enqueue (from -> newTo)
}
}
alreadyDone += target
}
val reachable = mutable.Set[Target](goals.toSeq: _*)
var changed = true
while (changed) {
changed = false
edges.foreach {
case (a, r, b) =>
if ((reachable contains b) && !(reachable contains a)) {
reachable += a
changed = true
}
}
}
if (!(targets.toSet subsetOf reachable)) return Bottom()
val edgesPerFrom = edges.groupBy(_._1)
And(targets.toSeq.map { case (from, to) => derivable(from, to) } ++ (reachable collect {
case t @ (from, to) if !(goals contains t) =>
Imp(
derivable(from, to),
Or(
edgesPerFrom(t) collect {
case (_, r, b) if goals contains b => ruleIncluded(r)
case (_, r, b @ (from_, to_)) if reachable contains b =>
And(ruleIncluded(r), derivable(from_, to_))
}
)
)
}) ++ (for (
(from1, to1) <- reachable;
(from2, to2) <- reachable if from1 == from2 && to1 != to2
if syntacticMatching(to2, to1).isDefined
) yield Imp(derivable(from1, to1), derivable(from1, to2))))
}
}
object minimizeRecursionScheme {
val logger = Logger("minimizeRecursionScheme"); import logger._
def apply(recSchem: RecursionScheme, targets: Iterable[(Expr, Expr)], targetFilter: TargetFilter.Type = TargetFilter.default, solver: MaxSATSolver = bestAvailableMaxSatSolver, weight: Rule => Int = _ => 1) = {
val fvs = freeVariables(targets.map(_._1)) union freeVariables(targets.map(_._2))
val nameGen = rename.awayFrom(constants.nonLogical(targets.map(_._1)) union constants.nonLogical(targets.map(_._2)))
val grounding = Substitution(for (v @ Var(name, ty) <- fvs) yield v -> Const(nameGen fresh name, ty))
val targets_ = grounding(targets.toSet)
val formula = new RecSchemGenLangFormula(recSchem, targetFilter)
val hard = formula(targets_)
debug(s"Logical complexity of the minimization formula: ${lcomp(simplifyPropositional(toNNF(hard)))}")
val soft = recSchem.rules map { rule => Neg(formula.ruleIncluded(rule)) -> weight(rule) }
val interp = time("maxsat") { solver.solve(hard, soft).get }
RecursionScheme(recSchem.startSymbol, recSchem.nonTerminals, recSchem.rules.filter { rule => interp(formula ruleIncluded rule) })
}
def viaInst(recSchem: RecursionScheme, targets: Iterable[(Expr, Expr)], targetFilter: TargetFilter.Type = TargetFilter.default, solver: MaxSATSolver = bestAvailableMaxSatSolver, weight: Rule => Int = _ => 1) = {
val fvs = freeVariables(targets.map(_._1)) union freeVariables(targets.map(_._2))
val nameGen = rename.awayFrom(constants.nonLogical(targets.map(_._1)) union constants.nonLogical(targets.map(_._2)))
val grounding = Substitution(for (v @ Var(name, ty) <- fvs) yield v -> Const(nameGen fresh name, ty))
val targets_ = grounding(targets.toSet)
val instTerms = targets_.map { _._1 }.flatMap { case Apps(_, as) => as }.flatMap { flatSubterms(_) }
val instRS = instantiateRS(recSchem, instTerms)
val formula = new RecSchemGenLangFormula(instRS, targetFilter)
val ruleCorrespondence =
for (ir <- instRS.rules.toSeq) yield formula.ruleIncluded(ir) --> Or(
for {
r <- recSchem.rules.toSeq
_ <- syntacticMatching(List(r.lhs -> ir.lhs, r.rhs -> ir.rhs))
} yield formula.ruleIncluded(r)
)
val hard = formula(targets_) & And(ruleCorrespondence)
debug(s"Logical complexity of the minimization formula: ${lcomp(simplifyPropositional(toNNF(hard)))}")
val soft = recSchem.rules map { rule => Neg(formula.ruleIncluded(rule)) -> weight(rule) }
val interp = solver.solve(hard, soft).get
RecursionScheme(recSchem.startSymbol, recSchem.nonTerminals, recSchem.rules.filter { rule => interp(formula ruleIncluded rule) })
}
}
case class RecSchemTemplate(startSymbol: Const, template: Set[(Expr, Expr)]) {
val nonTerminals: Set[Const] = template map { case (Apps(nt: Const, _), _) => nt }
val isSubtermC = "is_subterm"
def isSubterm(v: Expr, t: Expr): Formula =
Const(isSubtermC, v.ty ->: t.ty ->: To)(v, t).asInstanceOf[Formula]
val canonicalArgs = nonTerminals map {
case nt @ Const(_, FunctionType(_, argTypes), _) =>
nt -> argTypes.zipWithIndex.map { case (t, i) => Var(s"${nt}_$i", t) }
} toMap
val states = canonicalArgs map { case (nt, args) => nt(args: _*) }
val constraints: Map[(Const, Const), Formula] = {
val cache = mutable.Map[(Const, Const), Formula]()
def get(from: Const, to: Const): Formula =
cache.getOrElseUpdate(
from -> to, {
var postCond = if (from == to)
And(canonicalArgs(from).lazyZip(canonicalArgs(to)).map { Eq(_, _) })
else Or(template collect {
case (Apps(prev: Const, prevArgs), Apps(`to`, toArgs)) if prev != to =>
def postCondition(preCond: Formula): Formula = preCond match {
case Top() => Top()
case Bottom() => Bottom()
case And(a, b) => And(postCondition(a), postCondition(b))
case Or(a, b) => Or(postCondition(a), postCondition(b))
case Eq(a, b) =>
prevArgs(canonicalArgs(prev).indexOf(a)) match {
case v: Var =>
And(for ((toArg, canToArg: Var) <- toArgs.lazyZip(canonicalArgs(to)).toSeq if v == toArg)
yield Eq(canToArg, b))
case constr =>
val vars = freeVariables(constr)
And((toArgs.toSeq zip canonicalArgs(to)).collect {
case (toArg: Var, canToArg) if vars contains toArg =>
isSubterm(canToArg, b)
})
}
case Apps(Const(`isSubtermC`, _, _), Seq(a, b)) =>
val vars = freeVariables(prevArgs(canonicalArgs(prev).indexOf(a)))
And((toArgs.toSeq zip canonicalArgs(to)).collect {
case (toArg: Var, canToArg) if vars contains toArg =>
isSubterm(canToArg, b)
})
}
postCondition(get(from, prev))
} toSeq)
val recCalls = template filter {
case (Apps(`to`, _), Apps(`to`, _)) => true
case _ => false
}
if (recCalls nonEmpty) {
val constArgs = canonicalArgs(to).zipWithIndex filter {
case (a, i) =>
recCalls forall {
case (Apps(_, callerArgs), Apps(_, calleeArgs)) =>
callerArgs(i) == calleeArgs(i)
}
} map { _._1 }
val structRecArgs = canonicalArgs(to).zipWithIndex filter {
case (a, i) =>
recCalls forall {
case (Apps(_, callerArgs), Apps(_, calleeArgs)) =>
callerArgs(i).find(calleeArgs(i)).nonEmpty
}
} map { _._1 }
def appRecConstr(p: Formula): Formula = p match {
case Top() => Top()
case Bottom() => Bottom()
case Or(a, b) => Or(appRecConstr(a), appRecConstr(b))
case And(a, b) => And(appRecConstr(a), appRecConstr(b))
case Eq(a, b) if constArgs contains a => Eq(a, b)
case Eq(a, b) if structRecArgs contains a => isSubterm(a, b)
case Apps(Const(`isSubtermC`, _, _), Seq(a, b)) if (constArgs contains a) || (structRecArgs contains a) =>
isSubterm(a, b)
case _ => Top()
}
postCond = appRecConstr(postCond)
}
simplifyPropositional(toNNF(postCond))
}
)
(for (from <- nonTerminals; to <- nonTerminals)
yield (from, to) -> get(from, to)) toMap
}
val constraintEvaluators: Map[(Const, Const), (Seq[Expr], Seq[Expr]) => Boolean] =
constraints map {
case ((from, to), constr) =>
def mkEval(f: Formula): ((Seq[Expr], Seq[Expr]) => Boolean) = f match {
case Top() => (_, _) => true
case Bottom() => (_, _) => false
case And(a, b) =>
val aEval = mkEval(a)
val bEval = mkEval(b)
(x, y) => aEval(x, y) && bEval(x, y)
case Or(a, b) =>
val aEval = mkEval(a)
val bEval = mkEval(b)
(x, y) => aEval(x, y) || bEval(x, y)
case Eq(b, a) =>
val aIdx = canonicalArgs(from).indexOf(a)
val bIdx = canonicalArgs(to).indexOf(b)
require(aIdx >= 0 && bIdx >= 0)
(x, y) => syntacticMatching(y(bIdx), x(aIdx)).isDefined
case Apps(Const(`isSubtermC`, _, _), Seq(b, a)) =>
val aIdx = canonicalArgs(from).indexOf(a)
val bIdx = canonicalArgs(to).indexOf(b)
require(aIdx >= 0 && bIdx >= 0)
(x, y) =>
(expressionSize(y(bIdx)) <= expressionSize(x(aIdx)) + 1) &&
constants.nonLogical(y(bIdx)).subsetOf(constants.nonLogical(x(aIdx)))
}
(from -> to) -> mkEval(constr)
}
val targetFilter: TargetFilter.Type = (from, to) =>
TargetFilter.default(from, to).orElse {
val Apps(fromNt: Const, fromArgs) = from: @unchecked
val Apps(toNt: Const, toArgs) = to: @unchecked
val constrValue = constraintEvaluators(fromNt -> toNt)(fromArgs, toArgs)
if (constrValue) None else Some(false)
}
def stableRecSchem(targets: Set[(Expr, Expr)]): RecursionScheme = {
val neededVars = template flatMap { case (from, to) => freeVariables(from) }
val allTerms = targets map { _._2 }
val topLevelStableTerms = stableTerms(allTerms, neededVars.toSeq).filter(!_.isInstanceOf[Var])
val argumentStableTerms = stableTerms(
allTerms
flatMap { case Apps(_, as) => as }
flatMap { subTerms(_) }
filter { _.ty.isInstanceOf[TBase] },
neededVars.toSeq
)
var rules = template.flatMap {
case (from, to: Var) =>
val allowedVars = freeVariables(from)
topLevelStableTerms.filter { st => freeVariables(st) subsetOf allowedVars }.map { Rule(from, _) }
case (from, to) =>
val allowedVars = freeVariables(from)
val templateVars = freeVariables(to).diff(freeVariables(from))
templateVars.foldLeft(Seq[Map[Var, Expr]](Map()))((chosenValues, nextVar) =>
for (
subst <- chosenValues;
st <- argumentStableTerms if st.ty == nextVar.ty && freeVariables(st).subsetOf(allowedVars)
) yield subst + (nextVar -> st)
).map(s => Rule(from, Substitution(s)(to)))
}
// Filter out rules that only used variables that are passed unchanged from the startSymbol.
targets.map { case (Apps(nt: Const, _), _) => nt }.toSeq match {
case Seq() => // empty language
case Seq(startSymbol) =>
(nonTerminals - startSymbol) foreach { nt =>
constraints(startSymbol -> nt) match {
case And.nAry(constr) =>
val identicalArgs = constr.collect {
case Eq(ntArg, startSymbolArg) => canonicalArgs(nt).indexOf(ntArg)
}.toSet
rules = rules filter {
case Rule(Apps(`nt`, args), to) =>
!freeVariables(to).subsetOf(identicalArgs map { args(_) } collect { case v: Var => v })
case _ => true
}
}
}
}
RecursionScheme(startSymbol, nonTerminals, rules)
}
def findMinimalCover(
targets: Set[(Expr, Expr)],
solver: MaxSATSolver = bestAvailableMaxSatSolver,
weight: Rule => Int = _ => 1
): RecursionScheme = {
minimizeRecursionScheme(stableRecSchem(targets), targets toSeq, targetFilter, solver, weight)
}
def findMinimalCoverViaInst(
targets: Set[(Expr, Expr)],
solver: MaxSATSolver = bestAvailableMaxSatSolver,
weight: Rule => Int = _ => 1
): RecursionScheme = {
minimizeRecursionScheme.viaInst(stableRecSchem(targets), targets toSeq, targetFilter, solver, weight)
}
}
object RecSchemTemplate {
def apply(startSymbol: Const, rules: (Expr, Expr)*): RecSchemTemplate =
RecSchemTemplate(startSymbol, rules toSet)
}
object recSchemToVTRATG {
def orderedNonTerminals(rs: RecursionScheme): Seq[Const] = {
val ntDeps = rs.nonTerminals map { nt =>
nt -> (rs rulesFrom nt map { _.rhs } flatMap { constants.nonLogical(_) } intersect rs.nonTerminals)
} toMap
var nts = Seq[Const]()
while (rs.nonTerminals -- nts nonEmpty) {
val Some(next) = rs.nonTerminals -- nts find { nt => ntDeps(nt) subsetOf nts.toSet }: @unchecked
nts = next +: nts
}
nts
}
def apply(recSchem: RecursionScheme): VTRATG = {
val nameGen = rename.awayFrom(containedNames(recSchem))
val ntCorrespondence = orderedNonTerminals(recSchem).reverse map {
case nt @ Const(name, FunctionType(_, argTypes), _) =>
nt -> (for ((t, i) <- argTypes.zipWithIndex) yield Var(nameGen.fresh(s"x_${name}_$i"), t))
}
val ntMap = ntCorrespondence.toMap
val FunctionType(startSymbolType, _) = recSchem.startSymbol.ty: @unchecked
val startSymbol = Var(nameGen.fresh(s"x_${recSchem.startSymbol.name}"), startSymbolType)
val nonTerminals = List(startSymbol) +: (ntCorrespondence map { _._2 } filter { _.nonEmpty })
val productions = recSchem.rules map {
r =>
(r: @unchecked) match {
case Rule(Apps(nt1: Const, vars1), Apps(nt2: Const, args2)) if recSchem.nonTerminals.contains(nt1) && recSchem.nonTerminals.contains(nt2) =>
val subst = Substitution(vars1.map(_.asInstanceOf[Var]) zip ntMap(nt1))
ntMap(nt2) -> args2.map(subst(_))
case Rule(Apps(nt1: Const, vars1), rhs) if recSchem.nonTerminals.contains(nt1) =>
val subst = Substitution(vars1.map(_.asInstanceOf[Var]) zip ntMap(nt1))
List(startSymbol) -> List(subst(rhs))
}
}
VTRATG(startSymbol, nonTerminals, productions)
}
}
object simplePi1RecSchemTempl {
def apply(startSymbol: Expr, pi1QTys: Seq[TBase])(implicit ctx: Context): RecSchemTemplate = {
val nameGen = rename.awayFrom(ctx.constants)
val Apps(startSymbolNT: Const, startSymbolArgs) = startSymbol: @unchecked
val FunctionType(instTT, startSymbolArgTys) = startSymbolNT.ty: @unchecked
// TODO: handle strong quantifiers in conclusion correctly
val startSymbolArgs2 = for ((t, i) <- startSymbolArgTys.zipWithIndex) yield Var(s"x_$i", t)
val indLemmaNT = Const(
nameGen fresh "B",
FunctionType(instTT, startSymbolArgTys ++ startSymbolArgTys ++ pi1QTys)
)
val lhsPi1QArgs = for ((t, i) <- pi1QTys.zipWithIndex) yield Var(s"w_$i", t)
val rhsPi1QArgs = for ((t, i) <- pi1QTys.zipWithIndex) yield Var(s"v_$i", t)
val indLemmaRules = startSymbolArgTys.zipWithIndex.flatMap {
case (indLemmaArgTy, indLemmaArgIdx) =>
val indTy = indLemmaArgTy.asInstanceOf[TBase]
ctx.getConstructors(indTy) match {
case None => Seq()
case Some(ctrs) =>
ctrs flatMap { ctr =>
val FunctionType(_, ctrArgTys) = ctr.ty: @unchecked
val ctrArgs = for ((t, i) <- ctrArgTys.zipWithIndex)
yield Var(s"x_${indLemmaArgIdx}_$i", t)
val lhs = indLemmaNT(startSymbolArgs)(
startSymbolArgs2.take(indLemmaArgIdx)
)(
ctr(ctrArgs: _*)
)(
startSymbolArgs2.drop(indLemmaArgIdx + 1)
)(
lhsPi1QArgs
)
val recRules = ctrArgTys.zipWithIndex.filter { _._1 == indTy } map {
case (ctrArgTy, ctrArgIdx) =>
lhs -> indLemmaNT(startSymbolArgs)(
startSymbolArgs2.take(indLemmaArgIdx)
)(
ctrArgs(ctrArgIdx)
)(
startSymbolArgs2.drop(indLemmaArgIdx + 1)
)(
rhsPi1QArgs
)
}
recRules :+ (lhs -> Var("u", instTT))
}
}
}
RecSchemTemplate(
startSymbolNT,
indLemmaRules.toSet
+ (startSymbolNT(startSymbolArgs) ->
indLemmaNT(startSymbolArgs)(startSymbolArgs)(rhsPi1QArgs))
+ (startSymbolNT(startSymbolArgs) -> Var("u", instTT))
+ (indLemmaNT(startSymbolArgs)(startSymbolArgs2)(lhsPi1QArgs) -> Var("u", instTT))
)
}
}
object qbupForRecSchem {
def canonicalRsLHS(recSchem: RecursionScheme)(implicit ctx: Context): Set[Expr] =
recSchem.nonTerminals flatMap { nt =>
val FunctionType(To, argTypes) = nt.ty: @unchecked
val args = for ((t, i) <- argTypes.zipWithIndex) yield Var(s"x$i", t)
recSchem.rulesFrom(nt).flatMap {
case Rule(Apps(_, as), _) => as.zipWithIndex.filterNot { _._1.isInstanceOf[Var] }.map { _._2 }
}.toSeq match {
case Seq() => Some(nt(args: _*))
case idcs =>
val newArgs =
for (case (_: TBase, idx) <- argTypes.zipWithIndex)
yield
if (!idcs.contains(idx)) List(args(idx))
else {
val indTy = argTypes(idx).asInstanceOf[TBase]
val Some(ctrs) = ctx.getConstructors(indTy): @unchecked
for {
ctr <- ctrs.toList
FunctionType(_, ctrArgTys) = ctr.ty: @unchecked
} yield ctr(
(for ((t, i) <- ctrArgTys.zipWithIndex) yield Var(s"x${idx}_$i", t)): _*
)
}
import cats.instances.list._
import cats.syntax.traverse._
newArgs.traverse(identity).map(nt(_: _*))
}
}
def apply(recSchem: RecursionScheme, conj: Formula)(implicit ctx: Context): Formula = {
def convert(term: Expr): Formula = term match {
case Apps(ax, args) if ax == recSchem.startSymbol => instantiate(conj, args)
case Apps(nt @ Const(name, ty, _), args) if recSchem.nonTerminals contains nt =>
Atom(Var(s"X_$name", ty)(args: _*))
case formula: Formula => formula
}
val lhss = canonicalRsLHS(recSchem)
existentialClosure(And(for (lhs <- lhss) yield All.Block(
freeVariables(lhs) toSeq,
formulaToSequent.pos(And(for {
Rule(lhs_, rhs) <- recSchem.rules
subst <- syntacticMatching(lhs_, lhs)
} yield convert(subst(rhs)))
--> convert(lhs)).toImplication
)))
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy