All Downloads are FREE. Search and download functionalities are using the official Maven repository.

z3-z3-4.13.0.src.ast.euf.euf_egraph.cpp Maven / Gradle / Ivy

The newest version!
/*++
Copyright (c) 2020 Microsoft Corporation

Module Name:

    euf_egraph.cpp

Abstract:

    E-graph layer

Author:

    Nikolaj Bjorner (nbjorner) 2020-08-23

Notes:
--*/


#include "ast/euf/euf_egraph.h"
#include "ast/ast_pp.h"
#include "ast/ast_translation.h"

namespace euf {

    enode* egraph::mk_enode(expr* f, unsigned generation, unsigned num_args, enode * const* args) {
        enode* n = enode::mk(m_region, f, generation, num_args, args);
        if (m_default_relevant)
            n->set_relevant(true);
        m_nodes.push_back(n);
        m_exprs.push_back(f);
        if (is_app(f) && num_args > 0) {
           unsigned id = to_app(f)->get_decl()->get_small_id();
           m_decl2enodes.reserve(id+1);
           m_decl2enodes[id].push_back(n);
        }
        m_expr2enode.setx(f->get_id(), n, nullptr);
        push_node(n);
        for (unsigned i = 0; i < num_args; ++i) {
            set_cgc_enabled(args[i], true);      
            args[i]->get_root()->set_is_shared(l_undef);
        }                       
        return n;
    }

    enode* egraph::find(expr* e, unsigned n, enode* const* args) {
        if (m_tmp_node && m_tmp_node_capacity < n) {
            memory::deallocate(m_tmp_node);
            m_tmp_node = nullptr;
        }
        if (!m_tmp_node) {
            m_tmp_node = enode::mk_tmp(n);
            m_tmp_node_capacity = n;
        }
        for (unsigned i = 0; i < n; ++i) 
            m_tmp_node->m_args[i] = args[i];
        m_tmp_node->m_num_args = n;
        m_tmp_node->m_expr = e;
        m_tmp_node->m_table_id = UINT_MAX;
        return m_table.find(m_tmp_node);
    }


    enode_vector const& egraph::enodes_of(func_decl* f) {
        unsigned id = f->get_small_id();
        if (id < m_decl2enodes.size())
            return m_decl2enodes[id];
        return m_empty_enodes;
    }

    enode_bool_pair egraph::insert_table(enode* p) {
        TRACE("euf_verbose", tout << "insert_table " << bpp(p) << "\n");
        //SASSERT(!m_table.contains_ptr(p));
        auto rc = m_table.insert(p);
        p->m_cg = rc.first;
        return rc;
    }

    void egraph::erase_from_table(enode* p) {
        m_table.erase(p);
    }

    void egraph::reinsert_equality(enode* p) {
        SASSERT(p->is_equality());        
        if (p->value() != l_true && p->get_arg(0)->get_root() == p->get_arg(1)->get_root()) 
            queue_literal(p, nullptr);
    }

    void egraph::queue_literal(enode* p, enode* ante) {        
        if (m_on_propagate_literal)
            m_to_merge.push_back(to_merge(p, ante));
    }

    void egraph::force_push() {
        if (m_num_scopes == 0)
            return;
        // DEBUG_CODE(invariant(););

        for (; m_num_scopes > 0; --m_num_scopes) {
            m_scopes.push_back(m_updates.size());
            m_region.push_scope();
            m_updates.push_back(update_record(m_new_th_eqs_qhead, update_record::new_th_eq_qhead()));
        }
        SASSERT(m_new_th_eqs_qhead <= m_new_th_eqs.size());
    }

    void egraph::update_children(enode* n) {
        for (enode* child : enode_args(n)) 
            child->get_root()->add_parent(n);
        for (enode* child : enode_args(n))  
            SASSERT(child->get_root()->m_parents.back() == n);
        m_updates.push_back(update_record(n, update_record::update_children()));
    }

    enode* egraph::mk(expr* f, unsigned generation, unsigned num_args, enode *const* args) {
        SASSERT(!find(f));
        force_push();
        enode *n = mk_enode(f, generation, num_args, args);
        
        SASSERT(n->class_size() == 1);        
        if (num_args == 0 && m.is_unique_value(f))
            n->mark_interpreted();
        if (m_on_make)
            m_on_make(n);
        if (num_args == 0)             
            return n;
        if (m.is_eq(f) && !m.is_iff(f)) {
            n->set_is_equality();
            reinsert_equality(n);
        }
        auto [n2, comm] = insert_table(n);
        if (n2 == n) 
            update_children(n);        
        else 
            push_merge(n, n2, comm);
    
        return n;
    }

    egraph::egraph(ast_manager& m) : m(m), m_table(m), m_tmp_app(2), m_exprs(m), m_eq_decls(m) {
        m_tmp_eq = enode::mk_tmp(m_region, 2);
    }

    egraph::~egraph() {
        for (enode* n : m_nodes) 
            n->m_parents.finalize();
        if (m_tmp_node)
            memory::deallocate(m_tmp_node);
    }

    void egraph::add_plugin(plugin* p) {
        m_plugins.reserve(p->get_id() + 1);
        m_plugins.set(p->get_id(), p);       
    }

    void egraph::propagate_plugins() {
        for (auto* p : m_plugins)
            if (p)
                p->propagate();        
    }

