All Downloads are FREE. Search and download functionalities are using the official Maven repository.

z3-z3-4.13.0.src.math.lp.cross_nested.h Maven / Gradle / Ivy

The newest version!
/*++
  Copyright (c) 2017 Microsoft Corporation

  Module Name:

  

  Abstract:

  

  Author:
  Nikolaj Bjorner (nbjorner)
  Lev Nachmanson (levnach)

  Revision History:


  --*/
#pragma once
#include 
#include "math/lp/nex.h"
#include "math/lp/nex_creator.h"

namespace nla {
class cross_nested {
    
    // fields
    nex *                                             m_e;
    std::function                  m_call_on_result;
    std::function                    m_var_is_fixed;
    std::function                        m_random;    
    bool                                              m_done;
    ptr_vector                                   m_b_split_vec;
    int                                               m_reported;
    bool                                              m_random_bit;
    std::function                      m_mk_scalar;
    nex_creator&                                      m_nex_creator;
#ifdef Z3DEBUG
    nex* m_e_clone;
#endif
public:
    
    nex_creator& get_nex_creator() { return m_nex_creator; }
    
    cross_nested(std::function call_on_result,
                 std::function var_is_fixed,
                 std::function random,
                 nex_creator& nex_cr) :
        m_call_on_result(call_on_result),
        m_var_is_fixed(var_is_fixed),
        m_random(random),
        m_done(false),
        m_reported(0),
        m_mk_scalar([this]{return m_nex_creator.mk_scalar(rational(1));}),
        m_nex_creator(nex_cr)
    {}

    
    void run(nex *e) {
        TRACE("nla_cn", tout << *e << "\n";);
        SASSERT(m_nex_creator.is_simplified(*e));
        m_e = e;
#ifdef Z3DEBUG
        m_e_clone = m_nex_creator.clone(m_e);
        TRACE("nla_cn", tout << "m_e_clone = " <<  * m_e_clone << "\n";);
        
#endif
        vector front;
        explore_expr_on_front_elem(&m_e, front);
    }

    static nex** pop_front(vector& front) {
        nex** c = front.back();
        TRACE("nla_cn", tout <<  **c << "\n";);
        front.pop_back();
        return c;
    }


    nex* extract_common_factor(nex* e) {
        nex_sum* c = to_sum(e);
        TRACE("nla_cn", tout << "c=" << *c << "\n"; tout << "occs:"; dump_occurences(tout, m_nex_creator.occurences_map()) << "\n";);
        unsigned size = c->size();
        bool have_factor = false;
        for (const auto & p : m_nex_creator.occurences_map()) {
            if (p.second.m_occs == size) {
                have_factor = true;
                break;
            }
        }
        if (have_factor == false) return nullptr;
        m_nex_creator.m_mk_mul.reset();
        for (const auto & p : m_nex_creator.occurences_map()) { // randomize here: todo
            if (p.second.m_occs == size) {
                m_nex_creator.m_mk_mul *= nex_pow(m_nex_creator.mk_var(p.first), p.second.m_power);
            }
        }
        return m_nex_creator.m_mk_mul.mk();
    }

    static bool has_common_factor(const nex_sum* c) {
        TRACE("nla_cn", tout << "c=" << *c << "\n";);
        auto & ch = *c;
        auto common_vars = get_vars_of_expr(ch[0]);
        for (lpvar j : common_vars) {
            bool divides_the_rest = true;
            for (unsigned i = 1; i < ch.size() && divides_the_rest; i++) {
                if (!ch[i]->contains(j))
                    divides_the_rest = false;
            }
            if (divides_the_rest) {
                TRACE("nla_cn_common_factor", tout << c << "\n";);
                return true;
            }
        }
        return false;
    }

    bool proceed_with_common_factor(nex** c, vector& front) {
        TRACE("nla_cn", tout << "c=" << **c << "\n";);
        nex* f = extract_common_factor(*c);
        if (f == nullptr) {
            TRACE("nla_cn", tout << "no common factor\n"; );
            return false;
        }
        TRACE("nla_cn", tout << "common factor f=" << *f << "\n";);
        
        nex* c_over_f = m_nex_creator.mk_div(**c, *f);
        c_over_f = m_nex_creator.simplify(c_over_f);
        TRACE("nla_cn", tout << "c_over_f = " << *c_over_f << std::endl;);
        nex_mul* cm; 
        *c = cm = m_nex_creator.mk_mul(f, c_over_f);
        TRACE("nla_cn", tout << "common factor=" << *f << ", c=" << **c << "\ne = " << *m_e << "\n";);
        explore_expr_on_front_elem((*cm)[1].ee(),  front);
        return true;
    }

