All Downloads are FREE. Search and download functionalities are using the official Maven repository.

z3-z3-4.13.0.src.muz.fp.dl_cmds.cpp Maven / Gradle / Ivy

The newest version!
/*++
Copyright (c) 2011 Microsoft Corporation

Module Name:

    dl_cmds.cpp

Abstract:
    Datalog commands for SMT2 front-end.

Author:

    Leonardo (leonardo) 2011-03-28

Notes:

--*/
#include "cmd_context/cmd_context.h"
#include "muz/fp/dl_cmds.h"
#include "muz/rel/dl_external_relation.h"
#include "muz/base/dl_context.h"
#include "muz/fp/dl_register_engine.h"
#include "ast/dl_decl_plugin.h"
#include "muz/rel/dl_instruction.h"
#include "muz/rel/dl_compiler.h"
#include "muz/base/dl_rule.h"
#include "ast/ast_pp.h"
#include "cmd_context/parametric_cmd.h"
#include "util/cancel_eh.h"
#include "util/scoped_ctrl_c.h"
#include "util/scoped_timer.h"
#include "util/trail.h"
#include "muz/base/fp_params.hpp"
#include


struct dl_context {
    scoped_ptr        m_fparams;
    params_ref                    m_params_ref;
    fp_params                     m_params;
    cmd_context &                 m_cmd;
    datalog::register_engine      m_register_engine;
    dl_collected_cmds*            m_collected_cmds;
    unsigned                      m_ref_count;
    datalog::dl_decl_plugin*      m_decl_plugin;
    scoped_ptr  m_context;
    trail_stack                   m_trail;

    fp_params const& get_params() {
        init();
        return m_context->get_params();
    }

    dl_context(cmd_context & ctx, dl_collected_cmds* collected_cmds):
        m_params(m_params_ref),
        m_cmd(ctx),
        m_collected_cmds(collected_cmds),
        m_ref_count(0),
        m_decl_plugin(nullptr),
        m_trail() {}

    void inc_ref() {
        ++m_ref_count;
    }

    void dec_ref() {
        --m_ref_count;
        if (0 == m_ref_count) {
            dealloc(this);
        }
    }

    smt_params& fparams() {
        if (!m_fparams) m_fparams = alloc(smt_params);
        return *m_fparams.get();
    }

    void init() {
        ast_manager& m = m_cmd.m();
        if (!m_context) {
            m_context = alloc(datalog::context, m, m_register_engine, fparams(), m_params_ref);
        }
        if (!m_decl_plugin) {
            symbol name("datalog_relation");
            if (m.has_plugin(name)) {
                m_decl_plugin = static_cast(m_cmd.m().get_plugin(m.mk_family_id(name)));
            }
            else {
                m_decl_plugin = alloc(datalog::dl_decl_plugin);
                m.register_plugin(symbol("datalog_relation"), m_decl_plugin);
            }
        }
    }

    void reset() {
        m_context = nullptr;
    }

    void register_predicate(func_decl* pred, unsigned num_kinds, symbol const* kinds) {
        if (m_collected_cmds) {
            m_collected_cmds->m_rels.push_back(pred);
            m_trail.push(push_back_vector(m_collected_cmds->m_rels));
        }
        dlctx().register_predicate(pred, false);
        dlctx().set_predicate_representation(pred, num_kinds, kinds);
    }

    void add_rule(expr * rule, symbol const& name, unsigned bound) {
        init();
        if (m_collected_cmds) {
            expr_ref rl = m_context->bind_vars(rule, true);
            m_collected_cmds->m_rules.push_back(rl);
            m_collected_cmds->m_names.push_back(name);
            m_trail.push(push_back_vector(m_collected_cmds->m_rules));
            m_trail.push(push_back_vector >(m_collected_cmds->m_names));
        }
        else {
        m_context->add_rule(rule, name, bound);
        }
    }

    bool collect_query(func_decl* q) {
        if (m_collected_cmds) {
            ast_manager& m = m_cmd.m();
            expr_ref qr(m);
            expr_ref_vector args(m);
            for (unsigned i = 0; i < q->get_arity(); ++i) {
                args.push_back(m.mk_var(i, q->get_domain(i)));
            }
            qr = m.mk_app(q, args.size(), args.data());
            qr = m_context->bind_vars(qr, false);
            m_collected_cmds->m_queries.push_back(qr);
            m_trail.push(push_back_vector(m_collected_cmds->m_queries));
            return true;
        }
        else {
            return false;
        }
    }

    void push() {
        m_trail.push_scope();
        dlctx().push();
    }

    void pop() {
        m_trail.pop_scope(1);
        dlctx().pop();
    }

    datalog::context & dlctx() {
        init();
        return *m_context;
    }
};


/**
   \brief rule command. It is also the owner of dl_context object.
*/
class dl_rule_cmd : public cmd {
    ref m_dl_ctx;
    mutable unsigned     m_arg_idx;
    expr*        m_t;
    symbol       m_name;
    unsigned     m_bound;
public:
    dl_rule_cmd(dl_context * dl_ctx):
        cmd("rule"),
        m_dl_ctx(dl_ctx),
        m_arg_idx(0),
        m_t(nullptr),
        m_bound(UINT_MAX) {}
    char const * get_usage() const override { return "(forall (q) (=> (and body) head)) :optional-name :optional-recursion-bound"; }
    char const * get_descr(cmd_context & ctx) const override { return "add a Horn rule."; }
    unsigned get_arity() const override { return VAR_ARITY; }
    cmd_arg_kind next_arg_kind(cmd_context & ctx) const override {
        switch(m_arg_idx) {
        case 0: return CPK_EXPR;
        case 1: return CPK_SYMBOL;
        case 2: return CPK_UINT;
        default: return CPK_SYMBOL;
        }
    }
    void set_next_arg(cmd_context & ctx, expr * t) override {
        m_t = t;
        m_arg_idx++;
    }
    void set_next_arg(cmd_context & ctx, symbol const & s) override {
        m_name = s;
        m_arg_idx++;
    }
    void set_next_arg(cmd_context & ctx, unsigned bound) override {
        m_bound = bound;
        m_arg_idx++;
    }
    void reset(cmd_context & ctx) override { m_dl_ctx->reset(); prepare(ctx); m_t = nullptr; }
    void prepare(cmd_context& ctx) override { m_arg_idx = 0; m_name = symbol::null; m_bound = UINT_MAX; }
    void finalize(cmd_context & ctx) override {
    }
    void execute(cmd_context & ctx) override {
        if (!m_t) throw cmd_exception("invalid rule, expected formula");
        m_dl_ctx->add_rule(m_t, m_name, m_bound);
    }
};

class dl_query_cmd : public parametric_cmd {
    ref m_dl_ctx;
    func_decl* m_target;
public:
    dl_query_cmd(dl_context * dl_ctx):
        parametric_cmd("query"),
        m_dl_ctx(dl_ctx),
        m_target(nullptr) {
    }
    char const * get_usage() const override { return "predicate"; }
    char const * get_main_descr() const override {
        return "pose a query to a predicate based on the Horn rules.";
    }

    cmd_arg_kind next_arg_kind(cmd_context & ctx) const override {
        if (m_target == nullptr) return CPK_FUNC_DECL;
        return parametric_cmd::next_arg_kind(ctx);
    }

    void set_next_arg(cmd_context & ctx, func_decl* t) override {
        m_target = t;
        if (t->get_family_id() != null_family_id) {
            throw cmd_exception("Invalid query argument, expected uninterpreted function name, but argument is interpreted");
        }
        datalog::context& dlctx = m_dl_ctx->dlctx();
        if (!dlctx.get_predicates().contains(t)) {
            throw cmd_exception("Invalid query argument, expected a predicate registered as a relation");
        }
    }

    void prepare(cmd_context & ctx) override {
        ctx.m(); // ensure manager is initialized.
        parametric_cmd::prepare(ctx);
        m_target   = nullptr;
    }

    void execute(cmd_context& ctx) override {
        if (m_target == nullptr) {
            throw cmd_exception("invalid query command, argument expected");
        }
        if (m_dl_ctx->collect_query(m_target)) {
            return;
        }
        datalog::context& dlctx = m_dl_ctx->dlctx();
        set_background(ctx);
        dlctx.updt_params(m_params);
        unsigned timeout   = ctx.params().m_timeout;
        unsigned rlimit    = ctx.params().rlimit();
        cancel_eh eh(ctx.m().limit());
        bool query_exn = false;
        lbool status = l_undef;
        {
            IF_VERBOSE(10, verbose_stream() << "(query)\n";);
            scoped_ctrl_c ctrlc(eh);
            scoped_timer timer(timeout, &eh);
            scoped_rlimit _rlimit(ctx.m().limit(), rlimit);
            cmd_context::scoped_watch sw(ctx);
            try {
                status = dlctx.rel_query(1, &m_target);
            }
            catch (z3_error & ex) {
                ctx.regular_stream() << "(error \"query failed: " << ex.msg() << "\")" << std::endl;
                print_statistics(ctx);
                throw ex;
            }
            catch (z3_exception& ex) {
                ctx.regular_stream() << "(error \"query failed: " << ex.msg() << "\")" << std::endl;
                query_exn = true;
            }
        }
        switch (status) {
        case l_false:
            ctx.regular_stream() << "unsat\n";
            print_certificate(ctx);
            break;
        case l_true:
            ctx.regular_stream() << "sat\n";
            print_answer(ctx);
            print_certificate(ctx);
            break;
        case l_undef:
            if (dlctx.get_status() == datalog::BOUNDED){
                ctx.regular_stream() << "bounded\n";
                print_certificate(ctx);
                break;
            }
            ctx.regular_stream() << "unknown\n";
            switch(dlctx.get_status()) {
            case datalog::INPUT_ERROR:
                ctx.regular_stream() << "input error\n";
                break;

            case datalog::MEMOUT:
                ctx.regular_stream() << "memory bounds exceeded\n";
                break;

            case datalog::TIMEOUT:
                ctx.regular_stream() << "timeout\n";
                break;

            case datalog::APPROX:
                ctx.regular_stream() << "approximated relations\n";
                break;

            case datalog::OK:
                (void)query_exn;
                break;

            case datalog::CANCELED:
                ctx.regular_stream() << "canceled\n";
                dlctx.display_profile(ctx.regular_stream());
                break;

            default:
                UNREACHABLE();
                break;
            }
            break;
        }
        dlctx.cleanup();
        print_statistics(ctx);
        m_target = nullptr;
    }

    void init_pdescrs(cmd_context & ctx, param_descrs & p) override {
        m_dl_ctx->dlctx().collect_params(p);
    }


private:
    void set_background(cmd_context& ctx) {
        datalog::context& dlctx = m_dl_ctx->dlctx();
        for (expr * e : ctx.assertions()) {
            dlctx.assert_expr(e);
        }
    }

    void print_answer(cmd_context& ctx) {
        if (m_dl_ctx->get_params().print_answer()) {
            datalog::context& dlctx = m_dl_ctx->dlctx();
            ast_manager& m = ctx.m();
            expr_ref query_result(dlctx.get_answer_as_formula(), m);
            sbuffer var_names;
            unsigned num_decls = 0;
            ctx.display(ctx.regular_stream(), query_result, 0, num_decls, "X", var_names);
            ctx.regular_stream() << std::endl;
        }
    }

    void print_statistics(cmd_context& ctx) {
        if (ctx.params().m_statistics) {
            statistics st;
            datalog::context& dlctx = m_dl_ctx->dlctx();
            dlctx.collect_statistics(st);
            st.update("time", ctx.get_seconds());
            st.display_smt2(ctx.regular_stream());
        }
    }

    void print_certificate(cmd_context& ctx) {
        if (m_dl_ctx->get_params().print_certificate()) {
            datalog::context& dlctx = m_dl_ctx->dlctx();
            dlctx.display_certificate(ctx.regular_stream());
            ctx.regular_stream() << "\n";
        }
    }
};

class dl_declare_rel_cmd : public cmd {
    ref    m_dl_ctx;
    unsigned         m_arg_idx;
    mutable unsigned m_query_arg_idx;
    symbol           m_rel_name;
    ptr_vector m_domain;
    svector  m_kinds;


public:
    dl_declare_rel_cmd(dl_context * dl_ctx):
        cmd("declare-rel"),
        m_dl_ctx(dl_ctx),
        m_domain(0) {}

    char const * get_usage() const override { return " ( ...) *"; }
    char const * get_descr(cmd_context & ctx) const override { return "declare new relation"; }
    unsigned get_arity() const override { return VAR_ARITY; }

    void prepare(cmd_context & ctx) override {
        ctx.m(); // ensure manager is initialized.
        m_arg_idx = 0;
        m_query_arg_idx = 0;
        m_domain.reset();
        m_kinds.reset();
    }
    cmd_arg_kind next_arg_kind(cmd_context & ctx) const override {
        switch(m_query_arg_idx++) {
        case 0: return CPK_SYMBOL;     // relation name
        case 1: return CPK_SORT_LIST;  // arguments
        default: return CPK_SYMBOL;    // optional representation specification
        }
    }
    void set_next_arg(cmd_context & ctx, unsigned num, sort * const * slist) override {
        m_domain.reset();
        m_domain.append(num, slist);
        m_arg_idx++;
    }
    void set_next_arg(cmd_context & ctx, symbol const & s) override {
        if(m_arg_idx==0) {
            m_rel_name = s;
        }
        else {
            SASSERT(m_arg_idx>1);
            m_kinds.push_back(s);
        }
        m_arg_idx++;
    }
    void execute(cmd_context & ctx) override {
        if(m_arg_idx<2) {
            throw cmd_exception("at least 2 arguments expected");
        }
        ast_manager& m = ctx.m();

        func_decl_ref pred(
            m.mk_func_decl(m_rel_name, m_domain.size(), m_domain.data(), m.mk_bool_sort()), m);
        ctx.insert(pred);
        m_dl_ctx->register_predicate(pred, m_kinds.size(), m_kinds.data());
    }

};

class dl_declare_var_cmd : public cmd {
    unsigned m_arg_idx;
    symbol   m_var_name;
    sort*    m_var_sort;
    ref   m_dl_ctx;
public:
    dl_declare_var_cmd(dl_context* dl_ctx):
        cmd("declare-var"),
        m_arg_idx(0),
        m_dl_ctx(dl_ctx)
    {}

    char const * get_usage() const override { return " "; }
    char const * get_descr(cmd_context & ctx) const override { return "declare constant as variable"; }
    unsigned get_arity() const override { return 2; }

    void prepare(cmd_context & ctx) override {
        ctx.m(); // ensure manager is initialized.
        m_arg_idx = 0;
    }
    cmd_arg_kind next_arg_kind(cmd_context & ctx) const override {
        SASSERT(m_arg_idx <= 1);
        if (m_arg_idx == 0) {
            return CPK_SYMBOL;
        }
        return CPK_SORT;
    }

    void set_next_arg(cmd_context & ctx, sort* s) override {
        m_var_sort = s;
        ++m_arg_idx;
    }

    void set_next_arg(cmd_context & ctx, symbol const & s) override {
        m_var_name = s;
        ++m_arg_idx;
    }

    void execute(cmd_context & ctx) override {
        ast_manager& m = ctx.m();
        func_decl_ref var(m.mk_func_decl(m_var_name, 0, static_cast(nullptr), m_var_sort), m);
        ctx.insert(var);
        m_dl_ctx->dlctx().register_variable(var);
    }
};


static void install_dl_cmds_aux(cmd_context& ctx, dl_collected_cmds* collected_cmds) {
    dl_context * dl_ctx = alloc(dl_context, ctx, collected_cmds);
    ctx.insert(alloc(dl_rule_cmd, dl_ctx));
    ctx.insert(alloc(dl_query_cmd, dl_ctx));
    ctx.insert(alloc(dl_declare_rel_cmd, dl_ctx));
    ctx.insert(alloc(dl_declare_var_cmd, dl_ctx));
}

void install_dl_cmds(cmd_context & ctx) {
    install_dl_cmds_aux(ctx, nullptr);
}

void install_dl_collect_cmds(dl_collected_cmds& collected_cmds, cmd_context & ctx) {
    install_dl_cmds_aux(ctx, &collected_cmds);
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy