All Downloads are FREE. Search and download functionalities are using the official Maven repository.

z3-z3-4.13.0.src.sat.smt.fpa_solver.cpp Maven / Gradle / Ivy

The newest version!
/*++
Copyright (c) 2014 Microsoft Corporation

Module Name:

    fpa_solver.cpp

Abstract:

    Floating-Point Theory Plugin

Author:

    Christoph (cwinter) 2014-04-23

Revision History:

    Ported from theory_fpa by nbjorner in 2020.

--*/

#include "sat/smt/fpa_solver.h"
#include "ast/fpa/bv2fpa_converter.h"

namespace fpa {

    solver::solver(euf::solver& ctx) :
        euf::th_euf_solver(ctx, symbol("fpa"), ctx.get_manager().mk_family_id("fpa")),
        m_th_rw(ctx.get_manager()),
        m_converter(ctx.get_manager(), m_th_rw),
        m_rw(ctx.get_manager(), m_converter, params_ref()),
        m_fpa_util(m_converter.fu()),
        m_bv_util(m_converter.bu()),
        m_arith_util(m_converter.au())
    {
        params_ref p;
        p.set_bool("arith_lhs", true);
        m_th_rw.updt_params(p);
    }

    solver::~solver() {
        dec_ref_map_key_values(m, m_conversions);
        SASSERT(m_conversions.empty());
    }


    expr_ref solver::convert(expr* e) {    
        expr_ref res(m);
        expr* ccnv;
        TRACE("t_fpa", tout << "converting " << mk_ismt2_pp(e, m) << "\n";);

        if (m_conversions.find(e, ccnv)) {
            res = ccnv;
            TRACE("t_fpa_detail", tout << "cached:" << "\n";
                  tout << mk_ismt2_pp(e, m) << "\n" << " -> " << "\n" << mk_ismt2_pp(res, m) << "\n";);
        }
        else {
            res = m_rw.convert(m_th_rw, e);

            TRACE("t_fpa_detail", tout << "converted; caching:" << "\n";
                  tout << mk_ismt2_pp(e, m) << "\n" << " -> " << "\n" << mk_ismt2_pp(res, m) << "\n";);

            m_conversions.insert(e, res);
            m.inc_ref(e);
            m.inc_ref(res);

            ctx.push(insert_ref2_map(m, m_conversions, e, res.get()));
        }
        return res;
    }

    sat::literal_vector solver::mk_side_conditions() {
        sat::literal_vector conds;
        expr_ref t(m);
        for (expr* arg : m_converter.m_extra_assertions) {
            ctx.get_rewriter()(arg, t);
            m_th_rw(t);
            conds.push_back(mk_literal(t));
        }
        m_converter.m_extra_assertions.reset();
        return conds;
    }

    sat::check_result solver::check() {
        SASSERT(m_converter.m_extra_assertions.empty());
        if (unit_propagate())
            return sat::check_result::CR_CONTINUE;
        SASSERT(m_nodes.size() <= m_nodes_qhead);
        return sat::check_result::CR_DONE;
    }

    void solver::attach_new_th_var(enode* n) {
        theory_var v = mk_var(n);
        ctx.attach_th_var(n, this, v);
        TRACE("t_fpa", tout << "new theory var: " << mk_ismt2_pp(n->get_expr(), m) << " := " << v << "\n";);
    }

    sat::literal solver::internalize(expr* e, bool sign, bool root) {
        SASSERT(m.is_bool(e));
        if (!visit_rec(m, e, sign, root))
            return sat::null_literal;
        sat::literal lit = expr2literal(e);
        if (sign)
           lit.neg();
        return lit;
    }

    void solver::internalize(expr* e) {
        visit_rec(m, e, false, false);
    }

    bool solver::visited(expr* e) {
        euf::enode* n = expr2enode(e);
        return n && n->is_attached_to(get_id());
    }

    bool solver::visit(expr* e) {
        if (visited(e))
            return true;
        if (!is_app(e) || to_app(e)->get_family_id() != get_id()) {
            ctx.internalize(e);
            return true;
        }
        m_stack.push_back(sat::eframe(e));
        return false;
    }

    bool solver::post_visit(expr* e, bool sign, bool root) {
        euf::enode* n = expr2enode(e);
        SASSERT(!n || !n->is_attached_to(get_id()));
        if (!n)
            n = mk_enode(e, false);
        SASSERT(!n->is_attached_to(get_id()));
        attach_new_th_var(n);
        TRACE("fp", tout << "post: " << mk_bounded_pp(e, m) << "\n";);
        m_nodes.push_back(std::tuple(n, sign, root));
        ctx.push(push_back_trail(m_nodes));
        return true;
    }

    void solver::apply_sort_cnstr(enode* n, sort* s) {
        TRACE("t_fpa", tout << "apply sort cnstr for: " << mk_ismt2_pp(n->get_expr(), m) << "\n";);
        SASSERT(s->get_family_id() == get_id());
        SASSERT(m_fpa_util.is_float(s) || m_fpa_util.is_rm(s));
        SASSERT(m_fpa_util.is_float(n->get_expr()) || m_fpa_util.is_rm(n->get_expr()));
        SASSERT(n->get_decl()->get_range() == s);

        if (is_attached_to_var(n))
            return;
        if (m.is_ite(n->get_expr()))
            return;
        attach_new_th_var(n);

        expr* owner = n->get_expr();

        if (m_fpa_util.is_rm(s) && !m_fpa_util.is_bv2rm(owner)) {
            // For every RM term, we need to make sure that it's
            // associated bit-vector is within the valid range.
            expr_ref valid(m), limit(m);
            limit = m_bv_util.mk_numeral(4, 3);
            valid = m_bv_util.mk_ule(m_converter.wrap(owner), limit);
            add_unit(mk_literal(valid));
        }
        activate(owner);
    }

    bool solver::unit_propagate() {
        if (m_nodes.size() <= m_nodes_qhead)
            return false;
        ctx.push(value_trail(m_nodes_qhead));
        for (; m_nodes_qhead < m_nodes.size(); ++m_nodes_qhead) 
            unit_propagate(m_nodes[m_nodes_qhead]);
        return true;
    }

    void solver::unit_propagate(std::tuple const& t) {
        auto [n, sign, root] = t;
        expr* e = n->get_expr();
        app* a = to_app(e);
        if (m.is_bool(e)) {
            sat::literal atom(ctx.get_si().add_bool_var(e), false);
            atom = ctx.attach_lit(atom, e);
            sat::literal bv_atom = mk_literal(m_rw.convert_atom(m_th_rw, e));
            sat::literal_vector conds = mk_side_conditions();
            conds.push_back(bv_atom);
            add_equiv_and(atom, conds);
            if (root) {
                if (sign)
                    atom.neg();
                add_unit(atom);
            }
        }
        else {            
            switch (a->get_decl_kind()) {
            case OP_FPA_TO_FP:
            case OP_FPA_TO_UBV:
            case OP_FPA_TO_SBV:
            case OP_FPA_TO_REAL:
            case OP_FPA_TO_IEEE_BV: {
                expr_ref conv = convert(e);
                add_unit(eq_internalize(e, conv));
                add_units(mk_side_conditions());
                break;
            }
            default: /* ignore */
                break;
            }
        }
        activate(e);
    }

    void solver::activate(expr* n) {
        TRACE("t_fpa", tout << "relevant_eh for: " << mk_ismt2_pp(n, m) << "\n";);

        mpf_manager& mpfm = m_fpa_util.fm();

        if (m.is_ite(n)) {
            // skip
        }
        else if (m_fpa_util.is_float(n) || m_fpa_util.is_rm(n)) {
            expr* a = nullptr, * b = nullptr, * c = nullptr;
            if (!m_fpa_util.is_fp(n)) {
                app_ref wrapped = m_converter.wrap(n);
                mpf_rounding_mode rm;
                scoped_mpf val(mpfm);
                if (m_fpa_util.is_rm_numeral(n, rm)) {
                    expr_ref rm_num(m);
                    rm_num = m_bv_util.mk_numeral(rm, 3);
                    add_unit(eq_internalize(wrapped, rm_num)); 
                }
                else if (m_fpa_util.is_numeral(n, val)) {
                    expr_ref bv_val_e(convert(n), m);
                    VERIFY(m_fpa_util.is_fp(bv_val_e, a, b, c));
                    expr* args[] = { a, b, c };
                    expr_ref cc_args(m_bv_util.mk_concat(3, args), m);
                    // Require
                    // wrap(n) = bvK
                    // fp(extract(wrap(n)) = n
                    add_unit(eq_internalize(wrapped, cc_args));
                    add_unit(eq_internalize(bv_val_e, n));
                    add_units(mk_side_conditions());
                }
                else 
                    add_unit(eq_internalize(m_converter.unwrap(wrapped, n->get_sort()), n));                
            }
        }
        else if (is_app(n) && to_app(n)->get_family_id() == get_id()) {
            // These are the conversion functions fp.to_* */
            SASSERT(!m_fpa_util.is_float(n) && !m_fpa_util.is_rm(n));
        }
        else {
            /* Theory variables can be merged when (= bv-term (bvwrap fp-term)) */
            SASSERT(m_bv_util.is_bv(n));
        }
    }

    void solver::ensure_equality_relation(theory_var x, theory_var y) {
        fpa_util& fu = m_fpa_util;
        enode* e_x = var2enode(x);
        enode* e_y = var2enode(y);
        expr* xe = e_x->get_expr();
        expr* ye = e_y->get_expr();

        if (fu.is_bvwrap(xe) || fu.is_bvwrap(ye))
            return;

        TRACE("t_fpa", tout << "new eq: " << x << " = " << y << "\n";
              tout << mk_ismt2_pp(xe, m) << "\n" << " = " << "\n" << mk_ismt2_pp(ye, m) << "\n";);

        expr_ref xc = convert(xe);
        expr_ref yc = convert(ye);

        TRACE("t_fpa_detail", tout << "xc = " << mk_ismt2_pp(xc, m) << "\n" <<
            "yc = " << mk_ismt2_pp(yc, m) << "\n";);

        expr_ref c(m);

        if ((fu.is_float(xe) && fu.is_float(ye)) ||
            (fu.is_rm(xe) && fu.is_rm(ye)))
            m_converter.mk_eq(xc, yc, c);
        else
            c = m.mk_eq(xc, yc);

        m_th_rw(c);

        sat::literal eq1 = eq_internalize(xe, ye); 
        sat::literal eq2 = mk_literal(c);
        add_equiv(eq1, eq2);
        add_units(mk_side_conditions());
    }

    void solver::new_eq_eh(euf::th_eq const& eq) {
        ensure_equality_relation(eq.v1(), eq.v2());
    }

    void solver::new_diseq_eh(euf::th_eq const& eq) {
        ensure_equality_relation(eq.v1(), eq.v2());
    }

    void solver::asserted(sat::literal l) {
        expr* e = ctx.bool_var2expr(l.var());

        TRACE("t_fpa", tout << "assign_eh for: " << l << "\n" << mk_ismt2_pp(e, m) << "\n";);

        sat::literal c = mk_literal(convert(e));
        sat::literal_vector conds = mk_side_conditions();
        conds.push_back(c);
        if (l.sign()) {
            for (sat::literal sc : conds)
                add_clause(l, sc);
        }
        else {
            for (auto& sc : conds)
                sc.neg();
            conds.push_back(l);
            add_clause(conds);
        }
    }

    void solver::add_value(euf::enode* n, model& mdl, expr_ref_vector& values) {
        expr* e = n->get_expr();
        app_ref wrapped(m);
        expr_ref value(m);
        
        auto is_wrapped = [&]() {
            if (!wrapped) wrapped = m_converter.wrap(e);
            return expr2enode(wrapped) != nullptr;
        };
        if (m_fpa_util.is_rm_numeral(e) || m_fpa_util.is_numeral(e)) 
            value = e;
        else if (m_fpa_util.is_fp(e)) {
            SASSERT(n->num_args() == 3);
            expr* a = values.get(n->get_arg(0)->get_root_id());
            expr* b = values.get(n->get_arg(1)->get_root_id());
            expr* c = values.get(n->get_arg(2)->get_root_id());
            value = m_converter.bv2fpa_value(e->get_sort(), a, b, c);
        }
        else if (m_fpa_util.is_bv2rm(e)) {
            SASSERT(n->num_args() == 1);
            value = m_converter.bv2rm_value(values.get(n->get_arg(0)->get_root_id()));
        }
        else if (m_fpa_util.is_rm(e) && is_wrapped())
            value = m_converter.bv2rm_value(values.get(expr2enode(wrapped)->get_root_id()));
        else if (m_fpa_util.is_rm(e))
            value = m_fpa_util.mk_round_toward_zero();
        else if (m_fpa_util.is_float(e) && is_wrapped()) {
            expr* a = values.get(expr2enode(wrapped)->get_root_id());
            value = m_converter.bv2fpa_value(e->get_sort(), a);
        }
        else {
            SASSERT(m_fpa_util.is_float(e));
            unsigned ebits = m_fpa_util.get_ebits(e->get_sort());
            unsigned sbits = m_fpa_util.get_sbits(e->get_sort());
            value = m_fpa_util.mk_pzero(ebits, sbits);
        }
        values.set(n->get_root_id(), value);
        TRACE("t_fpa", tout << ctx.bpp(n) << " := " << value << "\n";);
    }

    bool solver::add_dep(euf::enode* n, top_sort& dep) {
        expr* e = n->get_expr();
        if (m_fpa_util.is_fp(e)) {
            SASSERT(n->num_args() == 3);
            for (enode* arg : euf::enode_args(n))
                dep.add(n, arg);
            return true;
        }
        else if (m_fpa_util.is_bv2rm(e)) {
            SASSERT(n->num_args() == 1);
            dep.add(n, n->get_arg(0));
            return true;
        }
        else if (m_fpa_util.is_rm(e) || m_fpa_util.is_float(e)) {
            euf::enode* wrapped = expr2enode(m_converter.wrap(e));
            if (wrapped)
                dep.add(n, wrapped);
            return nullptr != wrapped;
        }
        else 
            return false;
    }

    std::ostream& solver::display(std::ostream& out) const {
        bool first = true;
        for (enode* n : ctx.get_egraph().nodes()) {
            theory_var v = n->get_th_var(m_fpa_util.get_family_id());
            if (v != -1) {
                if (first) out << "fpa theory variables:" << "\n";
                out << v << " -> " <<
                    mk_ismt2_pp(n->get_expr(), m) << "\n";
                first = false;
            }
        }
        // if there are no fpa theory variables, was fp ever used?
        if (first)
            return out;

        out << "bv theory variables:" << "\n";
        for (enode* n : ctx.get_egraph().nodes()) {
            theory_var v = n->get_th_var(m_bv_util.get_family_id());
            if (v != -1) out << v << " -> " <<
                mk_ismt2_pp(n->get_expr(), m) << "\n";
        }

        out << "arith theory variables:" << "\n";
        for (enode* n : ctx.get_egraph().nodes()) {
            theory_var v = n->get_th_var(m_arith_util.get_family_id());
            if (v != -1) out << v << " -> " <<
                mk_ismt2_pp(n->get_expr(), m) << "\n";
        }

        out << "equivalence classes:\n";
        for (enode* n : ctx.get_egraph().nodes()) {
            expr* e = n->get_expr();
            out << n->get_root_id() << " --> " << mk_ismt2_pp(e, m) << "\n";
        }
        return out;
    }

    void solver::finalize_model(model& mdl) {
        model new_model(m);

        bv2fpa_converter bv2fp(m, m_converter);

        obj_hashtable seen;
        bv2fp.convert_min_max_specials(&mdl, &new_model, seen);
        bv2fp.convert_uf2bvuf(&mdl, &new_model, seen);

        for (func_decl* f : seen)
            mdl.unregister_decl(f);

        for (unsigned i = 0; i < new_model.get_num_constants(); i++) {
            func_decl* f = new_model.get_constant(i);
            mdl.register_decl(f, new_model.get_const_interp(f));
        }

        for (unsigned i = 0; i < new_model.get_num_functions(); i++) {
            func_decl* f = new_model.get_function(i);
            func_interp* fi = new_model.get_func_interp(f)->copy();
            mdl.register_decl(f, fi);
        }
    }

};




© 2015 - 2024 Weber Informatics LLC | Privacy Policy