    static void push_to_front(vector& front, nex** e) {
        TRACE("nla_cn", tout << **e << "\n";);
        front.push_back(e);
    }
    
    static vector copy_front(const vector& front) {
        vector v;
        for (nex** n: front)
            v.push_back(*n);
        return v;
    }

    static void restore_front(const vector ©, vector& front) {
        SASSERT(copy.size() == front.size());
        for (unsigned i = 0; i < front.size(); i++)
            *(front[i]) = copy[i];
    }

    void pop_allocated(unsigned sz) {
        m_nex_creator.pop(sz);
    }
    
    void explore_expr_on_front_elem_vars(nex** c, vector& front, const svector & vars) {
        TRACE("nla_cn", tout << "save c=" << **c << "; front:"; print_front(front, tout) << "\n";);           
        nex* copy_of_c = *c;
        auto copy_of_front = copy_front(front);
        int alloc_size = m_nex_creator.size();
        for (lpvar j : vars) {
            if (m_var_is_fixed(j)) {
                // it does not make sense to explore fixed multupliers
                // because the interval products do not become smaller
                // after factoring those out
                continue;
            }
            explore_of_expr_on_sum_and_var(c, j, front);
            if (m_done)
                return;
            TRACE("nla_cn", tout << "before restore c=" << **c << "\nm_e=" << *m_e << "\n";);
            *c = copy_of_c;
            restore_front(copy_of_front, front);
            pop_allocated(alloc_size);
            TRACE("nla_cn", tout << "after restore c=" << **c << "\nm_e=" << *m_e << "\n";);   
        }
    }

    template 
    static std::ostream& dump_occurences(std::ostream& out, const T& occurences) {
        out << "{";
        for (const auto& p: occurences) {
            out << "(j" << p.first << "->" << p.second << ")";
        }
        out << "}" << std::endl;
        return out;
    }

    void calc_occurences(nex_sum* e) {
        clear_maps();
        for (const auto * ce : *e) {
            if (ce->is_mul()) {
                ce->to_mul().get_powers_from_mul(m_nex_creator.powers());
                update_occurences_with_powers();
            } else if (ce->is_var()) {
                add_var_occs(ce->to_var().var());
            }
        }
        remove_singular_occurences();
        TRACE("nla_cn_details", tout << "e=" << *e << "\noccs="; dump_occurences(tout, m_nex_creator.occurences_map()) << "\n";);
    }

    void fill_vars_from_occurences_map(svector& vars) {
        for (auto & p : m_nex_creator.occurences_map())
            vars.push_back(p.first);

        m_random_bit = m_random() % 2;
        TRACE("nla_cn", tout << "m_random_bit = " << m_random_bit << "\n";);
        std::sort(vars.begin(), vars.end(), [this](lpvar j, lpvar k)
                                            {
                                                auto it_j = m_nex_creator.occurences_map().find(j);
                                                auto it_k = m_nex_creator.occurences_map().find(k);
                                                

                                                const occ& a = it_j->second;
                                                const occ& b = it_k->second;
                                                if (a.m_occs > b.m_occs)
                                                    return true;
                                                if (a.m_occs < b.m_occs)
                                                    return false;
                                                if (a.m_power > b.m_power)
                                                    return true;
                                                if (a.m_power < b.m_power)
                                                    return false;
                                                
                                                return m_random_bit? j < k : j > k;
                                          });
 
    }    
    
    bool proceed_with_common_factor_or_get_vars_to_factor_out(nex** c, svector& vars, vector front) {
        calc_occurences(to_sum(*c));
        if (proceed_with_common_factor(c, front))
            return true;

        fill_vars_from_occurences_map(vars);
        return false;
    }
    
    void explore_expr_on_front_elem(nex** c, vector& front) {
        svector vars;
        if (proceed_with_common_factor_or_get_vars_to_factor_out(c, vars, front))
            return;

        TRACE("nla_cn", tout << "m_e=" << *m_e << "\nc=" << **c << ", c vars=";
              print_vector(vars, tout) << "; front:"; print_front(front, tout) << "\n";);
    
        if (vars.empty()) {
            if (front.empty()) {
                TRACE("nla_cn", tout << "got the cn form: =" << *m_e << "\n";);
                m_done = m_call_on_result(m_e) || ++m_reported > 100;
 #ifdef Z3DEBUG
                TRACE("nla_cn", tout << "m_e_clone " << *m_e_clone << "\n";);
                SASSERT(nex_creator::equal(m_e, m_e_clone));
 #endif
            } else {
                nex** f = pop_front(front);
                explore_expr_on_front_elem(f, front);     
            }
        } else {
            explore_expr_on_front_elem_vars(c, front, vars);
        }
    }

    std::ostream& print_front(const vector& front, std::ostream& out) const {
        for (auto e : front) {
            out << **e << "\n";
        }
        return out;
    }
    // c is the sub expressiond which is going to be changed from sum to the cross nested form
    // front will be explored more
    void explore_of_expr_on_sum_and_var(nex** c, lpvar j, vector front) {
        TRACE("nla_cn", tout << "m_e=" << *m_e << "\nc=" << **c << "\nj = " << nex_creator::ch(j) << "\nfront="; print_front(front, tout) << "\n";);
        if (!split_with_var(*c, j, front))
            return;
        TRACE("nla_cn", tout << "after split c=" << **c << "\nfront="; print_front(front, tout) << "\n";);
        if (front.empty()) {
#ifdef Z3DEBUG
            TRACE("nla_cn", tout << "got the cn form: =" << *m_e <<  ", clone = " << *m_e_clone << "\n";);
#endif
            m_done = m_call_on_result(m_e) || ++m_reported > 100;
#ifdef Z3DEBUG
            SASSERT(nex_creator::equal(m_e, m_e_clone));
#endif
            return;
        }
        auto n = pop_front(front);
        explore_expr_on_front_elem(n, front);
    }

    void add_var_occs(lpvar j) {
        auto it = m_nex_creator.occurences_map().find(j);
        if (it != m_nex_creator.occurences_map().end()) {
            it->second.m_occs++;
            it->second.m_power = 1;
        } else {            
            m_nex_creator.occurences_map().insert(std::make_pair(j, occ(1, 1)));
        }
    }    

    void update_occurences_with_powers() {
        for (auto & p : m_nex_creator.powers()) {
            lpvar j = p.first;
            unsigned jp = p.second;
            auto it = m_nex_creator.occurences_map().find(j);
            if (it == m_nex_creator.occurences_map().end()) {
                m_nex_creator.occurences_map()[j] = occ(1, jp);
            } else {
                it->second.m_occs++;
                it->second.m_power = std::min(it->second.m_power, jp);
            }
        }
        TRACE("nla_cn_details", tout << "occs="; dump_occurences(tout, m_nex_creator.occurences_map()) << "\n";);
    }    
    
    void remove_singular_occurences() {
        svector r;
        for (const auto & p : m_nex_creator.occurences_map()) {
            if (p.second.m_occs <= 1) {
                r.push_back(p.first);
            }
        }
        for (lpvar j : r)
            m_nex_creator.occurences_map().erase(j);
    }

    void clear_maps() {
        m_nex_creator.occurences_map().clear();
        m_nex_creator.powers().clear();
    }
    
    // j -> the number of expressions j appears in as a multiplier
    // The result is sorted by large number of occurences first
    vector> get_mult_occurences(const nex_sum* e) {
        clear_maps();
        for (const auto * ce : *e) {
            if (ce->is_mul()) {
                to_mul(ce)->get_powers_from_mul(m_nex_creator.powers());
                update_occurences_with_powers();
            } else if (ce->is_var()) {
                add_var_occs(to_var(ce)->var());
            }
        }
        remove_singular_occurences();
        TRACE("nla_cn_details", tout << "e=" << *e << "\noccs="; dump_occurences(tout, m_nex_creator.occurences_map()) << "\n";);
        vector> ret;
        for (auto & p : m_nex_creator.occurences_map())
            ret.push_back(p);
        std::sort(ret.begin(), ret.end(), [](const std::pair& a, const std::pair& b) {
                                              if (a.second.m_occs > b.second.m_occs)
                                                  return true;
                                              if (a.second.m_occs < b.second.m_occs)
                                                  return false;
                                              if (a.second.m_power > b.second.m_power)
                                                  return true;
                                              if (a.second.m_power < b.second.m_power)
                                                  return false;

                                              return a.first < b.first;
                                          });
        return ret;
    }

    static bool is_divisible_by_var(nex const* ce, lpvar j) {
        return (ce->is_mul() && to_mul(ce)->contains(j))
            || (ce->is_var() && to_var(ce)->var() == j);
    }
    // all factors of j go to a, the rest to b
    void pre_split(nex_sum * e, lpvar j, nex_sum const*& a, nex const*& b) {
        TRACE("nla_cn_details", tout << "e = " << * e << ", j = " << m_nex_creator.ch(j) << std::endl;);
        SASSERT(m_nex_creator.is_simplified(*e));
        nex_creator::sum_factory sf(m_nex_creator);
        m_b_split_vec.clear();
        for (nex const* ce: *e) {
            TRACE("nla_cn_details", tout << "ce = " << *ce << "\n";);
            if (is_divisible_by_var(ce, j)) {
                sf += m_nex_creator.mk_div(*ce , j);
            } else {
                m_b_split_vec.push_back(const_cast(ce));
            }        
        }
        a = sf.mk();
        TRACE("nla_cn_details", tout << "a = " << *a << "\n";);
        SASSERT(a->size() >= 2 && m_b_split_vec.size());
        a = to_sum(m_nex_creator.simplify_sum(const_cast(a)));
        
        if (m_b_split_vec.size() == 1) {
            b = m_b_split_vec[0];
            TRACE("nla_cn_details", tout << "b = " << *b << "\n";);
        } else {
            SASSERT(m_b_split_vec.size() > 1);
            b = m_nex_creator.mk_sum(m_b_split_vec);
            TRACE("nla_cn_details", tout << "b = " << *b << "\n";);
        }
    }

    void update_front_with_split_with_non_empty_b(nex* &e, lpvar j, vector & front, nex_sum const* a, nex const* b) {
        TRACE("nla_cn_details", tout << "b = " << *b << "\n";);
        e = m_nex_creator.mk_sum(m_nex_creator.mk_mul(m_nex_creator.mk_var(j), a), b); // e = j*a + b
        if (!a->is_linear()) {
            nex **ptr_to_a = e->to_sum()[0]->to_mul()[1].ee();
            push_to_front(front, ptr_to_a);
        }
        
        if (b->is_sum() && !to_sum(b)->is_linear()) {
            nex **ptr_to_a = &(e->to_sum()[1]);
            push_to_front(front, ptr_to_a);
        }
    }
    
   void update_front_with_split(nex* & e, lpvar j, vector & front, nex_sum const* a, nex const* b) {
        if (b == nullptr) {
            e = m_nex_creator.mk_mul(m_nex_creator.mk_var(j), a);
            if (!to_sum(a)->is_linear())
                push_to_front(front, e->to_mul()[1].ee());
        } else {
            update_front_with_split_with_non_empty_b(e, j, front, a, b);
        }
    }
    // it returns true if the recursion brings a cross-nested form
    bool split_with_var(nex*& e, lpvar j, vector & front) {
        SASSERT(e->is_sum());
        TRACE("nla_cn", tout << "e = " << *e << ", j=" << nex_creator::ch(j) << "\n";);
        nex_sum const* a; nex const* b;
        pre_split(to_sum(e), j, a, b);
        /*
          When we have e without a non-trivial common factor then
          there is a variable j such that e = jP + Q, where Q has all members
          of e that do not have j as a factor, and
          P also does not have a non-trivial common factor. It is enough
          to explore only such variables to create all cross-nested forms.
        */
        
        if (has_common_factor(a)) {
            return false;
        }
        update_front_with_split(e, j, front, a, b);
        return true;
    }

    
    ~cross_nested() {
        m_nex_creator.clear();
    }

    bool done() const { return m_done; }

#if Z3DEBUG
    nex * normalize_sum(nex_sum* a) {
        NOT_IMPLEMENTED_YET();
        return nullptr;
    }

    nex * normalize_mul(nex_mul* a) {
        TRACE("nla_cn", tout << *a << "\n";);
        NOT_IMPLEMENTED_YET();
        return nullptr;
    }    
    
    nex * normalize(nex* a) {
        if (a->is_elementary())
            return a;
        nex *r;
        if (a->is_mul()) {
            r = normalize_mul(to_mul(a));
        } else {
            r = normalize_sum(to_sum(a));
        }
        r->sort();
        return r;
    }
#endif
    
};
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy