All Downloads are FREE. Search and download functionalities are using the official Maven repository.

z3-z3-4.13.0.src.sat.smt.arith_sls.cpp Maven / Gradle / Ivy

The newest version!
/*++
Copyright (c) 2023 Microsoft Corporation

Module Name:

    arith_local_search.cpp

Abstract:

    Local search dispatch for SMT

Author:

    Nikolaj Bjorner (nbjorner) 2023-02-07

--*/
#include "sat/sat_solver.h"
#include "sat/smt/arith_solver.h"


namespace arith {
    
    sls::sls(solver& s): 
        s(s), m(s.m) {}

    void sls::reset() {
        m_bool_vars.reset();
        m_vars.reset();
        m_terms.reset();
    }

    void sls::save_best_values() {
        for (unsigned v = 0; v < s.get_num_vars(); ++v)
            m_vars[v].m_best_value = m_vars[v].m_value;
        check_ineqs();   
        if (unsat().size() == 1) {
            auto idx = *unsat().begin();
            verbose_stream() << idx << "\n";
            auto const& c = *m_bool_search->m_clauses[idx].m_clause;
            verbose_stream() << c << "\n";
            for (auto lit : c) {
                bool_var bv = lit.var();
                ineq* i = atom(bv);
                if (i)
                    verbose_stream() << lit << ": " << *i << "\n";
            }
            verbose_stream() << "\n";
        }
    }

    void sls::store_best_values() {
        // first compute assignment to terms
        // then update non-basic variables in tableau.

        if (!unsat().empty())
            return;
        
        for (auto const& [t,v] : m_terms) {
            int64_t val = 0;
            lp::lar_term const& term = s.lp().get_term(t);
            for (lp::lar_term::ival const& arg : term) {
                auto t2 = arg.j();
                auto w = s.lp().local_to_external(t2);
                val += to_numeral(arg.coeff()) * m_vars[w].m_best_value;
            }
            m_vars[v].m_best_value = val;
        }

        for (unsigned v = 0; v < s.get_num_vars(); ++v) {
            if (s.is_bool(v))
                continue;
            if (!s.lp().external_is_used(v)) 
                continue;            
            int64_t new_value = m_vars[v].m_best_value;
            s.ensure_column(v);
            lp::lpvar vj = s.lp().external_to_local(v);
            SASSERT(vj  != lp::null_lpvar);
            if (!s.lp().is_base(vj)) {
                rational new_value_(new_value, rational::i64());
                lp::impq val(new_value_, rational::zero());
                s.lp().set_value_for_nbasic_column(vj, val);
            }
        }

        lbool r = s.make_feasible();
        VERIFY (!unsat().empty() || r == l_true);    
#if 0
        if (unsat().empty()) 
            s.m_num_conflicts = s.get_config().m_arith_propagation_threshold;
#endif   

        auto check_bool_var = [&](sat::bool_var bv) {
            auto* ineq = m_bool_vars.get(bv, nullptr);
            if (!ineq)
                return;
            api_bound* b = nullptr;
            s.m_bool_var2bound.find(bv, b);
            if (!b)
                return;
            auto bound = b->get_value();
            theory_var v = b->get_var();
            if (s.get_phase(bv) == m_bool_search->get_model()[bv])
                return;
            switch (b->get_bound_kind()) {
            case lp_api::lower_t:
                verbose_stream() << "v" << v << " " << bound << " <= " << s.get_value(v) << " " << m_vars[v].m_best_value << "\n";
                break;
            case lp_api::upper_t:
                verbose_stream() << "v" << v << " " << bound << " >= " << s.get_value(v) << " " << m_vars[v].m_best_value << "\n";
                break;
            }
            int64_t value = 0;
            for (auto const& [coeff, v] : ineq->m_args) {
                value += coeff * m_vars[v].m_best_value;
            }
            ineq->m_args_value = value;
            verbose_stream() << *ineq << " dtt " << dtt(false, *ineq) << " phase " << s.get_phase(bv) << " model " << m_bool_search->get_model()[bv] << "\n";
            for (auto const& [coeff, v] : ineq->m_args) 
                verbose_stream() << "v" << v << " := " << m_vars[v].m_best_value << "\n";
            s.display(verbose_stream());
            display(verbose_stream());
            UNREACHABLE();
            exit(0);
        };

        if (unsat().empty()) {
            for (bool_var v = 0; v < s.s().num_vars(); ++v) 
                check_bool_var(v);            
        }
    }

