All Downloads are FREE. Search and download functionalities are using the official Maven repository.

z3-z3-4.13.0.src.ast.sls.sls_valuation.cpp Maven / Gradle / Ivy

The newest version!
/*++
Copyright (c) 2024 Microsoft Corporation

Module Name:

    sls_valuation.cpp

Abstract:

    A Stochastic Local Search (SLS) engine
    Uses invertibility conditions, 
    interval annotations
    don't care annotations

Author:

    Nikolaj Bjorner (nbjorner) 2024-02-07
    
--*/

#include "ast/sls/sls_valuation.h"

namespace bv {

    void bvect::set_bw(unsigned bw) {
        this->bw = bw;
        nw = (bw + sizeof(digit_t) * 8 - 1) / (8 * sizeof(digit_t));
        mask = (1 << (bw % (8 * sizeof(digit_t)))) - 1;
        if (mask == 0)
            mask = ~(digit_t)0;
        reserve(nw + 1);     
    }

    bool operator==(bvect const& a, bvect const& b) {
        SASSERT(a.nw > 0);
        return 0 == mpn_manager().compare(a.data(), a.nw, b.data(), a.nw);
    }

    bool operator<(bvect const& a, bvect const& b) {
        SASSERT(a.nw > 0);       
        return mpn_manager().compare(a.data(), a.nw, b.data(), a.nw) < 0;
    }

    bool operator>(bvect const& a, bvect const& b) {
        SASSERT(a.nw > 0);
        return mpn_manager().compare(a.data(), a.nw, b.data(), a.nw) > 0;
    }

    bool operator<=(bvect const& a, bvect const& b) {
        SASSERT(a.nw > 0);
        return mpn_manager().compare(a.data(), a.nw, b.data(), a.nw) <= 0;
    }

    bool operator>=(bvect const& a, bvect const& b) {
        SASSERT(a.nw > 0);
        return mpn_manager().compare(a.data(), a.nw, b.data(), a.nw) >= 0;
    }

    std::ostream& operator<<(std::ostream& out, bvect const& v) {
        out << std::hex;
        bool nz = false;
        for (unsigned i = v.nw; i-- > 0;) {
            auto w = v[i];
            if (i + 1 == v.nw)
                w &= v.mask;
            if (nz)
                out << std::setw(8) << std::setfill('0') << w;
            else if (w != 0)
                out << w, nz = true;
        }
        if (!nz)
            out << "0";
        out << std::dec;
        return out;
    }

    rational bvect::get_value(unsigned nw) const {
        rational p(1), r(0);
        for (unsigned i = 0; i < nw; ++i) {
            r += p * rational((*this)[i]);
            p *= rational::power_of_two(8 * sizeof(digit_t));
        }
        return r;
    }

    sls_valuation::sls_valuation(unsigned bw) {
        set_bw(bw);
        m_lo.set_bw(bw);
        m_hi.set_bw(bw);
        m_bits.set_bw(bw);
        fixed.set_bw(bw);
        eval.set_bw(bw);
        // have lo, hi bits, fixed point to memory allocated within this of size num_bytes each allocated        
        for (unsigned i = 0; i < nw; ++i)
            m_lo[i] = 0, m_hi[i] = 0, m_bits[i] = 0, fixed[i] = 0, eval[i] = 0;
        fixed[nw - 1] = ~mask;
    }

    void sls_valuation::set_bw(unsigned b) {
        bw = b;
        nw = (bw + sizeof(digit_t) * 8 - 1) / (8 * sizeof(digit_t));
        mask = (1 << (bw % (8 * sizeof(digit_t)))) - 1;
        if (mask == 0)
            mask = ~(digit_t)0;
    }

    bool sls_valuation::commit_eval() { 
        for (unsigned i = 0; i < nw; ++i)
            if (0 != (fixed[i] & (m_bits[i] ^ eval[i])))
                return false;        
        if (!in_range(eval))
            return false;
        for (unsigned i = 0; i < nw; ++i) 
            m_bits[i] = eval[i]; 
        SASSERT(well_formed()); 
        return true;
    }

    bool sls_valuation::in_range(bvect const& bits) const {
        mpn_manager m;
        auto c = m.compare(m_lo.data(), nw, m_hi.data(), nw);
        SASSERT(!has_overflow(bits));
        // full range

        if (c == 0)
            return true;
        // lo < hi: then lo <= bits & bits < hi
        if (c < 0)
            return
            m.compare(m_lo.data(), nw, bits.data(), nw) <= 0 &&
            m.compare(bits.data(), nw, m_hi.data(), nw) < 0;
        // hi < lo: bits < hi or lo <= bits
        return
            m.compare(m_lo.data(), nw, bits.data(), nw) <= 0 ||
            m.compare(bits.data(), nw, m_hi.data(), nw) < 0;
    }