    void egraph::add_th_eq(theory_id id, theory_var v1, theory_var v2, enode* c, enode* r) {
        TRACE("euf", tout << "eq: " << v1 << " == " << v2 << " - " << bpp(c) << " == " << bpp(r) << "\n";);
        m_new_th_eqs.push_back(th_eq(id, v1, v2, c, r));
        m_updates.push_back(update_record(update_record::new_th_eq()));
        ++m_stats.m_num_th_eqs;
        auto* p = get_plugin(id);
        if (p)
            p->merge_eh(c, r);
    }

    void egraph::add_th_diseq(theory_id id, theory_var v1, theory_var v2, enode* eq) {
        if (!th_propagates_diseqs(id))
            return;
        TRACE("euf_verbose", tout << "eq: " << v1 << " != " << v2 << "\n";);
        m_new_th_eqs.push_back(th_eq(id, v1, v2, eq->get_expr()));
        m_updates.push_back(update_record(update_record::new_th_eq()));
        auto* p = get_plugin(id);
        if (p)
            p->diseq_eh(eq);
        ++m_stats.m_num_th_diseqs;
    }

    void egraph::add_literal(enode* n, enode* ante) {
        TRACE("euf", tout << "propagate " << bpp(n) << " " << bpp(ante) << "\n");
        if (!m_on_propagate_literal)
            return;
        if (!ante) ++m_stats.m_num_eqs; else ++m_stats.m_num_lits;
        if (!ante)
            m_on_propagate_literal(n, ante);
        else if (m.is_true(ante->get_expr()) || m.is_false(ante->get_expr())) {
            for (enode* k : enode_class(n)) 
                if (k != ante) 
                    m_on_propagate_literal(k, ante);           
        }
        else {
            for (enode* k : enode_class(n)) {
                if (k->value() != ante->value()) 
                    m_on_propagate_literal(k, ante);                
            }
        }
    }

    void egraph::new_diseq(enode* n) {
        SASSERT(n->is_equality());
        SASSERT(n->value() == l_false);
        enode* arg1 = n->get_arg(0), * arg2 = n->get_arg(1);
        enode* r1 = arg1->get_root();
        enode* r2 = arg2->get_root();
        TRACE("euf", tout << "new-diseq:  " << bpp(r1) << " " << bpp(r2) << ": " << r1->has_th_vars() << " " << r2->has_th_vars() << "\n";);
        if (r1 == r2) {
            add_literal(n, nullptr);
            return;
        }
        if (!r1->has_th_vars())
            return;
        if (!r2->has_th_vars())
            return;
        if (r1->has_one_th_var() && r2->has_one_th_var() && r1->get_first_th_id() == r2->get_first_th_id()) {
            theory_id id = r1->get_first_th_id();
            if (!th_propagates_diseqs(id))
                return;
            theory_var v1 = arg1->get_closest_th_var(id);
            theory_var v2 = arg2->get_closest_th_var(id);
            add_th_diseq(id, v1, v2, n);
            return;
        }
        for (auto const& p : euf::enode_th_vars(r1)) {
            if (!th_propagates_diseqs(p.get_id()))
                continue;
            for (auto const& q : euf::enode_th_vars(r2))
                if (p.get_id() == q.get_id()) 
                    add_th_diseq(p.get_id(), p.get_var(), q.get_var(), n);
    }
    }


    /*
    * Propagate disequalities over equality atoms that are assigned to false.
    */

    void egraph::add_th_diseqs(theory_id id, theory_var v1, enode* r) {
        SASSERT(r->is_root());
        if (!th_propagates_diseqs(id))
            return;
        for (enode* p : enode_parents(r)) {
            if (p->is_equality() && p->value() == l_false) {
                enode* n = nullptr;
                n = (r == p->get_arg(0)->get_root()) ? p->get_arg(1) : p->get_arg(0);
                n = n->get_root();
                theory_var v2 = n->get_closest_th_var(id);
                if (v2 != null_theory_var)
                    add_th_diseq(id, v1, v2, p);                        
            }
        }
    }

    void egraph::set_th_propagates_diseqs(theory_id id) {
        m_th_propagates_diseqs.reserve(id + 1, false);
        m_th_propagates_diseqs[id] = true;
    }

    bool egraph::th_propagates_diseqs(theory_id id) const {
        return m_th_propagates_diseqs.get(id, false);
    }

    void egraph::add_th_var(enode* n, theory_var v, theory_id id) {
        force_push();
        theory_var w = n->get_th_var(id);
        enode* r = n->get_root();

        auto* p = get_plugin(id);
        if (p)
            p->register_node(n);

        if (w == null_theory_var) {
            n->add_th_var(v, id, m_region);
            m_updates.push_back(update_record(n, id, update_record::add_th_var()));
            if (r != n) {
                theory_var u = r->get_th_var(id);
                if (u == null_theory_var) {
                    r->add_th_var(v, id, m_region);
                    add_th_diseqs(id, v, r);
                }
                else 
                    add_th_eq(id, v, u, n, r);                
            }
        }
        else {
            theory_var u = r->get_th_var(id);
            SASSERT(u != v && u != null_theory_var);
            n->replace_th_var(v, id);
            m_updates.push_back(update_record(n, id, u, update_record::replace_th_var()));
            add_th_eq(id, v, u, n, r);
        }
    }