    void sls::set(sat::ddfw* d) { 
        m_bool_search = d; 
        reset();
        m_bool_vars.reserve(s.s().num_vars());
        add_vars();
        for (unsigned i = 0; i < d->num_clauses(); ++i)
            for (sat::literal lit : *d->get_clause_info(i).m_clause)
                init_bool_var(lit.var());
        for (unsigned v = 0; v < s.s().num_vars(); ++v)
            init_bool_var_assignment(v);

        d->set(this);
    }

    // distance to true
    int64_t sls::dtt(bool sign, int64_t args, ineq const& ineq) const {
        switch (ineq.m_op) {
        case ineq_kind::LE:
            if (sign) {
                if (args <= ineq.m_bound)
                    return ineq.m_bound - args + 1;
                return 0;
            }
            if (args <= ineq.m_bound)
                return 0;
            return args - ineq.m_bound;
        case ineq_kind::EQ:
            if (sign) {
                if (args == ineq.m_bound)
                    return 1;
                return 0;
            }
            if (args == ineq.m_bound)
                return 0;
            return 1;
        case ineq_kind::NE:
            if (sign) {
                if (args == ineq.m_bound)
                    return 0;
                return 1;
            }
            if (args == ineq.m_bound)
                return 1;
            return 0;
        case ineq_kind::LT:
            if (sign) {
                if (args < ineq.m_bound)
                    return ineq.m_bound - args;
                return 0;
            }
            if (args < ineq.m_bound)
                return 0;
            return args - ineq.m_bound + 1;
        default:
            UNREACHABLE();
            return 0;
        }
    }

    //
    // dtt is high overhead. It walks ineq.m_args
    // m_vars[w].m_value can be computed outside and shared among calls
    // different data-structures for storing coefficients
    // 
    int64_t sls::dtt(bool sign, ineq const& ineq, var_t v, int64_t new_value) const {
        for (auto const& [coeff, w] : ineq.m_args) 
            if (w == v)
                return dtt(sign, ineq.m_args_value + coeff * (new_value - m_vars[v].m_value), ineq);
        return 1;
    }

    int64_t sls::dtt(bool sign, ineq const& ineq, int64_t coeff, int64_t old_value, int64_t new_value) const {
        return dtt(sign, ineq.m_args_value + coeff * (new_value - old_value), ineq);
    }

    bool sls::cm(bool old_sign, ineq const& ineq, var_t v, int64_t& new_value) {
        for (auto const& [coeff, w] : ineq.m_args) 
            if (w == v)
                return cm(old_sign, ineq, v, coeff, new_value);        
        return false;
    }

    bool sls::cm(bool old_sign, ineq const& ineq, var_t v, int64_t coeff, int64_t& new_value) {
        SASSERT(ineq.is_true() != old_sign);
        VERIFY(ineq.is_true() != old_sign);
        auto bound = ineq.m_bound;
        auto argsv = ineq.m_args_value;
        bool solved = false;
        int64_t delta = argsv - bound;
        auto make_eq = [&]() {
            SASSERT(delta != 0);
            if (delta < 0)
                new_value = value(v) + (abs(delta) + abs(coeff) - 1) / coeff;
            else
                new_value = value(v) - (delta + abs(coeff) - 1) / coeff;
            solved = argsv + coeff * (new_value - value(v)) == bound;
            if (!solved && abs(coeff) == 1) {
                verbose_stream() << "did not solve equality " << ineq << " for " << v << "\n";
                verbose_stream() << new_value << " " << value(v) << " delta " << delta << " lhs " << (argsv + coeff * (new_value - value(v))) << " bound " << bound << "\n";
                UNREACHABLE();
            }
            return solved;
        };

        auto make_diseq = [&]() {
            if (delta >= 0)
                delta++;
            else
                delta--;
            new_value = value(v) + (abs(delta) + abs(coeff) - 1) / coeff;
            VERIFY(argsv + coeff * (new_value - value(v)) != bound);
            return true;
        };

        if (!old_sign) {
            switch (ineq.m_op) {                
            case ineq_kind::LE:
                // args <= bound -> args > bound
                SASSERT(argsv <= bound);
                SASSERT(delta <= 0);
                --delta;
                new_value = value(v) + (abs(delta) + abs(coeff) - 1) / coeff;
                VERIFY(argsv + coeff * (new_value - value(v)) > bound);
                return true;
            case ineq_kind::LT:
                // args < bound -> args >= bound
                SASSERT(argsv <= ineq.m_bound);
                SASSERT(delta <= 0);
                new_value = value(v) + (abs(delta) + abs(coeff) - 1) / coeff;
                VERIFY(argsv + coeff * (new_value - value(v)) >= bound);
                return true;
            case ineq_kind::EQ:
                return make_diseq();
            case ineq_kind::NE:
                return make_eq();
            default:
                UNREACHABLE();
                break;
            }
        }
        else {
            switch (ineq.m_op) {
            case ineq_kind::LE:
                SASSERT(argsv > ineq.m_bound);
                SASSERT(delta > 0);
                new_value = value(v) - (delta + abs(coeff) - 1) / coeff;
                VERIFY(argsv + coeff * (new_value - value(v)) <= bound);
                return true;
            case ineq_kind::LT:
                SASSERT(argsv >= ineq.m_bound);
                SASSERT(delta >= 0);
                ++delta;
                new_value = value(v) - (abs(delta) + abs(coeff) - 1) / coeff;
                VERIFY(argsv + coeff * (new_value - value(v)) < bound);
                return true;
            case ineq_kind::NE:
                return make_diseq();
            case ineq_kind::EQ:
                return make_eq();
            default:
                UNREACHABLE();
                break;
            }
        }
        return false;
    }