    //
    // largest dst <= src and dst is feasible
    // set dst := src & (~fixed | bits)
    // 
    // increment dst if dst < src by setting bits below msb(src & ~dst) to 1
    // 
    // if dst < lo < hi:
    //    return false
    // if lo < hi <= dst:
    //    set dst := hi - 1
    // if hi <= dst < lo
    //    set dst := hi - 1
    // 

    bool sls_valuation::get_at_most(bvect const& src, bvect& dst) const {
        SASSERT(!has_overflow(src));
        for (unsigned i = 0; i < nw; ++i)
            dst[i] = src[i] & (~fixed[i] | m_bits[i]);

        //
        // If dst < src, then find the most significant 
        // bit where src[idx] = 1, dst[idx] = 0
        // set dst[j] = bits_j | ~fixed_j for j < idx
        //
        for (unsigned i = nw; i-- > 0; ) {
            if (0 != (~dst[i] & src[i])) {
                auto idx = log2(~dst[i] & src[i]);
                auto mask = (1 << idx) - 1;
                dst[i] = (~fixed[i] & mask) | dst[i];
                for (unsigned j = i; j-- > 0; )
                    dst[j] = (~fixed[j] | m_bits[j]);
                break;
            }
        }
        SASSERT(!has_overflow(dst));
        return round_down(dst);
    }

    //
    // smallest dst >= src and dst is feasible with respect to this.
    // set dst := (src & ~fixed) | (fixed & bits)
    // 
    // decrement dst if dst > src by setting bits below msb to 0 unless fixed
    // 
    // if lo < hi <= dst
    //    return false
    // if dst < lo < hi:
    //    set dst := lo
    // if hi <= dst < lo
    //    set dst := lo
    // 
    bool sls_valuation::get_at_least(bvect const& src, bvect& dst) const {
        SASSERT(!has_overflow(src));
        for (unsigned i = 0; i < nw; ++i)
            dst[i] = (~fixed[i] & src[i]) | (fixed[i] & m_bits[i]);

        //
        // If dst > src, then find the most significant 
        // bit where src[idx] = 0, dst[idx] = 1
        // set dst[j] = dst[j] & fixed_j for j < idx
        //
        for (unsigned i = nw; i-- > 0; ) {
            if (0 != (dst[i] & ~src[i])) {
                auto idx = log2(dst[i] & ~src[i]);
                auto mask = (1 << idx);
                dst[i] = dst[i] & (fixed[i] | mask);
                for (unsigned j = i; j-- > 0; )
                    dst[j] = dst[j] & fixed[j];
                break;
            }
        }
        SASSERT(!has_overflow(dst));
        return round_up(dst);
    }

    bool sls_valuation::round_up(bvect& dst) const {
        if (m_lo < m_hi) {
            if (m_hi <= dst)
                return false;
            if (m_lo > dst)
                set(dst, m_lo);
        }
        else if (m_hi <= dst && m_lo > dst)
            set(dst, m_lo);
        SASSERT(!has_overflow(dst));
        return true;
    }

    bool sls_valuation::round_down(bvect& dst) const {
        if (m_lo < m_hi) {
            if (m_lo > dst)
                return false;
            if (m_hi <= dst) {
                set(dst, m_hi);
                sub1(dst);
            }
        }
        else if (m_hi <= dst && m_lo > dst) {
            set(dst, m_hi);
            sub1(dst);
        }
        SASSERT(well_formed());
        return true;
    }

    bool sls_valuation::set_random_at_most(bvect const& src, bvect& tmp, random_gen& r) {
        if (!get_at_most(src, tmp))
            return false;
        if (is_zero(tmp) || (0 == r() % 2))
            return try_set(tmp);

        set_random_below(tmp, r);
        // random value below tmp

        if (m_lo == m_hi || is_zero(m_lo) || m_lo <= tmp)
            return try_set(tmp);

        // for simplicity, bail out if we were not lucky
        return get_at_most(src, tmp) && try_set(tmp);  
    }

    bool sls_valuation::set_random_at_least(bvect const& src, bvect& tmp, random_gen& r) {
        if (!get_at_least(src, tmp))
            return false;
        if (is_ones(tmp) || (0 == r() % 2))
            return try_set(tmp);

        // random value at least tmp
        set_random_above(tmp, r);
       
        if (m_lo == m_hi || is_zero(m_hi) || m_hi > tmp)
            return try_set(tmp);

        // for simplicity, bail out if we were not lucky
        return get_at_least(src, tmp) && try_set(tmp);        
    }

    bool sls_valuation::set_random_in_range(bvect const& lo, bvect const& hi, bvect& tmp, random_gen& r) {
        if (0 == r() % 2) {
            if (!get_at_least(lo, tmp))
                return false;
            SASSERT(in_range(tmp));
            if (hi < tmp)
                return false;

            if (is_ones(tmp) || (0 == r() % 2))
                return try_set(tmp);
            set_random_above(tmp, r);
            round_down(tmp, [&](bvect const& t) { return hi >= t && in_range(t); });
            if (in_range(tmp) && lo <= tmp && hi >= tmp)
                return try_set(tmp);
            return get_at_least(lo, tmp) && hi >= tmp && try_set(tmp);
        }
        else {
            if (!get_at_most(hi, tmp))
                return false;
            SASSERT(in_range(tmp));
            if (lo > tmp)
                return false;
            if (is_zero(tmp) || (0 == r() % 2))
                return try_set(tmp);
            set_random_below(tmp, r);
            round_up(tmp, [&](bvect const& t) { return lo <= t && in_range(t); });
            if (in_range(tmp) && lo <= tmp && hi >= tmp)
                return try_set(tmp);
            return get_at_most(hi, tmp) && lo <= tmp && try_set(tmp);
        }
    }

    void sls_valuation::round_down(bvect& dst, std::function const& is_feasible) {      
        for (unsigned i = bw; !is_feasible(dst) && i-- > 0; )
            if (!fixed.get(i) && dst.get(i))
                dst.set(i, false);      
        repair_sign_bits(dst);
    }

    void sls_valuation::round_up(bvect& dst, std::function const& is_feasible) {
        for (unsigned i = 0; !is_feasible(dst) && i < bw; ++i)
            if (!fixed.get(i) && !dst.get(i))
                dst.set(i, true);
        repair_sign_bits(dst);
    }

    void sls_valuation::set_random_above(bvect& dst, random_gen& r) {
        for (unsigned i = 0; i < nw; ++i)
            dst[i] = dst[i] | (random_bits(r) & ~fixed[i]);
        repair_sign_bits(dst);
    }

    void sls_valuation::set_random_below(bvect& dst, random_gen& r) {
        if (is_zero(dst))
            return;
        unsigned n = 0, idx = UINT_MAX;
        for (unsigned i = 0; i < bw; ++i)
            if (dst.get(i) && !fixed.get(i) && (r() % ++n) == 0)
                idx = i;                

        if (idx == UINT_MAX)
            return;
        dst.set(idx, false);
        for (unsigned i = 0; i < idx; ++i) 
            if (!fixed.get(i))
                dst.set(i, r() % 2 == 0);
        repair_sign_bits(dst);
    }

    bool sls_valuation::set_repair(bool try_down, bvect& dst) {
        for (unsigned i = 0; i < nw; ++i)
            dst[i] = (~fixed[i] & dst[i]) | (fixed[i] & m_bits[i]);

        repair_sign_bits(dst);
        if (in_range(dst)) {
            set(eval, dst);
            return true;
        }
        bool repaired = false;
        dst.set_bw(bw);
        if (m_lo < m_hi) {
            for (unsigned i = bw; m_hi <= dst && !in_range(dst) && i-- > 0; )
                if (!fixed.get(i) && dst.get(i))
                    dst.set(i, false);
            for (unsigned i = 0; i < bw && dst < m_lo && !in_range(dst); ++i)
                if (!fixed.get(i) && !dst.get(i))
                    dst.set(i, true);        
        }
        else {
            for (unsigned i = 0; !in_range(dst) && i < bw; ++i)
                if (!fixed.get(i) && !dst.get(i))
                    dst.set(i, true);
            for (unsigned i = bw; !in_range(dst) && i-- > 0;)
                if (!fixed.get(i) && dst.get(i))
                    dst.set(i, false);
        }
        repair_sign_bits(dst);
        if (in_range(dst)) {
            set(eval, dst);
            repaired = true;
        }
        dst.set_bw(0);
        return repaired;
    }

    void sls_valuation::min_feasible(bvect& out) const {
        if (m_lo < m_hi) 
            m_lo.copy_to(nw, out);        
        else {
            for (unsigned i = 0; i < nw; ++i)
                out[i] = fixed[i] & m_bits[i];
        }
        repair_sign_bits(out);
        SASSERT(!has_overflow(out));
    }

    void sls_valuation::max_feasible(bvect& out) const {
        if (m_lo < m_hi) {
            m_hi.copy_to(nw, out);
            sub1(out);
        }
        else {
            for (unsigned i = 0; i < nw; ++i)
                out[i] = ~fixed[i] | m_bits[i];
        }
        repair_sign_bits(out);
        SASSERT(!has_overflow(out));
    }

    unsigned sls_valuation::msb(bvect const& src) const {
        SASSERT(!has_overflow(src));
        for (unsigned i = nw; i-- > 0; )
            if (src[i] != 0)
                return i * 8 * sizeof(digit_t) + log2(src[i]);
        return bw;
    }

    void sls_valuation::set_value(bvect& bits, rational const& n) {
        for (unsigned i = 0; i < bw; ++i)
            bits.set(i, n.get_bit(i));
        clear_overflow_bits(bits);
    }

    void sls_valuation::get(bvect& dst) const {
        m_bits.copy_to(nw, dst);
    }

    digit_t sls_valuation::random_bits(random_gen& rand) {
        digit_t r = 0;
        for (digit_t i = 0; i < sizeof(digit_t); ++i)
            r ^= rand() << (8 * i);
        return r;
    }

    void sls_valuation::get_variant(bvect& dst, random_gen& r) const {
        for (unsigned i = 0; i < nw; ++i)
            dst[i] = (random_bits(r) & ~fixed[i]) | (fixed[i] & m_bits[i]);
        repair_sign_bits(dst);
        clear_overflow_bits(dst);
    }

    void sls_valuation::repair_sign_bits(bvect& dst) const {
        if (m_signed_prefix == 0)
            return;
        bool sign = dst.get(bw - 1);
        for (unsigned i = bw; i-- >= bw - m_signed_prefix; ) {
            if (dst.get(i) != sign) {
                if (fixed.get(i)) {
                    for (unsigned i = bw; i-- >= bw - m_signed_prefix; )
                        if (!fixed.get(i))
                            dst.set(i, !sign);
                    return;
                }
                else
                    dst.set(i, sign);
            }
        }
    }

    //
    // new_bits != bits => ~fixed
    // 0 = (new_bits ^ bits) & fixed
    // also check that new_bits are in range
    //
    bool sls_valuation::can_set(bvect const& new_bits) const {
        SASSERT(!has_overflow(new_bits));
        for (unsigned i = 0; i < nw; ++i)
            if (0 != ((new_bits[i] ^ m_bits[i]) & fixed[i]))
                return false;
        return in_range(new_bits);
    }

    unsigned sls_valuation::to_nat(unsigned max_n) {
        bvect const& d = m_bits;
        SASSERT(!has_overflow(d));
        SASSERT(max_n < UINT_MAX / 2);
        unsigned p = 1;
        unsigned value = 0;
        for (unsigned i = 0; i < bw; ++i) {
            if (p >= max_n) {
                for (unsigned j = i; j < bw; ++j)
                    if (d.get(j))
                        return max_n;
                return value;
            }
            if (d.get(i))
                value += p;
            p <<= 1;
        }
        return value;
    }

    void sls_valuation::shift_right(bvect& out, unsigned shift) const {
        SASSERT(shift < bw);
        for (unsigned i = 0; i < bw; ++i)
            out.set(i, i + shift < bw ? m_bits.get(i + shift) : false);
        SASSERT(well_formed());
    }

    void sls_valuation::add_range(rational l, rational h) {
        
        l = mod(l, rational::power_of_two(bw));
        h = mod(h, rational::power_of_two(bw));
        if (h == l)
            return;

        //verbose_stream() << "[" << l << ", " << h << "[\n";
        //verbose_stream() << *this << "\n";

        if (m_lo == m_hi) {
            set_value(m_lo, l);
            set_value(m_hi, h);
        }
        else {            
            auto old_lo = lo();
            auto old_hi = hi();
            if (old_lo < old_hi) {
                if (old_lo < l && l < old_hi)
                    set_value(m_lo, l),
                    old_lo = l;
                if (old_hi < h && h < old_hi)
                    set_value(m_hi, h);
            }
            else {
                SASSERT(old_hi < old_lo);
                if (old_lo < l || l < old_hi)
                    set_value(m_lo, l),
                    old_lo = l;
                if (old_lo < h && h < old_hi)
                    set_value(m_hi, h);
                else if (old_hi < old_lo && (h < old_hi || old_lo < h))
                    set_value(m_hi, h);
            }
        }



        SASSERT(!has_overflow(m_lo));
        SASSERT(!has_overflow(m_hi));

        tighten_range();
        SASSERT(well_formed());
        // verbose_stream() << *this << "\n";
    }

    //
    // update bits based on ranges
    // tighten lo/hi based on fixed bits.
    //   lo[bit_i] != fixedbit[bit_i] 
    //     let bit_i be most significant bit position of disagreement.
    //     if fixedbit = 1, lo = 0, increment lo
    //     if fixedbit = 0, lo = 1, lo := fixed & bits
    //   (hi-1)[bit_i] != fixedbit[bit_i]
    //     if fixedbit = 0, hi-1 = 1, set hi-1 := 0, maximize below bit_i
    //     if fixedbit = 1, hi-1 = 0, hi := fixed & bits
    // tighten fixed bits based on lo/hi
    //  lo + 1 = hi -> set bits = lo
    //  lo < hi, set most significant bits based on hi
    //
    void sls_valuation::tighten_range() {

        // verbose_stream() << "tighten " << *this << "\n";
        if (m_lo == m_hi)
            return;

        if (!in_range(m_bits)) {
            // verbose_stream() << "not in range\n";
            bool compatible = true;
            for (unsigned i = 0; i < nw && compatible; ++i)
                compatible = 0 == (fixed[i] & (m_bits[i] ^ m_lo[i]));
            //verbose_stream() << (fixed[0] & (m_bits[0] ^ m_lo[0])) << "\n";
            //verbose_stream() << bw << " " << m_lo[0] << " " << m_bits[0] << "\n";
            if (compatible) {
                //verbose_stream() << "compatible\n";
                set(m_bits, m_lo);
            }
            else {
                bvect tmp(m_bits.nw);
                tmp.set_bw(bw);
                set(tmp, m_lo);
                unsigned max_diff = bw;
                for (unsigned i = 0; i < bw; ++i) {
                    if (fixed.get(i) && (m_bits.get(i) ^ m_lo.get(i))) 
                        max_diff = i;                    
                }
                SASSERT(max_diff != bw);

                for (unsigned i = 0; i <= max_diff; ++i)
                    tmp.set(i, fixed.get(i) && m_bits.get(i));

                bool found0 = false;
                for (unsigned i = max_diff + 1; i < bw; ++i) {
                    if (found0 || m_lo.get(i) || fixed.get(i))
                        tmp.set(i, m_lo.get(i) && fixed.get(i));
                    else {
                        tmp.set(i, true);
                        found0 = true;
                    }
                }
                set(m_bits, tmp);
            }
        }
        // update lo, hi to be feasible.
        
        for (unsigned i = bw; i-- > 0; ) {
            if (!fixed.get(i))
                continue;
            if (m_bits.get(i) == m_lo.get(i))
                continue;
            if (m_bits.get(i)) {
                m_lo.set(i, true);
                for (unsigned j = i; j-- > 0; )
                    m_lo.set(j, fixed.get(j) && m_bits.get(j));
            }
            else {
                for (unsigned j = bw; j-- > 0; )
                    m_lo.set(j, fixed.get(j) && m_bits.get(j));
            }
            break;
        }

        SASSERT(well_formed());
    }

    void sls_valuation::set_sub(bvect& out, bvect const& a, bvect const& b) const {
        digit_t c;
        mpn_manager().sub(a.data(), nw, b.data(), nw, out.data(), &c);
        clear_overflow_bits(out);
    }

    bool sls_valuation::set_add(bvect& out, bvect const& a, bvect const& b) const {
        digit_t c;
        mpn_manager().add(a.data(), nw, b.data(), nw, out.data(), nw + 1, &c);
        bool ovfl = out[nw] != 0 || has_overflow(out);
        clear_overflow_bits(out);
        return ovfl;
    }

    bool sls_valuation::set_mul(bvect& out, bvect const& a, bvect const& b, bool check_overflow) const {
        mpn_manager().mul(a.data(), nw, b.data(), nw, out.data());
        bool ovfl = false;
        if (check_overflow) {
            ovfl = has_overflow(out);
            for (unsigned i = nw; i < 2 * nw; ++i)
                ovfl |= out[i] != 0;
        }
        clear_overflow_bits(out);
        return ovfl;
    }

    bool sls_valuation::is_power_of2(bvect const& src) const {
        unsigned c = 0;
        for (unsigned i = 0; i < nw; ++i)
            c += get_num_1bits(src[i]);
        return c == 1;
    }


}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy