All Downloads are FREE. Search and download functionalities are using the official Maven repository.

z3-z3-4.13.0.src.opt.opt_sls_solver.h Maven / Gradle / Ivy

The newest version!
/*++
Copyright (c) 2014 Microsoft Corporation

Module Name:

    opt_sls_solver.h

Abstract:

    Wraps a solver with SLS for improving a solution using an objective function.

Author:

    Nikolaj Bjorner (nbjorner) 2014-4-18

Notes:

   
--*/
#pragma once

#include "solver/solver_na2as.h"
#include "tactic/arith/card2bv_tactic.h"
#include "tactic/core/nnf_tactic.h"
#include "opt/pb_sls.h"
#include "tactic/sls/bvsls_opt_engine.h"


namespace opt {
    
    class sls_solver : public solver_na2as {
        ast_manager&     m;
        ref      m_solver;
        scoped_ptr m_bvsls;
        scoped_ptr      m_pbsls;
        pb::card_pb_rewriter m_pb2bv;
        vector m_weights;
        expr_ref_vector  m_soft;
        model_ref        m_model;
        params_ref       m_params;
        symbol           m_engine; 
    public:
        sls_solver(ast_manager & m, solver* s, 
                   expr_ref_vector const& soft, 
                   vector const& weights, 
                   params_ref & p):
            solver_na2as(m),
            m(m),
            m_solver(s),
            m_bvsls(0),
            m_pbsls(0),
            m_pb2bv(m),
            m_weights(weights),
            m_soft(soft)
        {            
            updt_params(p);
        }
        virtual ~sls_solver() = default;

        virtual void updt_params(params_ref & p) {
            m_solver->updt_params(p);
            m_params.copy(p);
            opt_params _p(p);
            m_engine = _p.sls_engine();
        }
        virtual void collect_param_descrs(param_descrs & r) {
            m_solver->collect_param_descrs(r);
        }
        virtual void collect_statistics(statistics & st) const {
            m_solver->collect_statistics(st);
            if (m_bvsls) m_bvsls->collect_statistics(st);
            if (m_pbsls) m_pbsls->collect_statistics(st);
        }
        virtual void assert_expr(expr * t) {
            m_solver->assert_expr(t);
        }
        virtual void get_unsat_core(ptr_vector & r) {
            m_solver->get_unsat_core(r);
        }
        virtual void get_model(model_ref & m) {
            m = m_model;
        }
        virtual proof * get_proof() {
            return m_solver->get_proof();
        }
        virtual std::string reason_unknown() const {
            return m_solver->reason_unknown();
        }
        virtual void get_labels(svector & r) {
            m_solver->get_labels(r);
        }
        virtual void set_progress_callback(progress_callback * callback) {
            m_solver->set_progress_callback(callback);
        }
        virtual unsigned get_num_assertions() const {
            return m_solver->get_num_assertions();
        }
        virtual expr * get_assertion(unsigned idx) const {
            return m_solver->get_assertion(idx);
        }
        virtual void display(std::ostream & out) const {
            m_solver->display(out);
            // if (m_bvsls) m_bvsls->display(out);
        }

        void opt(model_ref& mdl) {
            if (m_engine == symbol("pb")) {
                pbsls_opt(mdl);
            }
            else {
                bvsls_opt(mdl);
            }
        }

        static expr_ref soft2bv(expr_ref_vector const& soft, vector const& weights) {
            ast_manager& m = soft.get_manager();
            pb::card_pb_rewriter pb2bv(m);
            rational upper(1);
            expr_ref objective(m);
            for (unsigned i = 0; i < weights.size(); ++i) {
                upper += weights[i];
            }
            expr_ref zero(m), tmp(m);
            bv_util bv(m);
            expr_ref_vector es(m);
            rational num = numerator(upper);
            rational den = denominator(upper);
            rational maxval = num*den;
            unsigned bv_size = maxval.get_num_bits();
            zero = bv.mk_numeral(rational(0), bv_size);
            for (unsigned i = 0; i < soft.size(); ++i) {
                pb2bv(soft[i], tmp);
                es.push_back(m.mk_ite(tmp, bv.mk_numeral(den*weights[i], bv_size), zero));
            }
            if (es.empty()) {
                objective = bv.mk_numeral(0, bv_size);
            }
            else {
                objective = es[0].get();
                for (unsigned i = 1; i < es.size(); ++i) {
                    objective = bv.mk_bv_add(objective, es[i].get());
                }
            }
            return objective;
        }

    protected:
        typedef bvsls_opt_engine::optimization_result opt_result;

        virtual lbool check_sat_core(unsigned num_assumptions, expr * const * assumptions) {            
            lbool r = m_solver->check_sat(num_assumptions, assumptions);
            if (r == l_true) {
                m_solver->get_model(m_model);
                opt(m_model);
            }
            return r;
        }
        virtual void push_core() {
            m_solver->push();
        }
        virtual void pop_core(unsigned n) {
            m_solver->pop(n);
        }


    private:
        // convert soft constraints to bit-vector objective.

        void assertions2sls() {
            expr_ref tmp(m);
            goal_ref g(alloc(goal, m, true, false));
            for (unsigned i = 0; i < m_solver->get_num_assertions(); ++i) {
                m_pb2bv(m_solver->get_assertion(i), tmp);                
                g->assert_expr(tmp);
            }
            TRACE("opt", g->display(tout););
            tactic_ref simplify = mk_nnf_tactic(m);
            proof_converter_ref pc;
            expr_dependency_ref core(m);
            goal_ref_buffer result;
            model_converter_ref model_converter;
            (*simplify)(g, result, model_converter, pc, core);
            SASSERT(result.size() == 1);
            goal* r = result[0];
            for (unsigned i = 0; i < r->size(); ++i) {
                m_bvsls->assert_expr(r->form(i));
            }
            TRACE("opt", m_bvsls->display(tout););
        }

        void pbsls_opt(model_ref& mdl) {
            if (m_pbsls) {
                m_pbsls->reset();
            }
            else {
                m_pbsls = alloc(smt::pb_sls, m);
            }
            m_pbsls->set_model(mdl);
            m_pbsls->updt_params(m_params);
            for (unsigned i = 0; i < m_solver->get_num_assertions(); ++i) {
                m_pbsls->add(m_solver->get_assertion(i));
            }
            for (unsigned i = 0; i < m_soft.size(); ++i) {
                m_pbsls->add(m_soft[i].get(), m_weights[i]);
            }
            (*m_pbsls.get())();
            m_pbsls->get_model(m_model);            
            mdl = m_model.get();
        }

        void bvsls_opt(model_ref& mdl) {
            m_bvsls = alloc(bvsls_opt_engine, m, m_params);            
            assertions2sls();
            expr_ref objective = soft2bv(m_soft, m_weights);
            TRACE("opt", tout << objective << "\n";);
            opt_result res(m);
            res.is_sat = l_undef;
            try {
                res = m_bvsls->optimize(objective, mdl, true);
            }
            catch (...) {
                
            }
            SASSERT(res.is_sat == l_true || res.is_sat == l_undef);
            if (res.is_sat == l_true) {
                m_bvsls->get_model(m_model);
                mdl = m_model.get();
            }
        }

    };
}





© 2015 - 2024 Weber Informatics LLC | Privacy Policy