z3-z3-4.13.0.src.math.lp.cross_nested.h Maven / Gradle / Ivy
The newest version!
/*++
Copyright (c) 2017 Microsoft Corporation
Module Name:
Abstract:
Author:
Nikolaj Bjorner (nbjorner)
Lev Nachmanson (levnach)
Revision History:
--*/
#pragma once
#include
#include "math/lp/nex.h"
#include "math/lp/nex_creator.h"
namespace nla {
class cross_nested {
// fields
nex * m_e;
std::function m_call_on_result;
std::function m_var_is_fixed;
std::function m_random;
bool m_done;
ptr_vector m_b_split_vec;
int m_reported;
bool m_random_bit;
std::function m_mk_scalar;
nex_creator& m_nex_creator;
#ifdef Z3DEBUG
nex* m_e_clone;
#endif
public:
nex_creator& get_nex_creator() { return m_nex_creator; }
cross_nested(std::function call_on_result,
std::function var_is_fixed,
std::function random,
nex_creator& nex_cr) :
m_call_on_result(call_on_result),
m_var_is_fixed(var_is_fixed),
m_random(random),
m_done(false),
m_reported(0),
m_mk_scalar([this]{return m_nex_creator.mk_scalar(rational(1));}),
m_nex_creator(nex_cr)
{}
void run(nex *e) {
TRACE("nla_cn", tout << *e << "\n";);
SASSERT(m_nex_creator.is_simplified(*e));
m_e = e;
#ifdef Z3DEBUG
m_e_clone = m_nex_creator.clone(m_e);
TRACE("nla_cn", tout << "m_e_clone = " << * m_e_clone << "\n";);
#endif
vector front;
explore_expr_on_front_elem(&m_e, front);
}
static nex** pop_front(vector& front) {
nex** c = front.back();
TRACE("nla_cn", tout << **c << "\n";);
front.pop_back();
return c;
}
nex* extract_common_factor(nex* e) {
nex_sum* c = to_sum(e);
TRACE("nla_cn", tout << "c=" << *c << "\n"; tout << "occs:"; dump_occurences(tout, m_nex_creator.occurences_map()) << "\n";);
unsigned size = c->size();
bool have_factor = false;
for (const auto & p : m_nex_creator.occurences_map()) {
if (p.second.m_occs == size) {
have_factor = true;
break;
}
}
if (have_factor == false) return nullptr;
m_nex_creator.m_mk_mul.reset();
for (const auto & p : m_nex_creator.occurences_map()) { // randomize here: todo
if (p.second.m_occs == size) {
m_nex_creator.m_mk_mul *= nex_pow(m_nex_creator.mk_var(p.first), p.second.m_power);
}
}
return m_nex_creator.m_mk_mul.mk();
}
static bool has_common_factor(const nex_sum* c) {
TRACE("nla_cn", tout << "c=" << *c << "\n";);
auto & ch = *c;
auto common_vars = get_vars_of_expr(ch[0]);
for (lpvar j : common_vars) {
bool divides_the_rest = true;
for (unsigned i = 1; i < ch.size() && divides_the_rest; i++) {
if (!ch[i]->contains(j))
divides_the_rest = false;
}
if (divides_the_rest) {
TRACE("nla_cn_common_factor", tout << c << "\n";);
return true;
}
}
return false;
}
bool proceed_with_common_factor(nex** c, vector& front) {
TRACE("nla_cn", tout << "c=" << **c << "\n";);
nex* f = extract_common_factor(*c);
if (f == nullptr) {
TRACE("nla_cn", tout << "no common factor\n"; );
return false;
}
TRACE("nla_cn", tout << "common factor f=" << *f << "\n";);
nex* c_over_f = m_nex_creator.mk_div(**c, *f);
c_over_f = m_nex_creator.simplify(c_over_f);
TRACE("nla_cn", tout << "c_over_f = " << *c_over_f << std::endl;);
nex_mul* cm;
*c = cm = m_nex_creator.mk_mul(f, c_over_f);
TRACE("nla_cn", tout << "common factor=" << *f << ", c=" << **c << "\ne = " << *m_e << "\n";);
explore_expr_on_front_elem((*cm)[1].ee(), front);
return true;
}
static void push_to_front(vector& front, nex** e) {
TRACE("nla_cn", tout << **e << "\n";);
front.push_back(e);
}
static vector copy_front(const vector& front) {
vector v;
for (nex** n: front)
v.push_back(*n);
return v;
}
static void restore_front(const vector ©, vector& front) {
SASSERT(copy.size() == front.size());
for (unsigned i = 0; i < front.size(); i++)
*(front[i]) = copy[i];
}
void pop_allocated(unsigned sz) {
m_nex_creator.pop(sz);
}
void explore_expr_on_front_elem_vars(nex** c, vector& front, const svector & vars) {
TRACE("nla_cn", tout << "save c=" << **c << "; front:"; print_front(front, tout) << "\n";);
nex* copy_of_c = *c;
auto copy_of_front = copy_front(front);
int alloc_size = m_nex_creator.size();
for (lpvar j : vars) {
if (m_var_is_fixed(j)) {
// it does not make sense to explore fixed multupliers
// because the interval products do not become smaller
// after factoring those out
continue;
}
explore_of_expr_on_sum_and_var(c, j, front);
if (m_done)
return;
TRACE("nla_cn", tout << "before restore c=" << **c << "\nm_e=" << *m_e << "\n";);
*c = copy_of_c;
restore_front(copy_of_front, front);
pop_allocated(alloc_size);
TRACE("nla_cn", tout << "after restore c=" << **c << "\nm_e=" << *m_e << "\n";);
}
}
template
static std::ostream& dump_occurences(std::ostream& out, const T& occurences) {
out << "{";
for (const auto& p: occurences) {
out << "(j" << p.first << "->" << p.second << ")";
}
out << "}" << std::endl;
return out;
}
void calc_occurences(nex_sum* e) {
clear_maps();
for (const auto * ce : *e) {
if (ce->is_mul()) {
ce->to_mul().get_powers_from_mul(m_nex_creator.powers());
update_occurences_with_powers();
} else if (ce->is_var()) {
add_var_occs(ce->to_var().var());
}
}
remove_singular_occurences();
TRACE("nla_cn_details", tout << "e=" << *e << "\noccs="; dump_occurences(tout, m_nex_creator.occurences_map()) << "\n";);
}
void fill_vars_from_occurences_map(svector& vars) {
for (auto & p : m_nex_creator.occurences_map())
vars.push_back(p.first);
m_random_bit = m_random() % 2;
TRACE("nla_cn", tout << "m_random_bit = " << m_random_bit << "\n";);
std::sort(vars.begin(), vars.end(), [this](lpvar j, lpvar k)
{
auto it_j = m_nex_creator.occurences_map().find(j);
auto it_k = m_nex_creator.occurences_map().find(k);
const occ& a = it_j->second;
const occ& b = it_k->second;
if (a.m_occs > b.m_occs)
return true;
if (a.m_occs < b.m_occs)
return false;
if (a.m_power > b.m_power)
return true;
if (a.m_power < b.m_power)
return false;
return m_random_bit? j < k : j > k;
});
}
bool proceed_with_common_factor_or_get_vars_to_factor_out(nex** c, svector& vars, vector front) {
calc_occurences(to_sum(*c));
if (proceed_with_common_factor(c, front))
return true;
fill_vars_from_occurences_map(vars);
return false;
}
void explore_expr_on_front_elem(nex** c, vector& front) {
svector vars;
if (proceed_with_common_factor_or_get_vars_to_factor_out(c, vars, front))
return;
TRACE("nla_cn", tout << "m_e=" << *m_e << "\nc=" << **c << ", c vars=";
print_vector(vars, tout) << "; front:"; print_front(front, tout) << "\n";);
if (vars.empty()) {
if (front.empty()) {
TRACE("nla_cn", tout << "got the cn form: =" << *m_e << "\n";);
m_done = m_call_on_result(m_e) || ++m_reported > 100;
#ifdef Z3DEBUG
TRACE("nla_cn", tout << "m_e_clone " << *m_e_clone << "\n";);
SASSERT(nex_creator::equal(m_e, m_e_clone));
#endif
} else {
nex** f = pop_front(front);
explore_expr_on_front_elem(f, front);
}
} else {
explore_expr_on_front_elem_vars(c, front, vars);
}
}
std::ostream& print_front(const vector& front, std::ostream& out) const {
for (auto e : front) {
out << **e << "\n";
}
return out;
}
// c is the sub expressiond which is going to be changed from sum to the cross nested form
// front will be explored more
void explore_of_expr_on_sum_and_var(nex** c, lpvar j, vector front) {
TRACE("nla_cn", tout << "m_e=" << *m_e << "\nc=" << **c << "\nj = " << nex_creator::ch(j) << "\nfront="; print_front(front, tout) << "\n";);
if (!split_with_var(*c, j, front))
return;
TRACE("nla_cn", tout << "after split c=" << **c << "\nfront="; print_front(front, tout) << "\n";);
if (front.empty()) {
#ifdef Z3DEBUG
TRACE("nla_cn", tout << "got the cn form: =" << *m_e << ", clone = " << *m_e_clone << "\n";);
#endif
m_done = m_call_on_result(m_e) || ++m_reported > 100;
#ifdef Z3DEBUG
SASSERT(nex_creator::equal(m_e, m_e_clone));
#endif
return;
}
auto n = pop_front(front);
explore_expr_on_front_elem(n, front);
}
void add_var_occs(lpvar j) {
auto it = m_nex_creator.occurences_map().find(j);
if (it != m_nex_creator.occurences_map().end()) {
it->second.m_occs++;
it->second.m_power = 1;
} else {
m_nex_creator.occurences_map().insert(std::make_pair(j, occ(1, 1)));
}
}
void update_occurences_with_powers() {
for (auto & p : m_nex_creator.powers()) {
lpvar j = p.first;
unsigned jp = p.second;
auto it = m_nex_creator.occurences_map().find(j);
if (it == m_nex_creator.occurences_map().end()) {
m_nex_creator.occurences_map()[j] = occ(1, jp);
} else {
it->second.m_occs++;
it->second.m_power = std::min(it->second.m_power, jp);
}
}
TRACE("nla_cn_details", tout << "occs="; dump_occurences(tout, m_nex_creator.occurences_map()) << "\n";);
}
void remove_singular_occurences() {
svector r;
for (const auto & p : m_nex_creator.occurences_map()) {
if (p.second.m_occs <= 1) {
r.push_back(p.first);
}
}
for (lpvar j : r)
m_nex_creator.occurences_map().erase(j);
}
void clear_maps() {
m_nex_creator.occurences_map().clear();
m_nex_creator.powers().clear();
}
// j -> the number of expressions j appears in as a multiplier
// The result is sorted by large number of occurences first
vector> get_mult_occurences(const nex_sum* e) {
clear_maps();
for (const auto * ce : *e) {
if (ce->is_mul()) {
to_mul(ce)->get_powers_from_mul(m_nex_creator.powers());
update_occurences_with_powers();
} else if (ce->is_var()) {
add_var_occs(to_var(ce)->var());
}
}
remove_singular_occurences();
TRACE("nla_cn_details", tout << "e=" << *e << "\noccs="; dump_occurences(tout, m_nex_creator.occurences_map()) << "\n";);
vector> ret;
for (auto & p : m_nex_creator.occurences_map())
ret.push_back(p);
std::sort(ret.begin(), ret.end(), [](const std::pair& a, const std::pair& b) {
if (a.second.m_occs > b.second.m_occs)
return true;
if (a.second.m_occs < b.second.m_occs)
return false;
if (a.second.m_power > b.second.m_power)
return true;
if (a.second.m_power < b.second.m_power)
return false;
return a.first < b.first;
});
return ret;
}
static bool is_divisible_by_var(nex const* ce, lpvar j) {
return (ce->is_mul() && to_mul(ce)->contains(j))
|| (ce->is_var() && to_var(ce)->var() == j);
}
// all factors of j go to a, the rest to b
void pre_split(nex_sum * e, lpvar j, nex_sum const*& a, nex const*& b) {
TRACE("nla_cn_details", tout << "e = " << * e << ", j = " << m_nex_creator.ch(j) << std::endl;);
SASSERT(m_nex_creator.is_simplified(*e));
nex_creator::sum_factory sf(m_nex_creator);
m_b_split_vec.clear();
for (nex const* ce: *e) {
TRACE("nla_cn_details", tout << "ce = " << *ce << "\n";);
if (is_divisible_by_var(ce, j)) {
sf += m_nex_creator.mk_div(*ce , j);
} else {
m_b_split_vec.push_back(const_cast(ce));
}
}
a = sf.mk();
TRACE("nla_cn_details", tout << "a = " << *a << "\n";);
SASSERT(a->size() >= 2 && m_b_split_vec.size());
a = to_sum(m_nex_creator.simplify_sum(const_cast(a)));
if (m_b_split_vec.size() == 1) {
b = m_b_split_vec[0];
TRACE("nla_cn_details", tout << "b = " << *b << "\n";);
} else {
SASSERT(m_b_split_vec.size() > 1);
b = m_nex_creator.mk_sum(m_b_split_vec);
TRACE("nla_cn_details", tout << "b = " << *b << "\n";);
}
}
void update_front_with_split_with_non_empty_b(nex* &e, lpvar j, vector & front, nex_sum const* a, nex const* b) {
TRACE("nla_cn_details", tout << "b = " << *b << "\n";);
e = m_nex_creator.mk_sum(m_nex_creator.mk_mul(m_nex_creator.mk_var(j), a), b); // e = j*a + b
if (!a->is_linear()) {
nex **ptr_to_a = e->to_sum()[0]->to_mul()[1].ee();
push_to_front(front, ptr_to_a);
}
if (b->is_sum() && !to_sum(b)->is_linear()) {
nex **ptr_to_a = &(e->to_sum()[1]);
push_to_front(front, ptr_to_a);
}
}
void update_front_with_split(nex* & e, lpvar j, vector & front, nex_sum const* a, nex const* b) {
if (b == nullptr) {
e = m_nex_creator.mk_mul(m_nex_creator.mk_var(j), a);
if (!to_sum(a)->is_linear())
push_to_front(front, e->to_mul()[1].ee());
} else {
update_front_with_split_with_non_empty_b(e, j, front, a, b);
}
}
// it returns true if the recursion brings a cross-nested form
bool split_with_var(nex*& e, lpvar j, vector & front) {
SASSERT(e->is_sum());
TRACE("nla_cn", tout << "e = " << *e << ", j=" << nex_creator::ch(j) << "\n";);
nex_sum const* a; nex const* b;
pre_split(to_sum(e), j, a, b);
/*
When we have e without a non-trivial common factor then
there is a variable j such that e = jP + Q, where Q has all members
of e that do not have j as a factor, and
P also does not have a non-trivial common factor. It is enough
to explore only such variables to create all cross-nested forms.
*/
if (has_common_factor(a)) {
return false;
}
update_front_with_split(e, j, front, a, b);
return true;
}
~cross_nested() {
m_nex_creator.clear();
}
bool done() const { return m_done; }
#if Z3DEBUG
nex * normalize_sum(nex_sum* a) {
NOT_IMPLEMENTED_YET();
return nullptr;
}
nex * normalize_mul(nex_mul* a) {
TRACE("nla_cn", tout << *a << "\n";);
NOT_IMPLEMENTED_YET();
return nullptr;
}
nex * normalize(nex* a) {
if (a->is_elementary())
return a;
nex *r;
if (a->is_mul()) {
r = normalize_mul(to_mul(a));
} else {
r = normalize_sum(to_sum(a));
}
r->sort();
return r;
}
#endif
};
}