All Downloads are FREE. Search and download functionalities are using the official Maven repository.

z3-z3-4.13.0.src.sat.smt.euf_proof_checker.cpp Maven / Gradle / Ivy

The newest version!
/*++
Copyright (c) 2020 Microsoft Corporation

Module Name:

    euf_proof_checker.cpp

Abstract:

    Plugin manager for checking EUF proofs

Author:

    Nikolaj Bjorner (nbjorner) 2020-08-25

--*/

#include "util/union_find.h"
#include "ast/ast_pp.h"
#include "ast/ast_util.h"
#include "ast/ast_ll_pp.h"
#include "ast/arith_decl_plugin.h"
#include "smt/smt_solver.h"
#include "sat/sat_params.hpp"
#include "sat/smt/euf_proof_checker.h"
#include "sat/smt/arith_theory_checker.h"
#include "sat/smt/q_theory_checker.h"
#include "sat/smt/bv_theory_checker.h"
#include "sat/smt/distinct_theory_checker.h"
#include "sat/smt/tseitin_theory_checker.h"
#include "params/solver_params.hpp"

namespace euf {

    /**
     * The equality proof checker checks congruence proofs.
     * A congruence claim comprises
     *   - a set of equality and diseqality literals that are
     *     unsatisfiable modulo equality reasoning.
     *   - a list of congruence claims that are used for equality reasoning.
     *     Congruence claims are expressions of the form
     *     (cc uses_commutativity (= a b))
     *     where uses_commutativity is true or false
     *     If uses commutativity is true, then a, b are (the same) binary functions 
     *     a := f(x,y), b := f(z,u), such that x = u and y = z are consequences from 
     *     the current equalities.
     *     If uses_commtativity is false, then a, b are the same n-ary expressions
     *     each argument position i, a_i == b_i follows from current equalities.
     *     If the arguments are equal according to the current equalities, then the equality
     *     a = b is added as a consequence.
     *
     * The congruence claims can be justified from the equalities in the literals.
     * To be more precise, the congruence claims are justified in the they appear.
     * The congruence closure algorithm (egraph) uses timestamps to record a timestamp
     * when a congruence was inferred. Proof generation ensures that the congruence premises
     * are sorted by the timestamp such that a congruence that depends on an earlier congruence
     * appears later in the sorted order.
     *
     * Equality justifications are checked using union-find. 
     * We use union-find instead of fine-grained equality proofs (symmetry and transitivity
     * of equality) assuming that it is both cheap and simple to establish a certified
     * union-find checker.
     */

    class eq_theory_checker : public theory_checker_plugin {
        ast_manager&     m;
        arith_util       m_arith;
        expr_ref_vector  m_trail;
        basic_union_find m_uf;
        svector> m_expr2id;
        ptr_vector                       m_id2expr;
        svector>        m_diseqs;
        unsigned         m_ts = 0;
        
        void merge(expr* x, expr* y) {
            m_uf.merge(expr2id(x), expr2id(y));
            IF_VERBOSE(10, verbose_stream() << "merge " << mk_bounded_pp(x, m) << " == " << mk_bounded_pp(y, m) << "\n");
            merge_numeral(x);
            merge_numeral(y);
        }

        void merge_numeral(expr* x) {
            rational n;
            expr* y;
            if (m_arith.is_uminus(x, y) && m_arith.is_numeral(y, n)) {
                y = m_arith.mk_numeral(-n, x->get_sort());
                m_trail.push_back(y);
                m_uf.merge(expr2id(x), expr2id(y));
            }
        }

        bool are_equal(expr* x, expr* y) {
            return m_uf.find(expr2id(x)) == m_uf.find(expr2id(y));
        }

        bool congruence(bool comm, app* x, app* y) {
            if (x->get_decl() != y->get_decl())
                return false;
            if (x->get_num_args() != y->get_num_args())
                return false;
            if (comm) {
                if (x->get_num_args() != 2)
                    return false;
                if (!are_equal(x->get_arg(0), y->get_arg(1)))
                    return false;
                if (!are_equal(y->get_arg(0), x->get_arg(1)))
                    return false;
                merge(x, y);
            }
            else {
                for (unsigned i = 0; i < x->get_num_args(); ++i)
                    if (!are_equal(x->get_arg(i), y->get_arg(i)))
                        return false;
                merge(x, y);
            }
            IF_VERBOSE(10, verbose_stream() << "cc " << mk_bounded_pp(x, m) << " == " << mk_bounded_pp(y, m) << "\n");
            return true;
        }

        void reset() {
            ++m_ts;
            if (m_ts == 0) {
                m_expr2id.reset();
                ++m_ts;
            }
            m_uf.reset();
            m_diseqs.reset();
        }

        unsigned expr2id(expr* e) {
            auto [ts, id] = m_expr2id.get(e->get_id(), {0,0});
            if (ts != m_ts) {
                id = m_uf.mk_var();
                m_expr2id.setx(e->get_id(), {m_ts, id}, {0,0});
                m_id2expr.setx(id, e, nullptr);
            }
            return id;
        }                

    public:
        eq_theory_checker(ast_manager& m): m(m), m_arith(m), m_trail(m) {}

        expr_ref_vector clause(app* jst) override {
            expr_ref_vector result(m);
            for (expr* arg : *jst) 
                if (m.is_bool(arg)) 
                    result.push_back(mk_not(m, arg));
            return result;
        }

        bool check(app* jst) override {
            IF_VERBOSE(10, verbose_stream() << mk_pp(jst, m) << "\n");
            reset();

            for (expr* arg : *jst) {
                expr* x, *y;
                bool sign = m.is_not(arg, arg);

                if (m.is_bool(arg)) {
                    if (m.is_eq(arg, x, y)) {
                        if (sign)
                            m_diseqs.push_back({x, y});
                        else 
                            merge(x, y);
                    }
                    merge(arg, sign ? m.mk_false() : m.mk_true());
                }
                else if (m.is_proof(arg)) {
                    if (!is_app(arg))
                        return false;
                    app* a = to_app(arg);
                    if (a->get_num_args() != 1)
                        return false;
                    if (!m.is_eq(a->get_arg(0), x, y))
                        return false;
                    bool is_cc = a->get_name() == symbol("cc");
                    bool is_comm = a->get_name() == symbol("comm");
                    if (!is_cc && !is_comm)
                        return false;
                    if (!is_app(x) || !is_app(y))
                        return false;
                    if (!congruence(!is_cc, to_app(x), to_app(y))) {
                        IF_VERBOSE(0, verbose_stream() << "not congruent " << mk_pp(a, m) << "\n");
                        return false;
                    }
                }
                else {
                    IF_VERBOSE(0, verbose_stream() << "unrecognized argument " << mk_pp(arg, m) << "\n");
                    return false;
                }                
            }
            // check if a disequality is violated.
            for (auto const& [a, b] : m_diseqs)
                if (are_equal(a, b))
                    return true;

            // check if some equivalence class contains two distinct values.            
            for (unsigned v = 0; v < m_uf.get_num_vars(); ++v) {
                if (v != m_uf.find(v))
                    continue;
                unsigned r = v;
                expr* val = nullptr;
                do {
                    expr* e = m_id2expr[v];
                    if (val && m.are_distinct(e, val))
                        return true;
                    if (m.is_value(e))
                        val = e;
                    v = m_uf.next(v);
                }
                while (r != v);
            }
            return false;
        }

        void register_plugins(theory_checker& pc) override {
            pc.register_plugin(symbol("euf"), this);
            pc.register_plugin(symbol("smt"), this);
        }
    };

    /**
       A resolution proof term is of the form
       (res pivot proof1 proof2)
       The pivot occurs with opposite signs in proof1 and proof2
     */

    class res_checker : public theory_checker_plugin {
        ast_manager&   m;
        theory_checker& pc;

    public:
        res_checker(ast_manager& m, theory_checker& pc): m(m), pc(pc) {}

        bool check(app* jst) override {
            if (jst->get_num_args() != 3)
                return false;
            auto [pivot, proof1, proof2] = jst->args3();
            if (!m.is_bool(pivot) || !m.is_proof(proof1) || !m.is_proof(proof2))
                return false;
            expr* narg;
            bool found1 = false, found2 = false, found3 = false, found4 = false;
            for (expr* arg : pc.clause(proof1)) {
                found1 |= arg == pivot;
                found2 |= m.is_not(arg, narg) && narg == pivot;
            }
            if (found1 == found2)
                return false;
            
            for (expr* arg : pc.clause(proof2)) {
                found3 |= arg == pivot;
                found4 |= m.is_not(arg, narg) && narg == pivot;
            }
            if (found3 == found4)
                return false;
            if (found3 == found1)
                return false;
            return pc.check(proof1) && pc.check(proof2);            
        }
        
        expr_ref_vector clause(app* jst) override {
            expr_ref_vector result(m);
            auto x = jst->args3();
            auto pivot  = std::get<0>(x);
            auto proof1 = std::get<1>(x);
            auto proof2 = std::get<2>(x);
            expr* narg;
            auto is_pivot = [&](expr* arg) {
                if (arg == pivot)
                    return true;
                return m.is_not(arg, narg) && narg == pivot;                
            };
            for (expr* arg : pc.clause(proof1))
                if (!is_pivot(arg))
                    result.push_back(arg);
            for (expr* arg : pc.clause(proof2))
                if (!is_pivot(arg))
                    result.push_back(arg);            
            return result;
        }

        void register_plugins(theory_checker& pc) override {
            pc.register_plugin(symbol("res"), this);
        }
    };

    theory_checker::theory_checker(ast_manager& m):
        m(m) {
        add_plugin(alloc(arith::theory_checker, m));
        add_plugin(alloc(eq_theory_checker, m));
        add_plugin(alloc(res_checker, m, *this));
        add_plugin(alloc(q::theory_checker, m));
        add_plugin(alloc(distinct::theory_checker, m));
        add_plugin(alloc(smt_theory_checker_plugin, m)); 
        add_plugin(alloc(tseitin::theory_checker, m));
        add_plugin(alloc(bv::theory_checker, m));
    }

    theory_checker::~theory_checker() {
    }

    void theory_checker::add_plugin(theory_checker_plugin* p) {
        m_plugins.push_back(p);
        p->register_plugins(*this);
    }

    void theory_checker::register_plugin(symbol const& rule, theory_checker_plugin* p) {
        m_map.insert(rule, p);
    }

    bool theory_checker::check(expr* e) {
        if (!e || !is_app(e))
            return false;
        app* a = to_app(e);
        theory_checker_plugin* p = nullptr;
        return m_map.find(a->get_decl()->get_name(), p) && p->check(a);
    }

    expr_ref_vector theory_checker::clause(expr* e) {
        SASSERT(is_app(e) && m_map.contains(to_app(e)->get_name()));
        expr_ref_vector r = m_map[to_app(e)->get_name()]->clause(to_app(e));
        return r;
    }

    bool theory_checker::vc(expr* e, expr_ref_vector const& clause, expr_ref_vector& v) {
        SASSERT(is_app(e));
        app* a = to_app(e);
        theory_checker_plugin* p = nullptr;
        if (m_map.find(a->get_name(), p))
            return p->vc(a, clause, v);
        IF_VERBOSE(10, verbose_stream() << "there is no proof plugin for " << mk_pp(e, m) << "\n");
        return false;
    }
   
    bool theory_checker::check(expr_ref_vector const& clause1, expr* e, expr_ref_vector & units) {
        if (!check(e))
            return false;
        units.reset();
        expr_mark literals;
        auto clause2 = clause(e);

        // check that all literals in clause1 are in clause2
        for (expr* arg : clause2)
            literals.mark(arg, true);
        for (expr* arg : clause1)
            if (!literals.is_marked(arg)) {
                if (m.is_not(arg, arg) && m.is_not(arg, arg) && literals.is_marked(arg)) // kludge
                    continue;
                IF_VERBOSE(0, verbose_stream() << mk_bounded_pp(arg, m) << " not in " << clause2 << "\n");
                return false;
            }

        // extract negated units for literals in clause2 but not in clause1
        // the literals should be rup
        literals.reset();
        for (expr* arg : clause1)
            literals.mark(arg, true);
        for (expr* arg : clause2)
            if (!literals.is_marked(arg))
                units.push_back(mk_not(m, arg));

        return true;
    }

    expr_ref_vector smt_theory_checker_plugin::clause(app* jst) {
        expr_ref_vector result(m);
        for (expr* arg : *jst) 
            result.push_back(mk_not(m, arg));
        return result;
    }

    void smt_theory_checker_plugin::register_plugins(theory_checker& pc) {
        pc.register_plugin(symbol("datatype"), this);
        pc.register_plugin(symbol("array"), this);
        pc.register_plugin(symbol("quant"), this);
        pc.register_plugin(symbol("fpa"), this);
    }

    smt_proof_checker::smt_proof_checker(ast_manager& m, params_ref const& p):
        m(m),
        m_params(p),
        m_checker(m),
        m_sat_solver(m_params, m.limit()), 
        m_drat(m_sat_solver) 
    {
        m_params.set_bool("drat.check_unsat", true);
        m_params.set_bool("euf", false);
        m_sat_solver.updt_params(m_params);
        m_drat.updt_config();        
        m_rup = symbol("rup");
        solver_params sp(m_params);
        m_check_rup = sp.proof_check_rup();
    }

    void smt_proof_checker::ensure_solver() {
        if (!m_solver)
            m_solver = mk_smt_solver(m, m_params, symbol());
    }


    void smt_proof_checker::log_verified(app* proof_hint, bool success) {
        if (!proof_hint)
            return;

        symbol n = proof_hint->get_name();
        if (success)
            m_hint2hit.insert_if_not_there(n, 0)++;
        else
            m_hint2miss.insert_if_not_there(n, 0)++;
        ++m_num_logs;

        if (m_num_logs < 100 || (m_num_logs % 1000) == 0) {
            std::cout << "(proofs";
            for (auto const& [k, v] : m_hint2hit)
                std::cout << " +" << k << " " << v;
            for (auto const& [k, v] : m_hint2miss)
                std::cout << " -" << k << " " << v;
            std::cout << ")\n";
        }
    }

    bool smt_proof_checker::check_rup(expr_ref_vector const& clause) {
        if (!m_check_rup)
            return true;
        add_units();                          
        mk_clause(clause);
        return m_drat.is_drup(m_clause.size(), m_clause.data(), m_units);
    }

    bool smt_proof_checker::check_rup(expr* u) {
        if (!m_check_rup)
            return true;
        add_units();
        mk_clause(u);
        return m_drat.is_drup(m_clause.size(), m_clause.data(), m_units);
    }

    void smt_proof_checker::infer(expr_ref_vector& clause, app* proof_hint) {
            
        if (is_rup(proof_hint) && check_rup(clause)) {
            if (m_check_rup) {
                log_verified(proof_hint, true);
                add_clause(clause);
            }
            return;
        }
        
        expr_ref_vector units(m);
        if (m_checker.check(clause, proof_hint, units)) {
            bool units_are_rup = true;
            for (expr* u : units) {
                if (!m.is_true(u) && !check_rup(u)) {
                    std::cout << "unit " << mk_bounded_pp(u, m) << " is not rup\n";
                    units_are_rup = false;
                }
            }
            if (units_are_rup) {
                log_verified(proof_hint, true);
                add_clause(clause);
                return;
            }
        }
        
        // extract a simplified verification condition in case proof validation does not work.
        // quantifier instantiation can be validated as follows:
        // If quantifier instantiation claims that (forall x . phi(x)) => psi using instantiation x -> t
        // then check the simplified VC: phi(t) => psi.
        // in case psi is the literal instantiation, then the clause is a propositional tautology.
        // The VC function is a no-op if the proof hint does not have an associated vc generator.
        expr_ref_vector vc(clause);
        if (m_checker.vc(proof_hint, clause, vc)) {
            log_verified(proof_hint, true);
            add_clause(clause);
            return;
        }
        
        log_verified(proof_hint, false);

        ensure_solver();
        m_solver->push();
        for (expr* lit : vc)
            m_solver->assert_expr(m.mk_not(lit));
        lbool is_sat = m_solver->check_sat();
        if (is_sat != l_false) {
            std::cout << "did not verify: " << is_sat << " " << clause << "\n";
            std::cout << "vc:\n" << vc << "\n";
            if (proof_hint) 
                std::cout << "hint: " << mk_bounded_pp(proof_hint, m, 4) << "\n";
            m_solver->display(std::cout);
            if (is_sat == l_true) {
                model_ref mdl;
                m_solver->get_model(mdl);
                mdl->evaluate_constants();
                std::cout << *mdl << "\n";
            }                
            exit(0);
        }
        m_solver->pop(1);
        std::cout << "(verified-smt"; 
        if (proof_hint) std::cout << "\n" << mk_bounded_pp(proof_hint, m, 4);
        for (expr* arg : clause)
            std::cout << "\n " << mk_bounded_pp(arg, m);
        std::cout << ")\n";
        std::cout.flush();

        if (false && is_rup(proof_hint)) 
            diagnose_rup_failure(clause);
            
        add_clause(clause);
    }

    void smt_proof_checker::diagnose_rup_failure(expr_ref_vector const& clause) {
        expr_ref_vector fmls(m), assumptions(m), core(m);
        m_solver->get_assertions(fmls);
        for (unsigned i = 0; i < fmls.size(); ++i) {
            assumptions.push_back(m.mk_fresh_const("a", m.mk_bool_sort()));
            fmls[i] = m.mk_implies(assumptions.back(), fmls.get(i));
        }
            
        ref<::solver> core_solver = mk_smt_solver(m, m_params, symbol());
        // core_solver->assert_expr(fmls);
        core_solver->assert_expr(m.mk_not(mk_or(clause)));
        lbool ch = core_solver->check_sat(assumptions);
        std::cout << "failed to verify\n" << clause << "\n";
        if (ch == l_false) {
            core_solver->get_unsat_core(core);
            std::cout << "core\n";
            for (expr* f : core)
                std::cout << mk_pp(f, m) << "\n";
        }
    }

    void smt_proof_checker::collect_statistics(statistics& st) const {
        if (m_solver)
            m_solver->collect_statistics(st);

    }
    
}





© 2015 - 2024 Weber Informatics LLC | Privacy Policy