    // flip on the first positive score
    // it could be changed to flip on maximal positive score
    // or flip on maximal non-negative score
    // or flip on first non-negative score
    bool sls::flip(bool sign, ineq const& ineq) {
        int64_t new_value;
        auto v = ineq.m_var_to_flip;
        if (v == UINT_MAX) {
            IF_VERBOSE(1, verbose_stream() << "no var to flip\n");
            return false;
        }
        if (!cm(sign, ineq, v, new_value)) {
            verbose_stream() << "no critical move for " << v << "\n";
            return false;
        }
        update(v, new_value);
        return true;
    }

    //
    // dscore(op) = sum_c (dts(c,alpha) - dts(c,alpha_after)) * weight(c)
    // TODO - use cached dts instead of computed dts
    // cached dts has to be updated when the score of literals are updated.
    // 
    double sls::dscore(var_t v, int64_t new_value) const {        
        double score = 0;
        auto const& vi = m_vars[v];
        for (auto const& [coeff, bv] : vi.m_bool_vars) {
            sat::literal lit(bv, false);
            for (auto cl : m_bool_search->get_use_list(lit))
                score += (compute_dts(cl) - dts(cl, v, new_value)) * m_bool_search->get_weight(cl);
            for (auto cl : m_bool_search->get_use_list(~lit))
                score += (compute_dts(cl) - dts(cl, v, new_value)) * m_bool_search->get_weight(cl);
        }
        return score;
    }

    //
    // cm_score is costly. It involves several cache misses.
    // Note that
    // - m_bool_search->get_use_list(lit).size() is "often" 1 or 2
    // - dtt_old can be saved
    //
    int sls::cm_score(var_t v, int64_t new_value) {
        int score = 0;
        auto& vi = m_vars[v];
        int64_t old_value = vi.m_value;
        for (auto const& [coeff, bv] : vi.m_bool_vars) {
            auto const& ineq = *atom(bv);            
            bool old_sign = sign(bv);
            int64_t dtt_old = dtt(old_sign, ineq);
            int64_t dtt_new = dtt(old_sign, ineq, coeff, old_value, new_value);
            if ((dtt_old == 0) == (dtt_new == 0))
                continue;
            sat::literal lit(bv, old_sign);
            if (dtt_old == 0) 
                // flip from true to false
                lit.neg();
                  
            // lit flips form false to true:            
            for (auto cl : m_bool_search->get_use_list(lit)) {
                auto const& clause = get_clause_info(cl);
                if (!clause.is_true())
                    ++score;
            }
            // ignore the situation where clause contains multiple literals using v
            for (auto cl : m_bool_search->get_use_list(~lit)) {
                auto const& clause = get_clause_info(cl);
                if (clause.m_num_trues == 1)
                    --score;
            }
        }
        return score;
    }

    int64_t sls::compute_dts(unsigned cl) const {
        int64_t d(1), d2;
        bool first = true;
        for (auto a : get_clause(cl)) {
            auto const* ineq = atom(a.var());
            if (!ineq)
                continue;
            d2 = dtt(a.sign(), *ineq);
            if (first)
                d = d2, first = false;
            else
                d = std::min(d, d2);
            if (d == 0)
                break;
        }
        return d;
    }

    int64_t sls::dts(unsigned cl, var_t v, int64_t new_value) const {
        int64_t d(1), d2;
        bool first = true;
        for (auto lit : get_clause(cl)) {
            auto const* ineq = atom(lit.var());
            if (!ineq)
                continue;
            d2 = dtt(lit.sign(), *ineq, v, new_value);
            if (first)
                d = d2, first = false;
            else
                d = std::min(d, d2);
            if (d == 0)
                break;
        }
        return d;
    }

    void sls::update(var_t v, int64_t new_value) {
        auto& vi = m_vars[v];
        auto old_value = vi.m_value;
        for (auto const& [coeff, bv] : vi.m_bool_vars) {
            auto& ineq = *atom(bv);
            bool old_sign = sign(bv);
            sat::literal lit(bv, old_sign);
            SASSERT(is_true(lit));            
            ineq.m_args_value += coeff * (new_value - old_value);
            int64_t dtt_new = dtt(old_sign, ineq);
            if (dtt_new != 0)
                m_bool_search->flip(bv);                                                 
            SASSERT(dtt(sign(bv), ineq) == 0);
        }
        vi.m_value = new_value;
    }

    void sls::add_vars() {
        SASSERT(m_vars.empty());
        for (unsigned v = 0; v < s.get_num_vars(); ++v) {
            int64_t value = s.is_registered_var(v)  ? to_numeral(s.get_ivalue(v).x) : 0;
            auto k = s.is_int(v) ? sls::var_kind::INT : sls::var_kind::REAL;
            m_vars.push_back({ value, value, k, {} });
        }
    }

    sls::ineq& sls::new_ineq(ineq_kind op, int64_t const& bound) {
        auto* i = alloc(ineq);
        i->m_bound = bound;
        i->m_op = op;
        return *i;
    }

    void sls::add_arg(sat::bool_var bv, ineq& ineq, int64_t const& c, var_t v) {
        ineq.m_args.push_back({ c, v });
        ineq.m_args_value += c * value(v);
        m_vars[v].m_bool_vars.push_back({ c, bv});
    }

    int64_t sls::to_numeral(rational const& r) {
        if (r.is_int64())
            return r.get_int64();
        return 0;
    }

    void sls::add_args(sat::bool_var bv, ineq& ineq, lp::lpvar t, theory_var v, int64_t sign) {
        if (s.lp().column_has_term(t)) {
            lp::lar_term const& term = s.lp().get_term(t);
            m_terms.push_back({t,v});
            for (lp::lar_term::ival arg : term) {
                auto t2 = arg.j();
                auto w = s.lp().local_to_external(t2);
                add_arg(bv, ineq, sign * to_numeral(arg.coeff()), w);
            }
        }
        else
            add_arg(bv, ineq, sign, s.lp().local_to_external(t));
    }

    void sls::init_bool_var(sat::bool_var bv) {
        if (m_bool_vars.get(bv, nullptr))
            return;
        api_bound* b = nullptr;
        s.m_bool_var2bound.find(bv, b);
        if (b) {
            auto t = b->column_index();
            rational bound = b->get_value();
            bool should_minus = false;
            sls::ineq_kind op;
            should_minus = b->get_bound_kind() == lp_api::bound_kind::lower_t;
            op = sls::ineq_kind::LE;
            if (should_minus)
                bound.neg();

            auto& ineq = new_ineq(op, to_numeral(bound));


            add_args(bv, ineq, t, b->get_var(), should_minus ? -1 : 1);
            m_bool_vars.set(bv, &ineq);
            m_bool_search->set_external(bv);
            return;
        }

        expr* e = s.bool_var2expr(bv);
        expr* l = nullptr, * r = nullptr;
        if (e && m.is_eq(e, l, r) && s.a.is_int_real(l)) {
            theory_var u = s.get_th_var(l);
            theory_var v = s.get_th_var(r);
            lp::lpvar tu = s.get_column(u);
            lp::lpvar tv = s.get_column(v);
            auto& ineq = new_ineq(sls::ineq_kind::EQ, 0);            
            add_args(bv, ineq, tu, u, 1);
            add_args(bv, ineq, tv, v, -1);
            m_bool_vars.set(bv, &ineq);
            m_bool_search->set_external(bv);
            return;
        }
    }

    void sls::init_bool_var_assignment(sat::bool_var v) {
        auto* ineq = m_bool_vars.get(v, nullptr);
        if (ineq && is_true(sat::literal(v, false)) != (dtt(false, *ineq) == 0))
            m_bool_search->flip(v);
    }

    void sls::init_search()  {
        on_restart();
    }

    void sls::finish_search()  {
        store_best_values();
    }

    void sls::flip(sat::bool_var v)  {
        sat::literal lit(v, !sign(v));
        SASSERT(!is_true(lit));
        auto const* ineq = atom(v);
        if (!ineq) 
            IF_VERBOSE(0, verbose_stream() << "no inequality for variable " << v << "\n");
        if (!ineq)
            return;
        SASSERT(ineq->is_true() == lit.sign());
        flip(sign(v), *ineq);
    }

    double sls::reward(sat::bool_var v) {
        if (m_dscore_mode)
            return dscore_reward(v);
        else 
            return dtt_reward(v);
    }

    double sls::dtt_reward(sat::bool_var bv0) {
        bool sign0 = sign(bv0);
        auto* ineq = atom(bv0);
        if (!ineq)
            return -1;
        int64_t new_value;      
        double max_result = -1;
        for (auto const & [coeff, x] : ineq->m_args) {
            if (!cm(sign0, *ineq, x, coeff, new_value))
                continue;
            double result = 0;
            auto old_value = m_vars[x].m_value;
            for (auto const& [coeff, bv] : m_vars[x].m_bool_vars) {
                result += m_bool_search->reward(bv);
                continue;
                bool old_sign = sign(bv);
                auto dtt_old = dtt(old_sign, *atom(bv));
                auto dtt_new = dtt(old_sign, *atom(bv), coeff, old_value, new_value);
                if ((dtt_new == 0) != (dtt_old == 0))
                    result += m_bool_search->reward(bv);
            }
            if (result > max_result) {
                max_result = result;
                ineq->m_var_to_flip = x;
            }
        }
        return max_result; 
    }

    double sls::dscore_reward(sat::bool_var bv) {
        m_dscore_mode = false;
        bool old_sign = sign(bv);
        sat::literal litv(bv, old_sign);
        auto* ineq = atom(bv);
        if (!ineq)
            return 0;
        SASSERT(ineq->is_true() != old_sign);
        int64_t new_value;

        for (auto const& [coeff, v] : ineq->m_args) {
            double result = 0;
            if (cm(old_sign, *ineq, v, coeff, new_value))
                result = dscore(v, new_value);
            // just pick first positive, or pick a max?
            if (result > 0) {
                ineq->m_var_to_flip = v;
                return result;
            }
        }
        return 0;
    }

    // switch to dscore mode
    void sls::on_rescale()  {
        m_dscore_mode = true;
    }

    void sls::on_save_model() {
        save_best_values();
    }

    void sls::on_restart()  {
        for (unsigned v = 0; v < s.s().num_vars(); ++v)
            init_bool_var_assignment(v);

        check_ineqs();
    }

    void sls::check_ineqs() {

        auto check_bool_var = [&](sat::bool_var bv) {
            auto const* ineq = atom(bv);
            if (!ineq)
                return;
            int64_t d = dtt(sign(bv), *ineq);
            sat::literal lit(bv, sign(bv));
            if (is_true(lit) != (d == 0)) {
                verbose_stream() << "invalid assignment " << bv << " " << *ineq << "\n";
            }
            VERIFY(is_true(lit) == (d == 0));
        };
        for (unsigned v = 0; v < s.get_num_vars(); ++v)
            check_bool_var(v);
    }

    std::ostream& sls::display(std::ostream& out) const {
        for (bool_var bv = 0; bv < s.s().num_vars(); ++bv) {
            auto const* ineq = atom(bv);            
            if (!ineq)
                continue;
            out << bv << " " << *ineq << "\n";            
        }
        for (unsigned v = 0; v < s.get_num_vars(); ++v) {
            if (s.is_bool(v))
                continue;
            out << "v" << v << " := " << m_vars[v].m_value << " " << m_vars[v].m_best_value << "\n";
        }
        return out;
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy