z3-z3-4.13.0.src.sat.sat_cut_simplifier.cpp Maven / Gradle / Ivy
The newest version!
/*++
Copyright (c) 2020 Microsoft Corporation
Module Name:
sat_cut_simplifier.cpp
Abstract:
extract AIG definitions from clauses
Perform cut-set enumeration to identify equivalences.
Author:
Nikolaj Bjorner 2020-01-02
--*/
#include "sat/sat_cut_simplifier.h"
#include "sat/sat_xor_finder.h"
#include "sat/sat_lut_finder.h"
#include "sat/sat_npn3_finder.h"
#include "sat/sat_elim_eqs.h"
namespace sat {
struct cut_simplifier::report {
cut_simplifier& s;
stopwatch m_watch;
unsigned m_num_eqs, m_num_units, m_num_cuts, m_num_learned_implies;
report(cut_simplifier& s): s(s) {
m_watch.start();
m_num_eqs = s.m_stats.m_num_eqs;
m_num_units = s.m_stats.m_num_units;
m_num_cuts = s.m_stats.m_num_cuts;
m_num_learned_implies = s.m_stats.m_num_learned_implies;
}
~report() {
unsigned ne = s.m_stats.m_num_eqs - m_num_eqs;
unsigned nu = s.m_stats.m_num_units - m_num_units;
unsigned nc = s.m_stats.m_num_cuts - m_num_cuts;
unsigned ni = s.m_stats.m_num_learned_implies - m_num_learned_implies;
IF_VERBOSE(2,
verbose_stream() << "(sat.cut-simplifier";
if (nu > 0) verbose_stream() << " :num-units " << nu;
if (ne > 0) verbose_stream() << " :num-eqs " << ne;
if (ni > 0) verbose_stream() << " :num-bin " << ni;
if (nc > 0) verbose_stream() << " :num-cuts " << nc;
verbose_stream() << " :mb " << mem_stat() << m_watch << ")\n");
}
};
struct cut_simplifier::validator {
solver& _s;
params_ref p;
literal_vector m_assumptions;
validator(solver& _s, params_ref const& p): _s(_s), p(p) {
}
void validate(unsigned n, literal const* clause) {
validate(literal_vector(n, clause));
}
void validate(literal_vector const& clause) {
if (clause.size() == 2 && clause[0] == ~clause[1]) return;
solver s(p, _s.rlimit());
s.copy(_s, false);
IF_VERBOSE(10, verbose_stream() << "validate: " << clause << "\n");
m_assumptions.reset();
for (literal lit : clause) m_assumptions.push_back(~lit);
lbool r = s.check(clause.size(), m_assumptions.data());
if (r != l_false) {
IF_VERBOSE(0,
verbose_stream() << "not validated: " << clause << "\n";
s.display(verbose_stream()););
UNREACHABLE();
}
}
};
void cut_simplifier::ensure_validator() {
if (!m_validator) {
params_ref p;
p.set_bool("aig", false);
p.set_bool("drat.check_unsat", false);
p.set_sym("drat.file", symbol());
p.set_uint("max_conflicts", 10000);
m_validator = alloc(validator, s, p);
}
}
cut_simplifier::cut_simplifier(solver& _s):
s(_s),
m_trail_size(0),
m_validator(nullptr) {
if (s.get_config().m_drat) {
std::function _on_add =
[this](literal_vector const& clause) { s.m_drat.add(clause); };
std::function _on_del =
[this](literal_vector const& clause) { s.m_drat.del(clause); };
m_aig_cuts.set_on_clause_add(_on_add);
m_aig_cuts.set_on_clause_del(_on_del);
}
else if (m_config.m_validate_cuts) {
ensure_validator();
std::function _on_add =
[this](literal_vector const& clause) {
m_validator->validate(clause);
};
m_aig_cuts.set_on_clause_add(_on_add);
}
}
cut_simplifier::~cut_simplifier() {
dealloc(m_validator);
}
void cut_simplifier::add_and(literal head, unsigned sz, literal const* lits) {
m_aig_cuts.add_node(head, and_op, sz, lits);
for (unsigned i = 0; i < sz; ++i) VERIFY(head.var() != lits[i].var());
m_stats.m_num_ands++;
}
// head == l1 or l2 or l3
// <=>
// ~head == ~l1 and ~l2 and ~l3
void cut_simplifier::add_or(literal head, unsigned sz, literal const* lits) {
m_lits.reset();
m_lits.append(sz, lits);
for (unsigned i = 0; i < sz; ++i) m_lits[i].neg();
m_aig_cuts.add_node(~head, and_op, sz, m_lits.data());
m_stats.m_num_ands++;
}
void cut_simplifier::add_xor(literal head, unsigned sz, literal const* lits) {
m_aig_cuts.add_node(head, xor_op, sz, lits);
m_stats.m_num_xors++;
}
void cut_simplifier::add_ite(literal head, literal c, literal t, literal e) {
literal lits[3] = { c, t, e };
m_aig_cuts.add_node(head, ite_op, 3, lits);
m_stats.m_num_ites++;
}
void cut_simplifier::add_iff(literal head, literal l1, literal l2) {
literal lits[2] = { l1, ~l2 };
m_aig_cuts.add_node(head, xor_op, 2, lits);
m_stats.m_num_xors++;
}
void cut_simplifier::set_root(bool_var v, literal r) {
m_aig_cuts.set_root(v, r);
}
void cut_simplifier::operator()() {
bool force = s.m_config.m_cut_force;
report _report(*this);
TRACE("cut_simplifier", s.display(tout););
unsigned n = 0, i = 0;
++m_stats.m_num_calls;
do {
n = m_stats.m_num_eqs + m_stats.m_num_units;
clauses2aig();
aig2clauses();
++i;
}
while (((force && i < 5) || i*i < m_stats.m_num_calls) && n < m_stats.m_num_eqs + m_stats.m_num_units);
}
/**
\brief extract AIG definitions from clauses
Ensure that they are sorted and variables have unique definitions.
*/
void cut_simplifier::clauses2aig() {
for (; m_config.m_enable_units && m_trail_size < s.init_trail_size(); ++m_trail_size) {
literal lit = s.trail_literal(m_trail_size);
m_aig_cuts.add_node(lit, and_op, 0, nullptr);
}
clause_vector clauses(s.clauses());
if (m_config.m_learned2aig) clauses.append(s.learned());
std::function on_and =
[&,this](literal head, literal_vector const& ands) {
m_aig_cuts.add_node(head, and_op, ands.size(), ands.data());
m_stats.m_xands++;
};
std::function on_ite =
[&,this](literal head, literal c, literal t, literal e) {
literal args[3] = { c, t, e };
m_aig_cuts.add_node(head, ite_op, 3, args);
m_stats.m_xites++;
};
if (s.m_config.m_cut_aig) {
aig_finder af(s);
af.set(on_and);
af.set(on_ite);
af(clauses);
}
std::function on_xor =
[&,this](literal_vector const& xors) {
SASSERT(xors.size() > 1);
unsigned max_level = xors.back().var();
unsigned index = xors.size() - 1;
for (unsigned i = index; i-- > 0; ) {
literal l = xors[i];
if (l.var() > max_level) {
max_level = l.var();
index = i;
}
}
// head + t1 + t2 + .. = 1
// <=>
// ~head = t1 + t2 + ..
literal head = ~xors[index];
TRACE("cut_simplifier", tout << xors << "\n";);
unsigned sz = xors.size() - 1;
m_lits.reset();
for (unsigned i = xors.size(); i-- > 0; ) {
if (i != index)
m_lits.push_back(xors[i]);
}
m_aig_cuts.add_node(head, xor_op, sz, m_lits.data());
m_lits.reset();
m_stats.m_xxors++;
};
if (s.m_config.m_cut_xor) {
xor_finder xf(s);
xf.set(on_xor);
xf(clauses);
}
std::function on_lut =
[&,this](uint64_t lut, bool_var_vector const& vars, bool_var v) {
m_stats.m_xluts++;
// m_aig_cuts.add_cut(v, lut, vars);
m_aig_cuts.add_node(v, lut, vars.size(), vars.data());
};
if (s.m_config.m_cut_npn3) {
npn3_finder nf(s);
// TBD: stubs for npn3
// question: perhaps just use a LUT interface?
// nf.set_on_mux
// nf.set_on_maj
// nf.set_on_orand
// nf.set_on_and
// nf.set_on_xor
// nf.set_on_andxor
// nf.set_on_xorand
// nf.set_on_gamble
// nf.set_on_onehot
// nf.set_on_dot
// nf(clauses);
}
if (s.m_config.m_cut_lut) {
lut_finder lf(s);
lf.set(on_lut);
lf(clauses);
}
#if 0
statistics st;
collect_statistics(st);
st.display(std::cout);
exit(0);
#endif
}
void cut_simplifier::aig2clauses() {
vector const& cuts = m_aig_cuts();
m_stats.m_num_cuts = m_aig_cuts.num_cuts();
add_dont_cares(cuts);
cuts2equiv(cuts);
cuts2implies(cuts);
simulate_eqs();
}
void cut_simplifier::cuts2equiv(vector const& cuts) {
map cut2id;
bool new_eq = false;
union_find_default_ctx ctx;
union_find<> uf(ctx);
for (unsigned i = 2*s.num_vars(); i--> 0; ) uf.mk_var();
auto add_eq = [&](literal l1, literal l2) {
uf.merge(l1.index(), l2.index());
uf.merge((~l1).index(), (~l2).index());
new_eq = true;
};
for (unsigned i = cuts.size(); i-- > 0; ) {
literal u(i, false);
for (auto& c : cuts[i]) {
unsigned j = 0;
cut nc(c);
nc.negate();
if (m_config.m_enable_units && c.is_true()) {
assign_unit(c, u);
}
else if (m_config.m_enable_units && c.is_false()) {
assign_unit(nc, ~u);
}
else if (cut2id.find(&c, j)) {
literal v(j, false);
assign_equiv(c, u, v);
add_eq(u, v);
}
else if (cut2id.find(&nc, j)) {
literal v(j, true);
assign_equiv(c, u, v);
add_eq(u, v);
}
else {
cut2id.insert(&c, i);
}
}
}
if (new_eq) {
uf2equiv(uf);
}
}
void cut_simplifier::assign_unit(cut const& c, literal lit) {
if (s.value(lit) != l_undef)
return;
IF_VERBOSE(10, verbose_stream() << "new unit " << lit << "\n");
validate_unit(lit);
certify_unit(lit, c);
s.assign_unit(lit);
++m_stats.m_num_units;
}
void cut_simplifier::assign_equiv(cut const& c, literal u, literal v) {
if (u.var() == v.var()) return;
IF_VERBOSE(10, verbose_stream() << u << " " << v << " " << c << "\n";);
TRACE("cut_simplifier", tout << u << " == " << v << "\n";);
certify_equivalence(u, v, c);
validate_eq(u, v);
}
/**
* Convert a union-find over literals into input for eim_eqs.
*/
void cut_simplifier::uf2equiv(union_find<> const& uf) {
union_find_default_ctx ctx;
union_find<> uf2(ctx);
bool new_eq = false;
for (unsigned i = 2*s.num_vars(); i--> 0; ) uf2.mk_var();
// extract equivalences over non-eliminated literals.
for (unsigned idx = 0; idx < uf.get_num_vars(); ++idx) {
if (!uf.is_root(idx) || 1 == uf.size(idx)) continue;
literal root = null_literal;
unsigned first = idx;
do {
literal lit = to_literal(idx);
if (!s.was_eliminated(lit)) {
if (root == null_literal) {
root = lit;
}
else {
uf2.merge(lit.index(), root.index());
new_eq = true;
++m_stats.m_num_eqs;
}
}
idx = uf.next(idx);
}
while (first != idx);
}
for (unsigned i = s.num_vars(); i-- > 0; ) {
literal lit(i, false);
if (uf2.find(lit.index()) == uf2.find((~lit).index())) {
s.set_conflict();
return;
}
}
if (new_eq) {
elim_eqs elim(s);
elim(uf2);
}
}
/**
* Extract binary clauses from cuts.
* A bit encoding of a LUT of u
* that sets a subset of bits for LUT' of v establishes
* that u implies v.
*/
void cut_simplifier::cuts2implies(vector const& cuts) {
if (!m_config.m_learn_implies) return;
vector>> var_tables;
map cut2tables;
unsigned j = 0;
big big(s.rand());
big.init(s, true);
for (auto const& cs : cuts) {
if (s.was_eliminated(cs.var()))
continue;
for (auto const& c : cs) {
if (c.is_false() || c.is_true())
continue;
if (!cut2tables.find(&c, j)) {
j = var_tables.size();
var_tables.push_back(vector>());
cut2tables.insert(&c, j);
}
var_tables[j].push_back(std::make_pair(cs.var(), &c));
}
}
for (unsigned i = 0; i < var_tables.size(); ++i) {
auto const& vt = var_tables[i];
for (unsigned j = 0; j < vt.size(); ++j) {
literal u(vt[j].first, false);
cut const& c1 = *vt[j].second;
cut nc1(c1);
nc1.negate();
uint64_t t1 = c1.table();
uint64_t n1 = nc1.table();
for (unsigned k = j + 1; k < vt.size(); ++k) {
literal v(vt[k].first, false);
cut const& c2 = *vt[k].second;
uint64_t t2 = c2.table();
uint64_t n2 = c2.ntable();
if (t1 == t2 || t1 == n2) {
// already handled
}
else if ((t1 | t2) == t2) {
learn_implies(big, c1, u, v);
}
else if ((t1 | n2) == n2) {
learn_implies(big, c1, u, ~v);
}
else if ((n1 | t2) == t2) {
learn_implies(big, nc1, ~u, v);
}
else if ((n1 | n2) == n2) {
learn_implies(big, nc1, ~u, ~v);
}
}
}
}
}
void cut_simplifier::learn_implies(big& big, cut const& c, literal u, literal v) {
if (u == ~v) {
assign_unit(c, v);
return;
}
if (u == v) {
return;
}
bin_rel q, p(~u, v);
if (m_bins.find(p, q) && q.op != op_code::none)
return;
if (big.connected(u, v))
return;
for (auto const& w : s.get_wlist(u))
if (w.is_binary_clause() && v == w.get_literal())
return;
certify_implies(u, v, c);
s.mk_clause(~u, v, sat::status::redundant());
// m_bins owns reference to ~u or v created by certify_implies
m_bins.insert(p);
++m_stats.m_num_learned_implies;
}
void cut_simplifier::simulate_eqs() {
if (!m_config.m_simulate_eqs) return;
auto var2val = m_aig_cuts.simulate(4);
// Assign higher cutset budgets to equality candidates that come from simulation
// touch them to trigger recomputation of cutsets.
u64_map val2lit;
unsigned i = 0, num_eqs = 0;
for (cut_val val : var2val) {
if (!s.was_eliminated(i) && s.value(i) == l_undef) {
literal u(i, false), v;
if (val2lit.find(val.m_t, v)) {
m_aig_cuts.inc_max_cutset_size(i);
m_aig_cuts.inc_max_cutset_size(v.var());
num_eqs++;
}
else if (val2lit.find(val.m_f, v)) {
m_aig_cuts.inc_max_cutset_size(i);
m_aig_cuts.inc_max_cutset_size(v.var());
num_eqs++;
}
else {
val2lit.insert(val.m_t, u);
val2lit.insert(val.m_f, ~u);
}
}
++i;
}
IF_VERBOSE(2, verbose_stream() << "(sat.cut-simplifier num simulated eqs " << num_eqs << ")\n");
}
void cut_simplifier::track_binary(bin_rel const& p) {
if (!s.m_config.m_drat)
return;
literal u, v;
p.to_binary(u, v);
track_binary(u, v);
}
void cut_simplifier::untrack_binary(bin_rel const& p) {
if (!s.m_config.m_drat)
return;
literal u, v;
p.to_binary(u, v);
untrack_binary(u, v);
}
void cut_simplifier::track_binary(literal u, literal v) {
if (s.m_config.m_drat) {
s.m_drat.add(u, v, sat::status::redundant());
}
}
void cut_simplifier::untrack_binary(literal u, literal v) {
if (s.m_config.m_drat) {
s.m_drat.del(u, v);
}
}
void cut_simplifier::certify_unit(literal u, cut const& c) {
certify_implies(~u, u, c);
}
/**
* Equivalences modulo cuts are not necessarily DRAT derivable.
* To ensure that there is a DRAT derivation we create all resolvents
* of the LUT clauses until deriving binary u or ~v and ~u or v.
* each resolvent is DRAT derivable because there are two previous lemmas that
* contain complementary literals.
*/
void cut_simplifier::certify_equivalence(literal u, literal v, cut const& c) {
certify_implies(u, v, c);
certify_implies(v, u, c);
}
/**
* certify that u implies v, where c is the cut for u.
* Then every position in c where u is true, it has to be
* the case that v is too.
* Where u is false, v can have any value.
* Thus, for every clause C or u', where u' is u or ~u,
* it follows that C or ~u or v
*/
void cut_simplifier::certify_implies(literal u, literal v, cut const& c) {
if (!s.m_config.m_drat) return;
vector clauses;
std::function on_clause =
[&,this](literal_vector const& clause) {
SASSERT(clause.back().var() == u.var());
clauses.push_back(clause);
clauses.back().back() = ~u;
if (~u != v) clauses.back().push_back(v);
s.m_drat.add(clauses.back());
};
m_aig_cuts.cut2def(on_clause, c, u);
// create all resolvents over C. C is assumed to
// contain all combinations of some set of literals.
unsigned i = 0, sz = clauses.size();
while (sz - i > 1) {
SASSERT((sz & (sz - 1)) == 0 && "sz is a power of 2");
for (; i < sz; ++i) {
auto const& clause = clauses[i];
if (clause[0].sign()) {
literal_vector cl(clause.size() - 1, clause.data() + 1);
clauses.push_back(cl);
s.m_drat.add(cl);
}
}
i = sz;
sz = clauses.size();
}
IF_VERBOSE(10, for (auto const& clause : clauses) verbose_stream() << clause << "\n";);
// once we established equivalence, don't need auxiliary clauses for DRAT.
clauses.pop_back();
for (auto const& clause : clauses) {
s.m_drat.del(clause);
}
}
void cut_simplifier::add_dont_cares(vector const& cuts) {
if (s.m_config.m_cut_dont_cares) {
cuts2bins(cuts);
bins2dont_cares();
dont_cares2cuts(cuts);
}
if (s.m_config.m_cut_redundancies) {
m_aig_cuts.simplify();
}
}
/**
* Collect binary relations between variables that occur in cut sets.
*/
void cut_simplifier::cuts2bins(vector const& cuts) {
svector dcs;
for (auto const& p : m_bins)
if (p.op != op_code::none)
dcs.push_back(p);
m_bins.reset();
for (auto const& cs : cuts)
for (auto const& c : cs)
for (unsigned i = c.size(); i-- > 0; )
for (unsigned j = i; j-- > 0; )
m_bins.insert(bin_rel(c[j],c[i]));
// don't lose previous don't cares
for (auto const& p : dcs) {
if (m_bins.contains(p)) {
m_bins.insert(p);
}
else {
untrack_binary(p);
}
}
}
/**
* Compute masks for binary relations.
*/
void cut_simplifier::bins2dont_cares() {
big b(s.rand());
b.init(s, true);
for (auto& p : m_bins) {
if (p.op != op_code::none) continue;
literal u(p.u, false), v(p.v, false);
// u -> v, then u & ~v is impossible
if (b.connected(u, v)) {
p.op = op_code::pn;
}
else if (b.connected(u, ~v)) {
p.op = op_code::pp;
}
else if (b.connected(~u, v)) {
p.op = op_code::nn;
}
else if (b.connected(~u, ~v)) {
p.op = op_code::np;
}
if (p.op != op_code::none) {
track_binary(p);
}
}
IF_VERBOSE(2, {
unsigned n = 0; for (auto const& p : m_bins) if (p.op != op_code::none) ++n;
verbose_stream() << n << " / " << m_bins.size() << " don't cares\n";
});
}
/**
* Loop over cuts, if it is possible to add a new don't care combination
* to a cut, then ensure that the variable is "touched" so that it participates
* in the next propagation.
*/
void cut_simplifier::dont_cares2cuts(vector const& cuts) {
for (auto& cs : cuts) {
for (auto const& c : cs) {
if (add_dont_care(c)) {
m_aig_cuts.touch(cs.var());
m_stats.m_num_dont_care_reductions++;
}
}
}
}
/**
* compute masks for position i, j and op-code p.op
* For the don't care combination false, false, the first don't care
* position is 0. If it is true, false, the first don't care position
* is the position that encodes the first occurrence where i is true.
* It is 2^i. Cases for false, true and true, true are similar.
* Don't care positions are spaced apart by 2^{j+1},
* where j is the second variable position.
*/
uint64_t cut_simplifier::op2dont_care(unsigned i, unsigned j, bin_rel const& p) {
SASSERT(i < j && j < 6);
if (p.op == op_code::none) return 0ull;
// first position of mask is offset into output bits contributed by i and j
bool i_is_0 = (p.op == op_code::np || p.op == op_code::nn);
bool j_is_0 = (p.op == op_code::pn || p.op == op_code::nn);
uint64_t first = (i_is_0 ? 0 : (1 << i)) + (j_is_0 ? 0 : (1 << j));
uint64_t inc = 1ull << (j + 1);
uint64_t r = 1ull << first;
while (inc < 64ull) { r |= (r << inc); inc *= 2; }
return r;
}
/**
* Apply obtained dont_cares to cut sets.
* The don't care bits are added to the LUT, so that the
* output is always 1 on don't care combinations.
*/
bool cut_simplifier::add_dont_care(cut const & c) {
uint64_t dc = 0;
for (unsigned i = 0; i < c.size(); ++i) {
for (unsigned j = i + 1; j < c.size(); ++j) {
bin_rel p(c[i], c[j]);
if (m_bins.find(p, p) && p.op != op_code::none) {
dc |= op2dont_care(i, j, p);
}
}
}
return (dc != c.dont_care()) && (c.add_dont_care(dc), true);
}
void cut_simplifier::collect_statistics(statistics& st) const {
st.update("sat-cut.eqs", m_stats.m_num_eqs);
st.update("sat-cut.cuts", m_stats.m_num_cuts);
st.update("sat-cut.ands", m_stats.m_num_ands);
st.update("sat-cut.ites", m_stats.m_num_ites);
st.update("sat-cut.xors", m_stats.m_num_xors);
st.update("sat-cut.xands", m_stats.m_xands);
st.update("sat-cut.xites", m_stats.m_xites);
st.update("sat-cut.xxors", m_stats.m_xxors);
st.update("sat-cut.xluts", m_stats.m_xluts);
st.update("sat-cut.dc-reduce", m_stats.m_num_dont_care_reductions);
}
void cut_simplifier::validate_unit(literal lit) {
if (!m_config.m_validate_lemmas) return;
ensure_validator();
m_validator->validate(1, &lit);
}
void cut_simplifier::validate_eq(literal a, literal b) {
if (!m_config.m_validate_lemmas) return;
ensure_validator();
literal lits1[2] = { a, ~b };
literal lits2[2] = { ~a, b };
m_validator->validate(2, lits1);
m_validator->validate(2, lits2);
}
}