All Downloads are FREE. Search and download functionalities are using the official Maven repository.

z3-z3-4.13.0.src.sat.smt.arith_theory_checker.h Maven / Gradle / Ivy

The newest version!
/*++
Copyright (c) 2022 Microsoft Corporation

Module Name:

    arith_proof_checker.h

Abstract:

    Plugin for checking arithmetic lemmas

Author:

    Nikolaj Bjorner (nbjorner) 2022-08-28

Notes:

The module assumes a limited repertoire of arithmetic proof rules.

- farkas - inequalities, equalities and disequalities with coefficients
- implied-eq - last literal is a disequality. The literals before imply the complementary equality.
- bound - last literal is a bound. It is implied by prior literals.

--*/
#pragma once

#include "util/obj_pair_set.h"
#include "ast/ast_trail.h"
#include "ast/ast_util.h"
#include "ast/arith_decl_plugin.h"
#include "sat/smt/euf_proof_checker.h"
#include 


namespace arith {

    class theory_checker : public euf::theory_checker_plugin {

        enum rule_type_t {
            cut_t,
            farkas_t,
            implied_eq_t,
            bound_t,
            none_t
        };
        
        struct row {
            obj_map m_coeffs;
            rational m_coeff;
            void reset() {
                m_coeffs.reset();
                m_coeff = 0;
            }
            bool is_zero() const {
                return m_coeffs.empty() && m_coeff == 0;
            }
        };
       
        ast_manager& m;
        arith_util   a;
        vector> m_todo;
        bool         m_strict = false;
        row          m_ineq;
        row          m_conseq;
        vector  m_eqs, m_ineqs;
        symbol       m_farkas = symbol("farkas");
        symbol       m_implied_eq = symbol("implied-eq");
        symbol       m_bound = symbol("bound");
        symbol       m_cut = symbol("cut");

        rule_type_t rule_type(app* jst) const {
            if (jst->get_name() == m_cut)
                return cut_t;
            if (jst->get_name() == m_bound)
                return bound_t;
            if (jst->get_name() == m_implied_eq)
                return implied_eq_t;
            if (jst->get_name() == m_farkas)
                return farkas_t;
            return none_t;
        }
        
        
        void add(row& r, expr* v, rational const& coeff) {
            rational coeff1;
            if (coeff.is_zero())
                return;
            if (r.m_coeffs.find(v, coeff1)) {
                coeff1 += coeff;
                if (coeff1.is_zero())
                    r.m_coeffs.erase(v);
                else
                    r.m_coeffs[v] = coeff1;
            }
            else
                r.m_coeffs.insert(v, coeff);
        }
        
        void mul(row& r, rational const& coeff) {
            if (coeff == 1)
                return;
            for (auto & [v, c] : r.m_coeffs)
                c *= coeff;
            r.m_coeff *= coeff;
        }
        
        // dst <- dst + mul*src
        void add(row& dst, row const& src, rational const& mul) {
            for (auto const& [v, c] : src.m_coeffs)
                add(dst, v, c*mul);
            dst.m_coeff += mul*src.m_coeff;
        }
        
        // dst <- X*dst + Y*src
        // where
        // X = lcm(a,b)/b, Y = -lcm(a,b)/a  if v is integer
        // X = 1/b, Y = -1/a                if v is real
        //
        bool resolve(expr* v, row& dst, rational const& A, row const& src) {
            rational B, x, y;
            if (!dst.m_coeffs.find(v, B))
                return false;
            if (a.is_int(v)) {
                rational lc = lcm(abs(A), abs(B));
                x =  lc / abs(B);
                y =  lc / abs(A);
            }
            else {
                x = rational(1)  / abs(B);
                y = rational(1)  / abs(A);
            }
            if (A < 0 && B < 0)
                y.neg();
            if (A > 0 && B > 0)
                y.neg();
            mul(dst, x);
            add(dst, src, y);
            return true;
        }

        void cut(row& r) {
            if (r.m_coeffs.empty())
                return;
            auto const& [v, coeff] = *r.m_coeffs.begin();
            if (!a.is_int(v))
                return;
            rational lc = denominator(r.m_coeff);
            for (auto const& [v, coeff] : r.m_coeffs)
                lc = lcm(lc, denominator(coeff));
            if (lc != 1) {
                r.m_coeff *= lc;
                for (auto & [v, coeff] : r.m_coeffs)
                    coeff *= lc;
            }
            rational g(0);
            for (auto const& [v, coeff] : r.m_coeffs)
                g = gcd(coeff, g);
            if (g == 1)
                return;
            rational m = mod(r.m_coeff, g);
            if (m == 0)
                return;
            r.m_coeff += g - m;
        }
        
        /**
         * \brief populate m_coeffs, m_coeff based on mul*e 
         */
        void linearize(row& r, rational const& mul, expr* e) {
            SASSERT(m_todo.empty());
            m_todo.push_back({ mul, e });
            rational coeff1;
            expr* e1, *e2;
            for (unsigned i = 0; i < m_todo.size(); ++i) {
                auto [coeff, e] = m_todo[i];
                if (a.is_mul(e, e1, e2) && is_numeral(e1, coeff1))
                    m_todo.push_back({coeff*coeff1, e2});
                else if (a.is_mul(e, e1, e2) && is_numeral(e2, coeff1))
                    m_todo.push_back({ coeff * coeff1, e1 });
                else if (a.is_add(e))
                    for (expr* arg : *to_app(e))
                        m_todo.push_back({coeff, arg});
                else if (a.is_uminus(e, e1))
                    m_todo.push_back({-coeff, e1});
                else if (a.is_sub(e, e1, e2)) {
                    m_todo.push_back({coeff, e1});
                    m_todo.push_back({-coeff, e2});                
                }
                else if (is_numeral(e, coeff1)) 
                    r.m_coeff += coeff*coeff1;
                else
                    add(r, e, coeff);
            }
            m_todo.reset();
        }

        bool is_numeral(expr* e, rational& n) {
            if (a.is_numeral(e, n))
                return true;
            if (a.is_uminus(e, e) && a.is_numeral(e, n))
                return n.neg(), true;
            return false;
        }
        
        bool check_ineq(row& r) {
            if (r.m_coeffs.empty() && r.m_coeff > 0)
                return true;
            if (r.m_coeffs.empty() && m_strict && r.m_coeff == 0)
                return true;
            return false;
        }
        
        // triangulate equalities, substitute results into m_ineq, m_conseq.
        // check consistency of equalities (they may be inconsisent)
        bool reduce_eq() {
            for (unsigned i = 0; i < m_eqs.size(); ++i) {
                auto& r = m_eqs[i];
                if (r.m_coeffs.empty() && r.m_coeff != 0)
                    return false;
                if (r.m_coeffs.empty())
                    continue;
                auto [v, coeff] = *r.m_coeffs.begin();
                for (unsigned j = i + 1; j < m_eqs.size(); ++j)
                    resolve(v, m_eqs[j], coeff, r);
                resolve(v, m_ineq, coeff, r);
                resolve(v, m_conseq, coeff, r);
                for (auto& ineq : m_ineqs)
                    resolve(v, ineq, coeff, r);
            }
            return true;
        }
        
        
        bool add_literal(row& r, rational const& coeff, expr* e, bool sign) {
            expr* e1, *e2 = nullptr;
            if ((a.is_le(e, e1, e2) || a.is_ge(e, e2, e1)) && !sign) {
                linearize(r, coeff, e1);
                linearize(r, -coeff, e2);
            }
            else if ((a.is_lt(e, e1, e2) || a.is_gt(e, e2, e1)) && sign) {
                linearize(r, coeff, e2);
                linearize(r, -coeff, e1);
            }
            else if ((a.is_le(e, e1, e2) || a.is_ge(e, e2, e1)) && sign) {
                linearize(r, coeff, e2);
                linearize(r, -coeff, e1);
                if (a.is_int(e1))
                    r.m_coeff += coeff;
                else
                    m_strict = true;
            }
            else if ((a.is_lt(e, e1, e2) || a.is_gt(e, e2, e1)) && !sign) {
                linearize(r, coeff, e1);
                linearize(r, -coeff, e2);
                if (a.is_int(e1))
                    r.m_coeff += coeff;
                else
                    m_strict = true;
            }
            else
                return false;
            // display_row(std::cout << coeff << " * " << (sign?"~":"") << mk_pp(e, m) << "\n", r) << "\n";
            return true;
        }
        
        bool check_farkas() {
            if (check_ineq(m_ineq))
                return true;
            if (!reduce_eq())
                return true;
            if (check_ineq(m_ineq))
                return true;
            IF_VERBOSE(3, display_row(verbose_stream() << "Failed to verify Farkas with reduced row ", m_ineq) << "\n");
            // convert to expression, maybe follows from a cut.
            return false;
        }

        //
        // farkas coefficient is computed for m_conseq 
        // after all inequalities in ineq have been added up
        //
        bool check_bound() {
            if (!reduce_eq())
                return true;
            if (check_ineq(m_conseq))
                return true;
            if (m_ineq.m_coeffs.empty() ||
                m_conseq.m_coeffs.empty())
                return false;
            cut(m_ineq);
            auto const& [v, coeff1] = *m_ineq.m_coeffs.begin();
            rational coeff2;
            if (!m_conseq.m_coeffs.find(v, coeff2))
                return false;
            add(m_conseq, m_ineq, abs(coeff2/coeff1));
            if (check_ineq(m_conseq))
                return true;            
            return false;
        }

        /**
           Check implied equality lemma:

           inequalities & equalities => equality

           
           We may assume the set of inequality assumptions we are given are all tight, non-strict and imply equalities.
           In other words, given a set of inequalities a1x + b1 <= 0, ..., anx + bn <= 0
           the equalities a1x + b1 = 0, ..., anx + bn = 0 are all consequences.

           We use a weaker property: We derive implied equalities by applying exhaustive Fourier-Motzkin
           elimination and then collect the tight 0 <= 0 inequalities that are derived.
           
           Claim: the set of inequalities used to derive 0 <= 0 are all tight equalities.
         */

        svector> m_deps;
        unsigned_vector m_tight_inequalities;
        uint_set m_ineqs_that_are_eqs;
        
        bool check_implied_eq() {
            if (!reduce_eq())
                return true;
            if (m_conseq.is_zero())
                return true;

            m_eqs.reset();
            m_deps.reset();
            unsigned orig_size = m_ineqs.size();
            m_deps.reserve(orig_size);
            for (unsigned i = 0; i < m_ineqs.size(); ++i) {
                row& r = m_ineqs[i];
                if (r.is_zero()) {
                    m_tight_inequalities.push_back(i);                    
                    continue;
                }
                auto const& [v, coeff] = *r.m_coeffs.begin();
                unsigned sz = m_ineqs.size();
                
                for (unsigned j = i + 1; j < sz; ++j) {
                    rational B;
                    row& r2 = m_ineqs[j]; 
                    if (!r2.m_coeffs.find(v, B) || (coeff > 0 && B > 0) || (coeff < 0 && B < 0))
                        continue;
                    row& r3 = fresh(m_ineqs);
                    add(r3, m_ineqs[j], rational::one());
                    resolve(v, r3, coeff, m_ineqs[i]);
                    m_deps.push_back({i, j});
                }
                SASSERT(m_deps.size() == m_ineqs.size());
            }

            m_ineqs_that_are_eqs.reset();
            while (!m_tight_inequalities.empty()) {
                unsigned j = m_tight_inequalities.back();
                m_tight_inequalities.pop_back();
                if (m_ineqs_that_are_eqs.contains(j))
                    continue;                
                m_ineqs_that_are_eqs.insert(j);
                if (j < orig_size) {
                    m_eqs.push_back(m_ineqs[j]);
                }
                else {
                    auto [a, b] = m_deps[j];
                    m_tight_inequalities.push_back(a);
                    m_tight_inequalities.push_back(b);
                }                            
            }
            m_ineqs.reset();

            VERIFY (reduce_eq());

            return m_conseq.is_zero();
        }

        std::ostream& display_row(std::ostream& out, row const& r) {
            bool first = true;
            for (auto const& [v, coeff] : r.m_coeffs) {
                if (!first)
                    out << " + ";
                if (coeff != 1)
                    out << coeff << " * ";
                out << mk_pp(v, m);
                first = false;
            }
            if (r.m_coeff != 0) 
                out << " + " << r.m_coeff;                
            return out;
        }


        void display_eq(std::ostream& out, row const& r) {
            display_row(out, r);
            out << " = 0\n";
        }

        void display_ineq(std::ostream& out, row const& r) {
            display_row(out, r);
            if (m_strict)
                out << " < 0\n";
            else 
                out << " <= 0\n";
        }

        row& fresh(vector& rows) {
            rows.push_back(row());
            return rows.back();
        }
        
    public:
        theory_checker(ast_manager& m): 
            m(m), 
            a(m) {}

        void reset() {
            m_ineq.reset();
            m_conseq.reset();
            m_eqs.reset();
            m_ineqs.reset();
            m_strict = false;
        }
        
        bool add_ineq(rule_type_t rt, rational const& coeff, expr* e, bool sign) {
            row& r = rt == implied_eq_t ? fresh(m_ineqs) : m_ineq;
            return add_literal(r, abs(coeff), e, sign);            
        }

        bool add_conseq(rational const& coeff, expr* e, bool sign) {
            return add_literal(m_conseq, abs(coeff), e, sign);
        }
        
        void add_eq(expr* a, expr* b) {
            row& r = fresh(m_eqs);
            linearize(r, rational(1), a);
            linearize(r, rational(-1), b);
        }
        
        bool check(rule_type_t rt) {
            switch (rt) {
            case farkas_t:
                return check_farkas();
            case bound_t:
                return check_bound();
            case implied_eq_t:
                return check_implied_eq();
            default:
                return check_bound();
            }
        }

        std::ostream& display(std::ostream& out) {
            for (auto & r : m_eqs)
                display_eq(out, r);
            display_ineq(out, m_ineq);
            if (!m_conseq.m_coeffs.empty())
                display_ineq(out, m_conseq);
            return out;
        }

        expr_ref_vector clause(app* jst) override {
            expr_ref_vector result(m);
            for (expr* arg : *jst) 
                if (m.is_bool(arg))
                    result.push_back(mk_not(m, arg));
            return result;
        }

        /**
           Add implied equality as an inequality
         */
        bool add_implied_diseq(bool sign, app* jst) {
            unsigned n = jst->get_num_args();
            if (n < 2)
                return false;
            expr* arg1 = jst->get_arg(n - 2);
            expr* arg2 = jst->get_arg(n - 1);
            rational coeff;
            if (!a.is_numeral(arg1, coeff))
                return false;
            if (!m.is_not(arg2, arg2))
                return false;
            if (!m.is_eq(arg2, arg1, arg2))
                return false;
            if (!sign)
                coeff.neg();
            auto& r = m_conseq;
            linearize(r, coeff, arg1);
            linearize(r, -coeff, arg2);
            return true;
        }

        bool check(app* jst) override {
            reset();

            auto rt = rule_type(jst);
            switch (rt) {
            case cut_t:
                return false;
            case none_t:
                IF_VERBOSE(0, verbose_stream() << "unhandled inference " << mk_pp(jst, m) << "\n");
                return false;
            default:
                break;
            }
            bool even = true;
            rational coeff;
            expr* x, * y;
            unsigned j = 0;
            
            for (expr* arg : *jst) {
                
                if (even) {
                    if (!a.is_numeral(arg, coeff)) {
                        IF_VERBOSE(0, verbose_stream() << "not numeral " << mk_pp(jst, m) << "\n");
                        return false;
                    }
                }
                else {
                    bool sign = m.is_not(arg, arg);
                    if (a.is_le(arg) || a.is_lt(arg) || a.is_ge(arg) || a.is_gt(arg)) {
                        if (rt == bound_t && j + 1 == jst->get_num_args())
                            add_conseq(coeff, arg, sign);
                        else
                            add_ineq(rt, coeff, arg, sign);
                    }
                    else if (m.is_eq(arg, x, y)) {
                        if (rt == bound_t && j + 1 == jst->get_num_args())
                            add_conseq(coeff, arg, sign);
                        else if (rt == implied_eq_t && j + 1 == jst->get_num_args()) 
                            return add_implied_diseq(sign, jst) && check(rt);
                        else if (!sign) 
                            add_eq(x, y);
                        else {
                            IF_VERBOSE(0, verbose_stream() << "unexpected disequality in justification " << mk_pp(arg, m) << "\n");
                            return false;
                        }
                    }
                    else {
                        IF_VERBOSE(0, verbose_stream() << "not a recognized arithmetical relation " << mk_pp(arg, m) << "\n");
                        return false;
                    }
                }
                even = !even;
                ++j;
            }
            return check(rt);
        }

        void register_plugins(euf::theory_checker& pc) override {
            pc.register_plugin(m_farkas, this);
            pc.register_plugin(m_bound, this);
            pc.register_plugin(m_implied_eq, this);
            pc.register_plugin(m_cut, this);
        }
        
    };

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy