All Downloads are FREE. Search and download functionalities are using the official Maven repository.

z3-z3-4.13.0.src.ast.pattern.pattern_inference.h Maven / Gradle / Ivy

The newest version!
/*++
Copyright (c) 2006 Microsoft Corporation

Module Name:

    pattern_inference.h

Abstract:

    

Author:

    Leonardo de Moura (leonardo) 2006-12-08.

Revision History:

--*/
#pragma once

#include "ast/ast.h"
#include "ast/rewriter/rewriter.h"
#include "ast/rewriter/rewriter_def.h"
#include "params/pattern_inference_params.h"
#include "util/vector.h"
#include "util/uint_set.h"
#include "util/nat_set.h"
#include "util/obj_hashtable.h"
#include "util/obj_pair_hashtable.h"
#include "util/map.h"
#include "ast/pattern/expr_pattern_match.h"

/**
   \brief A pattern p_1 is smaller than a pattern p_2 iff 
   every instance of p_2 is also an instance of p_1.

   Example: f(X) is smaller than f(g(X)) because
   every instance of f(g(X)) is also an instance of f(X).
*/
class smaller_pattern {
    ptr_vector m_bindings;

    typedef std::pair expr_pair;
    typedef obj_pair_hashtable cache;

    svector m_todo;
    cache              m_cache;
    void save(expr * p1, expr * p2);
    bool process(expr * p1, expr * p2);

public:

    smaller_pattern() = default;

    smaller_pattern & operator=(smaller_pattern const &) = delete;

    bool operator()(unsigned num_bindings, expr * p1, expr * p2);
};

class pattern_inference_cfg :  public default_rewriter_cfg {
    ast_manager&               m;
    pattern_inference_params const & m_params;
    family_id                  m_bfid;
    family_id                  m_afid;
    svector         m_forbidden;
    obj_hashtable   m_preferred;        
    smaller_pattern            m_le;
    unsigned                   m_num_bindings;
    unsigned                   m_num_no_patterns;
    expr * const *             m_no_patterns;
    bool                       m_nested_arith_only;
    bool                       m_block_loop_patterns;
    bool                       m_decompose_patterns;

    struct info {
        uint_set m_free_vars;
        unsigned m_size;
        info(uint_set const & vars, unsigned size):
            m_free_vars(vars),
            m_size(size) {
        }
        info():
            m_free_vars(),
            m_size(0) {
        }
    };

    typedef obj_map expr2info;

    expr2info                 m_candidates_info; // candidate -> set of free vars + size
    app_ref_vector            m_candidates;

    ptr_vector           m_tmp1;
    ptr_vector           m_tmp2;
    ptr_vector           m_todo;

    // Compare candidates patterns based on their usefulness
    // p1 < p2 if
    //  - p1 has more free variables than p2
    //  - p1 and p2 has the same number of free variables,
    //    and p1 is smaller than p2.
    struct pattern_weight_lt {
        expr2info & m_candidates_info;
        pattern_weight_lt(expr2info & i):
            m_candidates_info(i) {
        }
        bool operator()(expr * n1, expr * n2) const;
    };

    pattern_weight_lt          m_pattern_weight_lt;

    //
    // Functor for collecting candidates.
    //
    class collect {
        struct entry {
            expr *    m_node;
            unsigned  m_delta;
            entry():m_node(nullptr), m_delta(0) {}
            entry(expr * n, unsigned d):m_node(n), m_delta(d) {}
            unsigned hash() const { 
                return hash_u_u(m_node->get_id(), m_delta);
            }
            bool operator==(entry const & e) const { 
                return m_node == e.m_node && m_delta == e.m_delta; 
            }
        };
        
        struct info {
            expr_ref    m_node;
            uint_set    m_free_vars;
            unsigned    m_size;
            info(ast_manager & m, expr * n, uint_set const & vars, unsigned sz):
                m_node(n, m), m_free_vars(vars), m_size(sz) {}
        };
        
        ast_manager &            m;
        pattern_inference_cfg &     m_owner;
        family_id                m_afid;
        unsigned                 m_num_bindings;
        typedef map, default_eq > cache;
        cache                    m_cache;
        ptr_vector         m_info;
        svector           m_todo;

        void visit(expr * n, unsigned delta, bool & visited);
        bool visit_children(expr * n, unsigned delta);
        void save(expr * n, unsigned delta, info * i);
        void save_candidate(expr * n, unsigned delta);
        void reset();
    public:
        collect(ast_manager & m, pattern_inference_cfg & o):m(m), m_owner(o), m_afid(m.mk_family_id("arith")) {}
        void operator()(expr * n, unsigned num_bindings);
    };

    collect                    m_collect;
    
    void add_candidate(app * n, uint_set const & s, unsigned size);

    void filter_looping_patterns(ptr_vector & result);

    bool has_preferred_patterns(ptr_vector & candidate_patterns, app_ref_buffer & result);

    void filter_bigger_patterns(ptr_vector const & patterns, ptr_vector & result);

    class contains_subpattern {
        pattern_inference_cfg & m_owner;
        nat_set              m_already_processed; 
        ptr_vector     m_todo;
        void save(expr * n);
    public:
        contains_subpattern(pattern_inference_cfg & owner):
            m_owner(owner) {}
        bool operator()(expr * n);
    };

    contains_subpattern        m_contains_subpattern;
        
    bool contains_subpattern(expr * n);
    
    struct pre_pattern {
        ptr_vector  m_exprs;     // elements of the pattern.
        uint_set         m_free_vars; // set of free variables in m_exprs
        unsigned         m_idx;       // idx of the next candidate to process.
        pre_pattern():
            m_idx(0) {
        }
    };

    ptr_vector      m_pre_patterns;
    expr_pattern_match           m_database;

    ptr_buffer m_args;
    app* mk_pattern(app* candidate);

    void candidates2unary_patterns(ptr_vector const & candidate_patterns,
                                   ptr_vector & remaining_candidate_patterns,
                                   app_ref_buffer & result);
    
    void candidates2multi_patterns(unsigned max_num_patterns, 
                                   ptr_vector const & candidate_patterns, 
                                   app_ref_buffer & result);
    
    void reset_pre_patterns();

    /**
       \brief All minimal unary patterns (i.e., expressions that
       contain all bound variables) are copied to result.  If there
       are unary patterns, then at most num_extra_multi_patterns multi
       patterns are created.  If there are no unary pattern, then at
       most 1 + num_extra_multi_patterns multi_patterns are created.
    */
    void mk_patterns(unsigned num_bindings,              // IN number of bindings.
                     expr * n,                           // IN node where the patterns are going to be extracted.
                     unsigned num_no_patterns,           // IN num. patterns that should not be used.
                     expr * const * no_patterns,         // IN patterns that should not be used.
                     app_ref_buffer & result);           // OUT result
    
public:
    pattern_inference_cfg(ast_manager & m, pattern_inference_params const & params);
    
    void register_forbidden_family(family_id fid) {
        SASSERT(fid != m_bfid);
        m_forbidden.push_back(fid);
    }

    /**
       \brief Register f as a preferred function symbol. The inference algorithm
       gives preference to patterns rooted by this kind of function symbol.
    */
    void register_preferred(func_decl * f) {
        m_preferred.insert(f);
    }

    bool reduce_quantifier(quantifier * old_q, 
                           expr * new_body, 
                           expr * const * new_patterns, 
                           expr * const * new_no_patterns,
                           expr_ref & result,
                           proof_ref & result_pr);

    void register_preferred(unsigned num, func_decl * const * fs) { for (unsigned i = 0; i < num; i++) register_preferred(fs[i]); }
    
    bool is_forbidden(func_decl const * decl) const {
        family_id fid = decl->get_family_id();
        if (fid == m_bfid && decl->get_decl_kind() != OP_TRUE && decl->get_decl_kind() != OP_FALSE) 
            return true;
        return std::find(m_forbidden.begin(), m_forbidden.end(), fid) != m_forbidden.end();
    }

    bool is_forbidden(app * n) const;
};

class pattern_inference_rw : public rewriter_tpl {
    pattern_inference_cfg m_cfg;
public:
    pattern_inference_rw(ast_manager& m, pattern_inference_params const & params);
};






© 2015 - 2024 Weber Informatics LLC | Privacy Policy