    void egraph::undo_add_th_var(enode* n, theory_id tid) {
        theory_var v = n->get_th_var(tid);
        SASSERT(v != null_theory_var);
        n->del_th_var(tid);
        enode* root = n->get_root();
        if (root != n && root->get_th_var(tid) == v)
            root->del_th_var(tid);
    }

    void egraph::set_merge_tf_enabled(enode* n, bool enable_merge_tf) {
        if (!m.is_bool(n->get_sort()))
            return;
        if (enable_merge_tf != n->merge_tf()) {
            TRACE("euf", tout << "set tf " << enable_merge_tf << " " << bpp(n) << "\n");
            n->set_merge_tf(enable_merge_tf);
            m_updates.push_back(update_record(n, update_record::toggle_merge_tf()));
        }
    }

    void egraph::set_cgc_enabled(enode* n, bool enable_merge) {
        if (enable_merge != n->cgc_enabled()) {
            toggle_cgc_enabled(n, false);
            m_updates.push_back(update_record(n, update_record::toggle_cgc()));
        }
    }

    void egraph::set_relevant(enode* n) {
        if (n->is_relevant())
            return;
        n->set_relevant(true);
        m_updates.push_back(update_record(n, update_record::set_relevant()));
    }

    void egraph::toggle_cgc_enabled(enode* n, bool backtracking) {
       bool enable_merge = !n->cgc_enabled();
       n->set_cgc_enabled(enable_merge);         
       if (n->num_args() > 0) {
           if (enable_merge) {
               auto [n2, comm] = insert_table(n);
               if (n2 != n && !backtracking)
                   m_to_merge.push_back(to_merge(n, n2, comm));
           }
           else if (n->is_cgr())
               erase_from_table(n);
       }
       VERIFY(n->num_args() == 0 || !n->cgc_enabled() || m_table.contains(n));
    }

    void egraph::set_value(enode* n, lbool value, justification j) {  
        if (n->value() == l_undef) {
            force_push();
            TRACE("euf", tout << bpp(n) << " := " << value << "\n";);
            n->set_value(value);
            n->m_lit_justification = j;
            m_updates.push_back(update_record(n, update_record::value_assignment()));
            if (n->is_equality() && n->value() == l_false) 
                new_diseq(n);
        }
    }

    void egraph::set_lbl_hash(enode* n) {
        SASSERT(n->m_lbl_hash == -1);
        // m_lbl_hash should be different from -1, if and only if,
        // there is a pattern that contains the enode. So,
        // I use a trail to restore the value of m_lbl_hash to -1.
        m_updates.push_back(update_record(n, update_record::lbl_hash()));
        unsigned h = hash_u(n->get_expr_id());
        n->m_lbl_hash = h & (APPROX_SET_CAPACITY - 1);
        // propagate modification to the root m_lbls set.
        enode* r = n->get_root();
        approx_set & r_lbls = r->m_lbls;
        if (!r_lbls.may_contain(n->m_lbl_hash)) {
            m_updates.push_back(update_record(r, update_record::lbl_set()));
            r_lbls.insert(n->m_lbl_hash);
        }
    }

    void egraph::pop(unsigned num_scopes) {
        if (num_scopes <= m_num_scopes) {
            m_num_scopes -= num_scopes;
            m_to_merge.reset();
            return;
        }
        num_scopes -= m_num_scopes;
        m_num_scopes = 0;

        unsigned old_lim = m_scopes.size() - num_scopes;
        unsigned num_updates = m_scopes[old_lim];
        auto undo_node = [&]() {
            enode* n = m_nodes.back();
            expr* e = m_exprs.back();
            if (n->num_args() > 0 && n->is_cgr()) 
                erase_from_table(n);
            
            m_expr2enode[e->get_id()] = nullptr;
            n->~enode();
            if (is_app(e) && n->num_args() > 0)
               m_decl2enodes[to_app(e)->get_decl()->get_small_id()].pop_back();
            m_nodes.pop_back();
            m_exprs.pop_back();
        };
        unsigned sz = m_updates.size();
        for (unsigned i = sz; i-- > num_updates; ) {
            auto const& p = m_updates[i];
            switch (p.tag) {
            case update_record::tag_t::is_add_node:
                undo_node();
                break;
            case update_record::tag_t::is_toggle_cgc:
                toggle_cgc_enabled(p.r1, true);
                break;
            case update_record::tag_t::is_toggle_merge_tf:
                p.r1->set_merge_tf(!p.r1->merge_tf());
                break;
            case update_record::tag_t::is_set_parent:
                undo_eq(p.r1, p.n1, p.r2_num_parents);
                break;
            case update_record::tag_t::is_add_th_var:
                undo_add_th_var(p.r1, p.r2_num_parents);
                break;
            case update_record::tag_t::is_replace_th_var:
                SASSERT(p.r1->get_th_var(p.m_th_id) != null_theory_var);
                p.r1->replace_th_var(p.m_old_th_var, p.m_th_id);
                break;
            case update_record::tag_t::is_new_th_eq:
                m_new_th_eqs.pop_back();
                break;
            case update_record::tag_t::is_new_th_eq_qhead:
                m_new_th_eqs_qhead = p.qhead;
                break;
            case update_record::tag_t::is_inconsistent:
                m_inconsistent = p.m_inconsistent;
                break;
            case update_record::tag_t::is_value_assignment:
                VERIFY(p.r1->value() != l_undef);
                p.r1->set_value(l_undef);
                break;
            case update_record::tag_t::is_lbl_hash:
                p.r1->m_lbl_hash = p.m_lbl_hash;
                break;
            case update_record::tag_t::is_lbl_set:
                p.r1->m_lbls.set(p.m_lbls);
                break;
            case update_record::tag_t::is_set_relevant:
                SASSERT(p.r1->is_relevant());
                p.r1->set_relevant(false);
                break;
            case update_record::tag_t::is_update_children:
                for (unsigned i = 0; i < p.r1->num_args(); ++i) {
                    CTRACE("euf", (p.r1->m_args[i]->get_root()->m_parents.back() != p.r1),
                           display(tout << bpp(p.r1->m_args[i]) << " " << bpp(p.r1->m_args[i]->get_root()) << " "););
                    SASSERT(p.r1->m_args[i]->get_root()->m_parents.back() == p.r1);
                    p.r1->m_args[i]->get_root()->m_parents.pop_back();
                }
                break;
            case update_record::tag_t::is_plugin_undo:
                m_plugins[p.m_th_id]->undo();
                break;
            default:
                UNREACHABLE();
                break;
            }                    
        }        
        SASSERT(m_updates.size() == sz);
        m_updates.shrink(num_updates);
        m_scopes.shrink(old_lim);        
        m_region.pop_scope(num_scopes);  
        m_to_merge.reset();

        SASSERT(m_new_th_eqs_qhead <= m_new_th_eqs.size());

        // DEBUG_CODE(invariant(););
    }

    void egraph::merge(enode* n1, enode* n2, justification j) {

        if (!n1->cgc_enabled() && !n2->cgc_enabled())
            return;
        SASSERT(n1->get_sort() == n2->get_sort());
        enode* r1 = n1->get_root();
        enode* r2 = n2->get_root();
        if (r1 == r2)
            return;

        TRACE("euf", j.display(tout << "merge: " << bpp(n1) << " == " << bpp(n2) << " ", m_display_justification) << "\n" << bpp(r1) << " " << bpp(r2) << "\n";);
        IF_VERBOSE(20, j.display(verbose_stream() << "merge: " << bpp(n1) << " == " << bpp(n2) << " ", m_display_justification) << "\n";);
        force_push();
        SASSERT(m_num_scopes == 0);
        ++m_stats.m_num_merge;
        if (r1->interpreted() && r2->interpreted()) {
            set_conflict(n1, n2, j);
            return;
        }
        
        if (r1->value() != r2->value() && r1->value() != l_undef && r2->value() != l_undef) {
            SASSERT(m.is_bool(r1->get_expr()));
            set_conflict(n1, n2, j);
            return;
        }
        if (!r2->interpreted() && 
             (r1->class_size() > r2->class_size() || r1->interpreted() || r1->value() != l_undef)) {
            std::swap(r1, r2);
            std::swap(n1, n2);
        }

        remove_parents(r1);
        push_eq(r1, n1, r2->num_parents());
        merge_justification(n1, n2, j);
        for (enode* c : enode_class(n1)) 
            c->m_root = r2;
        std::swap(r1->m_next, r2->m_next);
        r2->inc_class_size(r1->class_size());   
        r2->set_is_shared(l_undef);
        merge_th_eq(r1, r2);
        reinsert_parents(r1, r2);
        if (j.is_congruence() && (m.is_false(r2->get_expr()) || m.is_true(r2->get_expr())))
            add_literal(n1, r2);
        else if (n2->value() != l_undef && n1->value() != n2->value()) 
            add_literal(n1, n2);
        else if (n1->value() != l_undef && n2->value() != n1->value()) 
            add_literal(n2, n1);

        for (auto& cb : m_on_merge)
            cb(r2, r1);
    }

    void egraph::remove_parents(enode* r) {
        TRACE("euf_verbose", tout << bpp(r) << "\n");
        SASSERT(all_of(enode_parents(r), [&](enode* p) { return !p->is_marked1(); }));
        for (enode* p : enode_parents(r)) {
            if (p->is_marked1())
                continue;
            if (p->cgc_enabled()) {
                if (!p->is_cgr())
                    continue;
                SASSERT(m_table.contains_ptr(p));
                p->mark1();
                erase_from_table(p);
                CTRACE("euf_verbose", m_table.contains_ptr(p), tout << bpp(p) << "\n"; display(tout));
                SASSERT(!m_table.contains_ptr(p));
            }
            else if (p->is_equality())
                p->mark1();
        }
    }

    void egraph::reinsert_parents(enode* r1, enode* r2) {
        for (enode* p : enode_parents(r1)) {
            if (!p->is_marked1())
                continue;
            p->unmark1();
            TRACE("euf_verbose", tout << "reinsert " << bpp(r1) << " " << bpp(r2) << " " << bpp(p) << " " << p->cgc_enabled() << "\n";);
            if (p->cgc_enabled()) {
                auto [p_other, comm] = insert_table(p);
                SASSERT(m_table.contains_ptr(p) == (p_other == p));
                CTRACE("euf_verbose", p_other != p, tout << "reinsert " << bpp(p) << " == " << bpp(p_other) << " " << p->value() << " " << p_other->value() << "\n");
                if (p_other != p) 
                    m_to_merge.push_back(to_merge(p_other, p, comm));                
                else
                    r2->m_parents.push_back(p);
                if (p->is_equality())
                    reinsert_equality(p);
            }
            else if (p->is_equality()) {
                r2->m_parents.push_back(p);
                reinsert_equality(p);
            }
        }
    }

    void egraph::merge_th_eq(enode* n, enode* root) {
        SASSERT(n != root);
        for (auto const& iv : enode_th_vars(n)) {
            theory_id id = iv.get_id();
            theory_var v = root->get_th_var(id);
            if (v == null_theory_var) {                
                root->add_th_var(iv.get_var(), id, m_region);   
                m_updates.push_back(update_record(root, id, update_record::add_th_var()));
                add_th_diseqs(id, iv.get_var(), root);
            }
            else {
                SASSERT(v != iv.get_var());
                add_th_eq(id, v, iv.get_var(), n, root);
            }
        }
    }

    void egraph::undo_eq(enode* r1, enode* n1, unsigned r2_num_parents) {
        enode* r2 = r1->get_root();
        TRACE("euf_verbose", tout << "undo-eq old-root: " << bpp(r1) << " current-root " << bpp(r2) << " node: " << bpp(n1) << "\n";);
        r2->dec_class_size(r1->class_size());
        r2->set_is_shared(l_undef);
        std::swap(r1->m_next, r2->m_next);
        auto begin = r2->begin_parents() + r2_num_parents, end = r2->end_parents();
        for (auto it = begin; it != end; ++it) {
            enode* p = *it;
            TRACE("euf_verbose", tout << "erase " << bpp(p) << "\n";);
            SASSERT(!p->cgc_enabled() || m_table.contains_ptr(p));
            SASSERT(!p->cgc_enabled() || p->is_cgr());
            if (p->cgc_enabled())
                erase_from_table(p);
        }

        for (enode* c : enode_class(r1))
            c->m_root = r1;

        for (enode* p : enode_parents(r1)) 
            if (p->cgc_enabled() && (p->is_cgr() || !p->congruent(p->m_cg))) 
                insert_table(p);                    
        r2->m_parents.shrink(r2_num_parents);
        unmerge_justification(n1);
    }


    bool egraph::propagate() {
        force_push();
        unsigned i = 0;
        bool change = true;
        while (change) {
            change = false;
            propagate_plugins();
            for (; i < m_to_merge.size() && m.limit().inc() && !inconsistent(); ++i) {
                auto const& w = m_to_merge[i];
                switch (w.t) {
                case to_merge_plain:
                case to_merge_comm:
                    merge(w.a, w.b, justification::congruence(w.commutativity(), m_congruence_timestamp++));
                    break;
                case to_justified:
                    merge(w.a, w.b, w.j);
                    break;
                case to_add_literal:
                    add_literal(w.a, w.b);
                    break;
                }                
            }
        }
        m_to_merge.reset();
        return 
            (m_new_th_eqs_qhead < m_new_th_eqs.size()) ||
            inconsistent();
    }

    void egraph::set_conflict(enode* n1, enode* n2, justification j) {
        ++m_stats.m_num_conflicts;
        if (m_inconsistent)
            return;
        m_inconsistent = true;
        m_updates.push_back(update_record(false, update_record::inconsistent()));
        m_n1 = n1;
        m_n2 = n2;
        TRACE("euf", tout << "conflict " << bpp(n1) << " " << bpp(n2) << " " << j << " " << n1->get_root()->value() << " " << n2->get_root()->value() << "\n");
        m_justification = j;
    }

    void egraph::merge_justification(enode* n1, enode* n2, justification j) {
        SASSERT(!n1->get_root()->m_target);
        SASSERT(!n2->get_root()->m_target);
        SASSERT(n1->reaches(n1->get_root()));
        SASSERT(!n2->reaches(n1->get_root()));
        SASSERT(!n2->reaches(n1));
        n1->reverse_justification();
        n1->m_target = n2;
        n1->m_justification = j;
        SASSERT(n1->acyclic());
        SASSERT(n2->acyclic());
        SASSERT(n1->get_root()->reaches(n1));
        SASSERT(!n2->get_root()->m_target);
        TRACE("euf_verbose", tout << "merge " << n1->get_expr_id() << " " << n2->get_expr_id() << " updates: " << m_updates.size() << "\n";);
    }

    void egraph::unmerge_justification(enode* n1) {
        TRACE("euf_verbose", tout << "unmerge " << n1->get_expr_id() << " " << n1->m_target->get_expr_id() << "\n";);
        // r1 -> ..  -> n1 -> n2 -> ... -> r2
        // where n2 = n1->m_target
        SASSERT(n1->get_root()->reaches(n1));
        SASSERT(n1->m_target);
        n1->m_target        = nullptr;
        n1->m_justification = justification::axiom(null_theory_id);
        n1->get_root()->reverse_justification();
        // ---------------
        // n1 -> ... -> r1
        // n2 -> ... -> r2
        SASSERT(n1->reaches(n1->get_root()));
        SASSERT(!n1->get_root()->m_target);
    }

    bool egraph::are_diseq(enode* a, enode* b) {
        enode* ra = a->get_root(), * rb = b->get_root();
        if (ra == rb)
            return false;
        if (ra->interpreted() && rb->interpreted())
            return true;
        if (ra->get_sort() != rb->get_sort())
            return true;
        enode* r = tmp_eq(ra, rb);
        if (r && r->get_root()->value() == l_false)
            return true;
        return false;
    }

    enode* egraph::get_enode_eq_to(func_decl* f, unsigned num_args, enode* const* args) {
        m_tmp_app.set_decl(f);
        m_tmp_app.set_num_args(num_args);
        return find(m_tmp_app.get_app(), num_args, args);
    }

    enode* egraph::tmp_eq(enode* a, enode* b) {
        SASSERT(a->is_root());
        SASSERT(b->is_root());
        if (a->num_parents() > b->num_parents())
            std::swap(a, b);
        for (enode* p : enode_parents(a))
            if (p->is_equality() && 
                (b == p->get_arg(0)->get_root() || b == p->get_arg(1)->get_root()))
                return p;
        return nullptr;
    }

    /**
       \brief generate an explanation for a congruence.
       Each pair of children under a congruence have the same roots
       and therefore have a least common ancestor. We only need
       explanations up to the least common ancestors.
     */
    void egraph::push_congruence(enode* n1, enode* n2, bool comm) {
        SASSERT(is_app(n1->get_expr()));
        SASSERT(n1->get_decl() == n2->get_decl());
        m_uses_congruence = true;
        if (m_used_cc && !comm) { 
            m_used_cc(n1->get_app(), n2->get_app());
        }
        if (comm && 
            n1->get_arg(0)->get_root() == n2->get_arg(1)->get_root() &&
            n1->get_arg(1)->get_root() == n2->get_arg(0)->get_root()) {
            push_lca(n1->get_arg(0), n2->get_arg(1));
            push_lca(n1->get_arg(1), n2->get_arg(0));
            return;
        }
        TRACE("euf_verbose", tout << bpp(n1) << " " << bpp(n2) << "\n");
            
        for (unsigned i = 0; i < n1->num_args(); ++i) 
            push_lca(n1->get_arg(i), n2->get_arg(i));
    }

    enode* egraph::find_lca(enode* a, enode* b) {
        SASSERT(a->get_root() == b->get_root());
        a->mark2_targets();
        while (!b->is_marked2()) 
            b = b->m_target;
        a->mark2_targets();
        return b;
    }
    
    void egraph::push_to_lca(enode* n, enode* lca) {
        while (n != lca) {
            m_todo.push_back(n);
            n = n->m_target;
        }
    }

    void egraph::push_lca(enode* a, enode* b) {
        enode* lca = find_lca(a, b);
        push_to_lca(a, lca);
        push_to_lca(b, lca);
    }

    void egraph::push_todo(enode* n) {
        while (n) {
            m_todo.push_back(n);
            n = n->m_target;
        }
    }

    void egraph::begin_explain() {
        SASSERT(m_todo.empty());
        m_uses_congruence = false;
        DEBUG_CODE(for (enode* n : m_nodes) SASSERT(!n->is_marked1()););
    }

    void egraph::end_explain() {
        for (enode* n : m_todo) 
            n->unmark1();
        DEBUG_CODE(for (enode* n : m_nodes) SASSERT(!n->is_marked1()););
        m_todo.reset();        
    }

    template 
    void egraph::explain(ptr_vector& justifications, cc_justification* cc) {
        SASSERT(m_inconsistent);
        push_todo(m_n1);
        push_todo(m_n2);
        explain_eq(justifications, cc, m_n1, m_n2, m_justification);
        explain_todo(justifications, cc);
    }

    template 
    void egraph::explain_eq(ptr_vector& justifications, cc_justification* cc, enode* a, enode* b, justification const& j) {
        TRACE("euf_verbose", tout << "explain-eq: " << bpp(a) << " == " << bpp(b) << " jst: " << j << "\n";);
        if (j.is_external())
            justifications.push_back(j.ext());
        else if (j.is_congruence()) 
            push_congruence(a, b, j.is_commutative());
        else if (j.is_dependent()) {
            vector js;
            for (auto const& j2 : justification::dependency_manager::s_linearize(j.get_dependency(), js))
                explain_eq(justifications, cc, a, b, j2);
        }
        else if (j.is_equality()) 
            explain_eq(justifications, cc, j.lhs(), j.rhs());
        else if (j.is_axiom() && j.get_theory_id() != null_theory_id) {
            IF_VERBOSE(20, verbose_stream() << "TODO add theory axiom to justification\n");
        }
        if (cc && j.is_congruence()) 
            cc->push_back(std::tuple(a->get_app(), b->get_app(), j.timestamp(), j.is_commutative()));
    }


    template 
    void egraph::explain_eq(ptr_vector& justifications, cc_justification* cc, enode* a, enode* b) {
        SASSERT(a->get_root() == b->get_root());
        
        enode* lca = find_lca(a, b);
        TRACE("euf_verbose", tout << "explain-eq: " << bpp(a) << " == " << bpp(b) << " lca: " << bpp(lca) << "\n";);
        push_to_lca(a, lca);
        push_to_lca(b, lca);
        if (m_used_eq)
            m_used_eq(a->get_expr(), b->get_expr(), lca->get_expr());
        explain_todo(justifications, cc);
    }

    template 
    unsigned egraph::explain_diseq(ptr_vector& justifications, cc_justification* cc, enode* a, enode* b) {
        enode* ra = a->get_root(), * rb = b->get_root();
        SASSERT(ra != rb);
        if (ra->interpreted() && rb->interpreted()) {
            explain_eq(justifications, cc, a, ra);
            explain_eq(justifications, cc, b, rb);
            return sat::null_bool_var;
        }
        enode* r = tmp_eq(ra, rb);
        SASSERT(r && r->get_root()->value() == l_false);
        explain_eq(justifications, cc, r, r->get_root());
        return r->get_root()->bool_var();
    }


    template 
    void egraph::explain_todo(ptr_vector& justifications, cc_justification* cc) {
        for (unsigned i = 0; i < m_todo.size(); ++i) {
            enode* n = m_todo[i];
            if (n->is_marked1())
                continue;
            if (n->m_target) {
                n->mark1();
                CTRACE("euf_verbose", m_display_justification, n->m_justification.display(tout << n->get_expr_id() << " = " << n->m_target->get_expr_id() << " ", m_display_justification) << "\n";);
                explain_eq(justifications, cc, n, n->m_target, n->m_justification);
            }
            else if (!n->is_marked1() && n->value() != l_undef) {
                n->mark1();
                if (m.is_true(n->get_expr()) || m.is_false(n->get_expr()))
                    continue;
                justification j = n->m_lit_justification;
                SASSERT(j.is_external());
                justifications.push_back(j.ext());
            }
        }
    }

    void egraph::invariant() {
        for (enode* n : m_nodes)
            n->invariant(*this);
        for (enode* n : m_nodes)
            if (n->cgc_enabled() && n->num_args() > 0 && (!m_table.find(n) || n->get_root() != m_table.find(n)->get_root())) {
                CTRACE("euf", !m_table.find(n), tout << "node is not in table\n";);
                CTRACE("euf", m_table.find(n), tout << "root " << bpp(n->get_root()) << " table root " << bpp(m_table.find(n)->get_root()) << "\n";);
                TRACE("euf", display(tout << bpp(n) << " is not closed under congruence\n"););
                UNREACHABLE();
            }
    }

    std::ostream& egraph::display(std::ostream& out, unsigned max_args, enode* n) const {
        if (!n->is_relevant())
            out << "n";
        out << "#" << n->get_expr_id() << " := ";
        expr* f = n->get_expr();
        if (is_app(f))
            out << mk_bounded_pp(f, m, 1) << " ";
        else if (is_quantifier(f))
            out << "q:" << f->get_id() << " ";
        else
            out << "v:" << f->get_id() << " ";
        if (!n->is_root()) 
            out << "[r " << n->get_root()->get_expr_id() << "] ";
        if (!n->m_parents.empty()) {
            out << "[p";
            for (enode* p : enode_parents(n))
                out << " " << p->get_expr_id();
            out << "] ";
        }
        auto value_of = [&]() {
            switch (n->value()) {
            case l_true: return "T";
            case l_false: return "F";
            default: return "?";
            }
        };
        if (n->bool_var() != sat::null_bool_var) 
            out << "[b" << n->bool_var() << " := " << value_of() << (n->cgc_enabled() ? "" : " no-cgc") << (n->merge_tf()? " merge-tf" : "") << "] ";
        if (n->has_th_vars()) {
            out << "[t";
            for (auto const& v : enode_th_vars(n))
                out << " " << v.get_id() << ":" << v.get_var();
            out << "] ";
        }
        if (n->generation() > 0)
            out << "[g " << n->generation() << "] ";
        if (n->m_target && m_display_justification)
            n->m_justification.display(out << "[j " << n->m_target->get_expr_id() << " ", m_display_justification) << "] ";
        out << "\n";
        return out;
    }

    std::ostream& egraph::display(std::ostream& out) const {
        out << "updates " << m_updates.size() << "\n";
        out << "neweqs  " << m_new_th_eqs.size() << " qhead: " << m_new_th_eqs_qhead << "\n";
        m_table.display(out);
        unsigned max_args = 0;
        for (enode* n : m_nodes)
            max_args = std::max(max_args, n->num_args());
        for (enode* n : m_nodes) 
            display(out, max_args, n);          
        for (auto* p : m_plugins)
            if (p)
                p->display(out);
        return out;
    }

    void egraph::collect_statistics(statistics& st) const {
        st.update("euf merge", m_stats.m_num_merge);
        st.update("euf conflicts", m_stats.m_num_conflicts);
        st.update("euf propagations eqs", m_stats.m_num_eqs);
        st.update("euf propagations theory eqs", m_stats.m_num_th_eqs);
        st.update("euf propagations theory diseqs", m_stats.m_num_th_diseqs);
        st.update("euf propagations literal", m_stats.m_num_lits);
    }

    void egraph::copy_from(egraph const& src, std::function& copy_justification) {
        SASSERT(m_scopes.empty());
        SASSERT(m_nodes.empty());
        ptr_vector old_expr2new_enode, args;
        ast_translation tr(src.m, m);
        for (unsigned i = 0; i < src.m_nodes.size(); ++i) {
            enode* n1 = src.m_nodes[i];
            expr* e1 = src.m_exprs[i];
            args.reset();
            for (unsigned j = 0; j < n1->num_args(); ++j) 
                args.push_back(old_expr2new_enode[n1->get_arg(j)->get_expr_id()]);
            expr*  e2 = tr(e1);
            enode* n2 = mk(e2, n1->generation(), args.size(), args.data());
            
            old_expr2new_enode.setx(e1->get_id(), n2, nullptr);
            n2->set_value(n1->value());
            n2->m_bool_var = n1->m_bool_var;
            n2->m_commutative = n1->m_commutative;
            n2->m_cgc_enabled = n1->m_cgc_enabled;
            n2->m_merge_tf_enabled = n1->m_merge_tf_enabled;
            n2->m_is_equality = n1->m_is_equality;            
        }
        for (unsigned i = 0; i < src.m_nodes.size(); ++i) {             
            enode* n1 = src.m_nodes[i];
            enode* n1t = n1->m_target;      
            enode* n2 = m_nodes[i];
            enode* n2t = n1t ? old_expr2new_enode[n1->get_expr_id()] : nullptr;
            SASSERT(!n1t || n2t);
            SASSERT(!n1t || n1->get_sort() == n1t->get_sort());
            SASSERT(!n1t || n2->get_sort() == n2t->get_sort());
            if (n1t && n2->get_root() != n2t->get_root()) 
                merge(n2, n2t, n1->m_justification.copy(copy_justification));
        }
        propagate();
        for (unsigned i = 0; i < src.m_scopes.size(); ++i)
            push();
        force_push();
    }
}

template void euf::egraph::explain(ptr_vector& justifications, cc_justification*);
template void euf::egraph::explain_todo(ptr_vector& justifications, cc_justification*);
template void euf::egraph::explain_eq(ptr_vector& justifications, cc_justification*, enode* a, enode* b);
template unsigned euf::egraph::explain_diseq(ptr_vector& justifications, cc_justification*, enode* a, enode* b);

template void euf::egraph::explain(ptr_vector& justifications, cc_justification*);
template void euf::egraph::explain_todo(ptr_vector& justifications, cc_justification*);
template void euf::egraph::explain_eq(ptr_vector& justifications, cc_justification*, enode* a, enode* b);
template unsigned euf::egraph::explain_diseq(ptr_vector& justifications, cc_justification*, enode* a, enode* b);

template void euf::egraph::explain(ptr_vector& justifications, cc_justification*);
template void euf::egraph::explain_todo(ptr_vector& justifications, cc_justification*);
template void euf::egraph::explain_eq(ptr_vector& justifications, cc_justification*, enode* a, enode* b);
template unsigned euf::egraph::explain_diseq(ptr_vector& justifications, cc_justification*, enode* a, enode* b);


#if 0
Each node has a congruence closure root, cg.
cg is set to the representative in the cc table 
(first insertion of congruent node).
Each node n has a set of parents, denoted n.P.
The table maintains the invariant
 - p.cg = find(p)

Merge sets r2 to the root of r1 
(r2 and r1 are both considered roots before the merge).
The operation Unmerge reverses the effect of Merge.


Merge(r1, r2)
-------------

Erase:        for each p in r1.P such that p.cg == p:
                 erase from table        
Update root:  r1.root := r2
Insert:       for each p in r1.P:
                 p.cg = insert p in table
                 if p.cg == p:
                   append p to r2.P
                 else 
                   add (p.cg == p) to "to_merge" 

Unmerge(r1, r2)
---------------

Erase:        for each p in r2.P added from r1.P:
                 erase p from table 
Revert root:  r1.root := r1
Insert:       for each p in r1.P:
                 insert p if n was cc root before merge

condition for being cc root before merge:
  p.cg == p or !congruent(p, p.cg)

congruent(p,q) := roots of p.args = roots of q.args

The algorithm orients r1, r2 such that class_size(r1) <= class_size(r2).
With N nodes, there can be at most N calls to Merge.
Each of the calls traverse r1.P from the smaller class size.
Label a merge tree with nodes from the larger class size.
In other words, if Merge(r2,r1); Merge(r3,r1) is a sequence
of calls where r1 is selected root, then the merge tree is

    r1 
  /   \
 r1    r3  
   \
   r2

Note that parent lists are re-examined only for nodes that join
from right subtrees (with lesser class sizes).
Claim: a node participates in a path along right adjoining sub-trees at most O(log(N)) times.
Justification (very roughly): the size of a right adjoining subtree can at most 
be equal to the left adjoining sub-tree. This entails a logarithmic number of 
re-examinations from the right adjoining tree.

The parent lists are bounded by the maximal arity of functions.

Example:

Initially:
 n1 := f(a,b)  has root n1
 n2 := f(a1,b) has root n2
 table = [f(a,b) |-> n1, f(a1,b) |-> n2]

merge(a,a1) (a1 becomes root)
 table = [f(a1,b) |-> n2]
 n1.cg = n2
 a1.P = [n2]
 n1 is not added as parent because it is not a cc root after the assignment a.root := a1

unmerge(a,a1)
- nothing is erased
- n1 is reinserted. It used to be a root.

#endif




© 2015 - 2024 Weber Informatics LLC | Privacy Policy