All Downloads are FREE. Search and download functionalities are using the official Maven repository.

z3-z3-4.13.0.src.ast.euf.euf_bv_plugin.cpp Maven / Gradle / Ivy

The newest version!
/*++
Copyright (c) 2023 Microsoft Corporation

Module Name:

    euf_bv_plugin.cpp

Abstract:

    plugin structure for bit-vectors

Author:

    Nikolaj Bjorner (nbjorner) 2023-11-08
    Jakob Rath 2023-11-08

Objective:

satisfies extract/concat axioms.
  - concat(n{I],n[J]) = n[IJ] for I, J consecutive.
  - concat(v1, v2) = 2^width(v1)*v2 + v1
  - concat(n[width(n)-1:0]) = n
  - concat(a, b)[I] = concat(a[I1], b[I2])
  - concat(a, concat(b, c)) = concat(concat(a, b), c)

E-graph:

The E-graph contains node definitions of the form

   n := f(n1,n2,..)

and congruences:

   n ~ n' means root(n) = root(n')

Saturated state:

   1. n := n1[I], n' := n2[J], n1 ~ n2 => root(n1) contains tree refining both I, J from smaller intervals

   2. n := concat(n1[I], n2[J]), n1 ~ n2 ~ n3 I and J are consecutive => n ~ n3[IJ]

   3. n := concat(n1[I], n2[J]), I and J are consecutive & n1 ~ n2, n1[I] ~ v1, n1[J] ~ v2 => n ~ 2^width(v1)*v2 + v1

   4. n := concat(n1[I], n2[J], I, J are consecutive, n1 ~ n2, n ~ v => n1[I] ~ v mod 2^width(n1[I]), n2[J] ~ v div 2^width(n1[I])

   5. n' := n[I] => n ~ n[width(n)-1:0]

   6. n := concat(a, concat(b, c)) => n ~ concat(concat(a, b), c)
      - handled by rewriter pre-processing for inputs
      - terms created internally are not equated modulo associativity

   7, n := concat(n1, n2)[I] => n ~ concat(n1[I1],n2[I2]) or n[I1] or n[I2]
      - handled by rewriter pre-processing

Example:
   x == (x1 x2) x3
   y == y1 (y2 y3)
   x1 == y1, x2 == y2, x3 == y3
   =>
   x = y

   by x2 == y2, x3 == y3 => (x2 x3) = (y2 y3)
   by (2) => x[I23] = (x2 x3)
   by (2) => x[I123] = (x1 (x2 x3))
   by (5) => x = x[I123]

The formal properties of saturation have to be established.

- Saturation does not complete with respect to associativity.
Instead the claim is along the lines that the resulting E-graph can be used as a canonizer.
If given a set of equations E that are saturated, and terms t1, t2 that are 
both simplified with respect to left-associativity of concatentation, and t1, t2 belong to the E-graph,
then t1 = t2 iff t1 ~ t2 in the E-graph.

TODO: Is saturation for (7) overkill for the purpose of canonization?

--*/


#include "ast/ast_pp.h"
#include "ast/euf/euf_bv_plugin.h"
#include "ast/euf/euf_egraph.h"

namespace euf {

    bv_plugin::bv_plugin(egraph& g):
        plugin(g),
        bv(g.get_manager())
    {}

    enode* bv_plugin::mk_value_concat(enode* hi, enode* lo) {
        auto v1 = get_value(hi);
        auto v2 = get_value(lo);
        auto v3 = v2 + v1 * rational::power_of_two(width(lo));
        return mk_value(v3, width(lo) + width(hi));
    }

    enode* bv_plugin::mk_value(rational const& v, unsigned sz) {
        auto e = bv.mk_numeral(v, sz);
        auto n = mk(e, 0, nullptr);
        if (m_ensure_th_var)
            m_ensure_th_var(n);
        return n;
    }

    void bv_plugin::propagate_merge(enode* x, enode* y) {
        if (!bv.is_bv(x->get_expr()))
            return;

        TRACE("bv", tout << "merge_eh " << g.bpp(x) << " == " << g.bpp(y) << "\n");
        SASSERT(!m_internal);
        flet _internal(m_internal, true);

        propagate_values(x);

        // ensure slices align
        if (has_sub(x) || has_sub(y)) {
            enode_vector& xs = m_xs, & ys = m_ys;
            xs.reset();
            ys.reset();
            xs.push_back(x);
            ys.push_back(y);
            merge(xs, ys, justification::equality(x, y));
        }

        // ensure p := concat(n1[I], n2[J]), n1 ~ n2 ~ n3 I and J are consecutive => p ~ n3[IJ]
        for (auto* n : enode_class(x))
            propagate_extract(n);
    }

    void bv_plugin::register_node(enode* n) { 
        m_queue.push_back(n); 
        m_trail.push_back(new (get_region()) push_back_vector(m_queue));  
        push_plugin_undo(bv.get_family_id()); 
    }

    void bv_plugin::merge_eh(enode* n1, enode* n2) { 
        m_queue.push_back(enode_pair(n1, n2)); 
        m_trail.push_back(new (get_region()) push_back_vector(m_queue));  
        push_plugin_undo(bv.get_family_id()); 
    }

    void bv_plugin::propagate() {
        if (m_qhead == m_queue.size())
            return;
        m_trail.push_back(new (get_region()) value_trail(m_qhead));
        push_plugin_undo(bv.get_family_id());
        for (; m_qhead < m_queue.size(); ++m_qhead) {
            if (std::holds_alternative(m_queue[m_qhead])) {
                auto n = *std::get_if(&m_queue[m_qhead]);
                propagate_register_node(n);
            }
            else {
                auto [a, b] = *std::get_if(&m_queue[m_qhead]);
                propagate_merge(a, b);
            }
        }
    }

    // enforce concat(v1, v2) = v1*2^|v2| + v2
    void bv_plugin::propagate_values(enode* x) {
        if (!is_value(x))
            return;

        auto val_x = get_value(x);
        enode* a, * b;
        unsigned lo, hi;
        for (enode* p : enode_parents(x)) {
            if (is_concat(p, a, b) && is_value(a) && is_value(b))
                push_merge(mk_concat(a->get_interpreted(), b->get_interpreted()), mk_value_concat(a, b));

            if (is_extract(p, lo, hi)) {
                auto val_p = mod2k(machine_div2k(val_x, lo), hi - lo + 1);
                auto ix = x->get_interpreted();
                auto ex = mk(bv.mk_extract(hi, lo, ix->get_expr()), 1, &ix);
                push_merge(ex, mk_value(val_p, width(p)));
            }
        }

        for (enode* sib : enode_class(x)) {
            if (is_concat(sib, a, b)) {
                auto val_a = machine_div2k(val_x, width(b));
                auto val_b = mod2k(val_x, width(b));
                push_merge(mk_concat(mk_value(val_a, width(a)), mk_value(val_b, width(b))), x->get_interpreted());
            }
        }
    }

    //
    // p := concat(n1[I], n2[J]), n1 ~ n2 ~ n3 I and J are consecutive => p ~ n3[IJ]
    // 
    // n is of form arg[I]
    // p is of form concat(n, b) or concat(a, n)
    // b is congruent to arg[J], I is consecutive with J => ensure that arg[IJ] = p
    // a is congruent to arg[J], J is consecutive with I => ensure that arg[JI] = p
    // 

    void bv_plugin::propagate_extract(enode* n) {
        unsigned lo1, hi1, lo2, hi2;
        enode* a, * b;
        if (!is_extract(n, lo1, hi1))
            return;

        enode* arg = n->get_arg(0);
        enode* arg_r = arg->get_root();       
        enode* n_r = n->get_root();

        m_ensure_concat.reset();
        auto ensure_concat = [&](unsigned lo, unsigned mid, unsigned hi) {
            // verbose_stream() << lo << " " << mid << " " << hi << "\n";
            TRACE("bv", tout << "ensure-concat " << lo << " " << mid << " " << hi << "\n");
            unsigned lo_, hi_;
            for (enode* p1 : enode_parents(n))
                if (is_extract(p1, lo_, hi_) && lo_ == lo && hi_ == hi && p1->get_arg(0)->get_root() == arg_r)
                    return;
            // add the axiom instead of merge(p, mk_extract(arg, lo, hi)), which would require tracking justifications
            push_merge(mk_concat(mk_extract(arg, mid + 1, hi), mk_extract(arg, lo, mid)), mk_extract(arg, lo, hi));
        };

        auto propagate_above = [&](enode* b) {
            TRACE("bv", tout << "propagate-above " << g.bpp(b) << "\n");
            for (enode* sib : enode_class(b))
                if (is_extract(sib, lo2, hi2) && sib->get_arg(0)->get_root() == arg_r && hi1 + 1 == lo2)
                    m_ensure_concat.push_back({lo1, hi1, hi2});
        };

        auto propagate_below = [&](enode* a) {
            TRACE("bv", tout << "propagate-below " << g.bpp(a) << "\n");
            for (enode* sib : enode_class(a))
                if (is_extract(sib, lo2, hi2) && sib->get_arg(0)->get_root() == arg_r && hi2 + 1 == lo1)
                    m_ensure_concat.push_back({lo2, hi2, hi1});
        };
        
        for (enode* p : enode_parents(n)) {
            if (is_concat(p, a, b)) {
                if (a->get_root() == n_r)
                    propagate_below(b);
                if (b->get_root() == n_r)
                    propagate_above(a);
            }
        }

        for (auto [lo, mid, hi] : m_ensure_concat) 
            ensure_concat(lo, mid, hi);
        
    }

    class bv_plugin::undo_split : public trail {
        bv_plugin& p;
        enode* n;
    public:
        undo_split(bv_plugin& p, enode* n): p(p), n(n) {}
        void undo() override {
            auto& i = p.info(n);
            i.value = nullptr;
            i.lo = nullptr;
            i.hi = nullptr;
            i.cut = null_cut;            
        }
    };

    void bv_plugin::push_undo_split(enode* n) {
        m_trail.push_back(new (get_region()) undo_split(*this, n));
        push_plugin_undo(bv.get_family_id());
    }

    void bv_plugin::undo() {
        m_trail.back()->undo();
        m_trail.pop_back();
    }

    
    void bv_plugin::propagate_register_node(enode* n) {
        TRACE("bv", tout << "register " << g.bpp(n) << "\n");
        enode* a, * b;
        unsigned lo, hi;
        if (is_concat(n, a, b)) {
            auto& i = info(n);
            i.value = n;
            i.hi = a;
            i.lo = b;
            i.cut = width(b);
            push_undo_split(n);
        }
        else if (is_concat(n) && n->num_args() != 2) {
            SASSERT(n->num_args() != 0);
            auto last = n->get_arg(n->num_args() - 1);
            for (unsigned i = n->num_args() - 1; i-- > 0;) 
                last = mk_concat(n->get_arg(i), last);
            push_merge(last, n);
        }
        else if (is_extract(n, lo, hi) && (lo != 0 || hi + 1 != width(n->get_arg(0)))) {
            enode* arg = n->get_arg(0);   
            unsigned w = width(arg);
            if (all_of(enode_parents(arg), [&](enode* p) { unsigned _lo, _hi;  return !is_extract(p, _lo, _hi) || _lo != 0 || _hi + 1 != w; }))
                push_merge(mk_extract(arg, 0, w - 1), arg);
            ensure_slice(arg, lo, hi);
        }
        TRACE("bv", tout << "done register " << g.bpp(n) << "\n");
    }

    //
    // Ensure that there are slices at boundaries of n[hi:lo]
    //
    void bv_plugin::ensure_slice(enode* n, unsigned lo, unsigned hi) {
        enode* r = n;
        unsigned lb = 0, ub = width(n) - 1;
        while (true) {   
            TRACE("bv", tout << "ensure slice " << g.bpp(n) << " " << lb << " [" << lo << ", " << hi << "] " << ub << "\n");
            SASSERT(lb <= lo && hi <= ub);
            SASSERT(ub - lb + 1 == width(r));
            if (lb == lo && ub == hi)
                return;
            slice_info const& i = info(r);
            
            if (!i.lo) {
                if (lo > lb) {
                    split(r, lo - lb);
                    if (hi < ub) // or split(info(r).hi, ...)                       
                        ensure_slice(n, lo, hi);                    
                }
                else if (hi < ub)
                    split(r, ub - hi);
                break;
            }
            auto cut = i.cut;
            if (cut + lb <= lo) {
                lb += cut;
                r = i.hi;
                continue;
            }
            if (cut + lb > hi) {
                ub = cut + lb - 1;
                r = i.lo;
                continue;
            }
            SASSERT(lo < cut + lb && cut + lb <= hi);
            ensure_slice(n, lo, cut + lb - 1);
            ensure_slice(n, cut + lb, hi);
            break;
        }
    }

    enode* bv_plugin::mk_extract(enode* n, unsigned lo, unsigned hi) { 
        SASSERT(lo <= hi && width(n) > hi - lo);
        unsigned lo1, hi1;
        while (is_extract(n, lo1, hi1)) {
            lo += lo1;
            hi += lo1;
            n = n->get_arg(0);
        }
        if (n->interpreted()) {
            auto v = get_value(n);
            if (lo > 0)
                v = div(v, rational::power_of_two(lo));
            if (hi + 1 != width(n))
                v = mod(v, rational::power_of_two(hi + 1));
            return mk_value(v, hi - lo + 1);
        }
        return mk(bv.mk_extract(hi, lo, n->get_expr()), 1, &n);
    }

    enode* bv_plugin::mk_concat(enode* hi, enode* lo) { 
        enode* args[2] = { hi, lo }; 
        return mk(bv.mk_concat(hi->get_expr(), lo->get_expr()), 2, args); 
    }

    void bv_plugin::merge(enode_vector& xs, enode_vector& ys, justification dep) {  
        while (!xs.empty()) {
            SASSERT(!ys.empty());
            auto x = xs.back();
            auto y = ys.back();
            TRACE("bv", tout << "merge " << g.bpp(x) << " " << g.bpp(y) << "\n");
            if (unfold_sub(x, xs))
                continue;
            else if (unfold_sub(y, ys))
                continue;
            else if (unfold_width(x, xs, y, ys))
                continue;
            else if (unfold_width(y, ys, x, xs))
                continue;
            else if (x->get_root() != y->get_root()) 
                push_merge(x, y, dep);            
            xs.pop_back();
            ys.pop_back();
        }
        SASSERT(ys.empty());
    }

    bool bv_plugin::unfold_sub(enode* x, enode_vector& xs) {
        if (!has_sub(x))
            return false;
        xs.pop_back();
        xs.push_back(sub_hi(x));
        xs.push_back(sub_lo(x));
        return true;
    }

    bool bv_plugin::unfold_width(enode* x, enode_vector& xs, enode* y, enode_vector& ys) {
        if (width(x) <= width(y))
            return false;            
        split(x, width(y));
        xs.pop_back();
        xs.push_back(sub_hi(x));
        xs.push_back(sub_lo(x));
        return true;
    }

    void bv_plugin::split(enode* n, unsigned cut) {
        TRACE("bv", tout << "split: " << g.bpp(n) << " " << cut << "\n");
        unsigned w = width(n);
        SASSERT(!info(n).hi);
        SASSERT(0 < cut && cut < w);
        enode* hi = mk_extract(n, cut, w - 1);
        enode* lo = mk_extract(n, 0, cut - 1);        
        auto& i = info(n);        
        i.value = n;
        i.hi = hi;
        i.lo = lo;
        i.cut = cut;
        push_undo_split(n);
        push_merge(mk_concat(hi, lo), n);        
    }

    void bv_plugin::sub_slices(enode* n, std::function& consumer) {
        m_todo.push_back({ n, 0 });
        unsigned lo, hi;
        expr* e;

        for (unsigned i = 0; i < m_todo.size(); ++i) {
            auto [n, offset] = m_todo[i];
            m_offsets.reserve(n->get_root_id() + 1);
            auto& offsets = m_offsets[n->get_root_id()];
            if (offsets.contains(offset))
                continue;
            offsets.push_back(offset);
            if (!consumer(n, offset))
                continue;
            for (auto sib : euf::enode_class(n)) {
                if (bv.is_concat(sib->get_expr())) {
                    unsigned delta = 0;
                    for (unsigned j = sib->num_args(); j-- > 0; ) {
                        auto arg = sib->get_arg(j);
                        m_todo.push_back({ arg, offset + delta });
                        delta += width(arg);
                    }
                }
                for (auto p : euf::enode_parents(sib)) {
                    if (bv.is_extract(p->get_expr(), lo, hi, e)) {
                        SASSERT(g.find(e)->get_root() == n->get_root());
                        m_todo.push_back({ p, offset + lo });
                    }
                }
            }

        }
        clear_offsets();
    }

    void bv_plugin::super_slices(enode* n, std::function& consumer) {
        m_todo.push_back({ n, 0 });
        unsigned lo, hi;
        expr* e;

        for (unsigned i = 0; i < m_todo.size(); ++i) {
            auto [n, offset] = m_todo[i];
            m_offsets.reserve(n->get_root_id() + 1);
            auto& offsets = m_offsets[n->get_root_id()];
            if (offsets.contains(offset))
                continue;
            offsets.push_back(offset);
            if (!consumer(n, offset))
                continue;
            for (auto sib : euf::enode_class(n)) {
                if (bv.is_extract(sib->get_expr(), lo, hi, e)) {
                    auto child = g.find(e);
                    m_todo.push_back({ child, offset + lo });
                }
                for (auto p : euf::enode_parents(sib)) {
                    if (bv.is_concat(p->get_expr())) {
                        unsigned delta = 0;
                        for (unsigned j = p->num_args(); j-- > 0; ) {
                            auto arg = p->get_arg(j);
                            if (arg->get_root() == n->get_root())
                                m_todo.push_back({ p, offset + delta });
                            delta += width(arg);
                        }
                    }
                }            
            }
        }
        clear_offsets();
    }

    //
    // Explain that a is a subslice of b at offset
    // or that b is a subslice of a at offset
    // 
    void bv_plugin::explain_slice(enode* a, unsigned offset, enode* b, std::function& consumer) {
        if (width(a) < width(b))
            std::swap(a, b);
        SASSERT(width(a) >= width(b));
        svector> just;
        m_jtodo.push_back({ a, 0, UINT_MAX });
        unsigned lo, hi;
        expr* e;

        for (unsigned i = 0; i < m_jtodo.size(); ++i) {
            auto [n, offs, j] = m_jtodo[i];
            m_offsets.reserve(n->get_root_id() + 1);
            auto& offsets = m_offsets[n->get_root_id()];
            if (offsets.contains(offs))
                continue;
            offsets.push_back(offs);
            if (n->get_root() == b->get_root() && offs == offset) {
                if (n != b)
                    consumer(n, b);
                while (j != UINT_MAX) {
                    auto [x, y, j2] = just[j];
                    if (x != y)
                        consumer(x, y);
                    j = j2;
                }
                for (auto const& [n, offset, j] : m_jtodo) {
                    m_offsets.reserve(n->get_root_id() + 1);
                    m_offsets[n->get_root_id()].reset();
                }
                DEBUG_CODE(
                    for (auto const& off : m_offsets) {
                        VERIFY(off.empty());
                    });
                m_jtodo.reset();
                return;
            }
            for (auto sib : euf::enode_class(n)) {
                if (bv.is_concat(sib->get_expr())) {
                    unsigned delta = 0;
                    unsigned j2 = just.size();
                    just.push_back({ n, sib, j });
                    for (unsigned j = sib->num_args(); j-- > 0; ) {
                        auto arg = sib->get_arg(j);
                        m_jtodo.push_back({ arg, offs + delta, j2 });
                        delta += width(arg);
                    }
                }
                for (auto p : euf::enode_parents(sib)) {
                    if (bv.is_extract(p->get_expr(), lo, hi, e)) {
                        SASSERT(g.find(e)->get_root() == n->get_root());
                        unsigned j2 = just.size();
                        just.push_back({ g.find(e), n, j });
                        m_jtodo.push_back({ p, offs + lo, j2 });
                    }
                }
            }

        }
        IF_VERBOSE(0,
            g.display(verbose_stream());
            verbose_stream() << g.bpp(a) << " offset " << offset << " " << g.bpp(b) << "\n";
            for (auto const& [n, offset, j] : m_jtodo) 
                verbose_stream() << g.bpp(n) << " offset " << offset << " " << g.bpp(n->get_root()) << "\n";            
        );
        UNREACHABLE();
    }

    void bv_plugin::clear_offsets() {
        for (auto const& [n, offset] : m_todo) {
            m_offsets.reserve(n->get_root_id() + 1);
            m_offsets[n->get_root_id()].reset();
        }
        m_todo.reset();
    }

    std::ostream& bv_plugin::display(std::ostream& out) const {
        out << "bv\n";        
        for (auto const& i : m_info) 
            if (i.lo)
                out << g.bpp(i.value) << " cut " << i.cut << " lo " << g.bpp(i.lo) << " hi " << g.bpp(i.hi) << "\n";
        return out;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy