All Downloads are FREE. Search and download functionalities are using the official Maven repository.

cvc5-cvc5-1.2.0.src.theory.quantifiers.sygus_sampler.cpp Maven / Gradle / Ivy

The newest version!
/******************************************************************************
 * Top contributors (to current version):
 *   Andrew Reynolds, Mathias Preiner, Gereon Kremer
 *
 * This file is part of the cvc5 project.
 *
 * Copyright (c) 2009-2024 by the authors listed in the file AUTHORS
 * in the top-level source directory and their institutional affiliations.
 * All rights reserved.  See the file COPYING in the top-level source
 * directory for licensing information.
 * ****************************************************************************
 *
 * Implementation of sygus_sampler.
 */

#include "theory/quantifiers/sygus_sampler.h"

#include 

#include "expr/dtype.h"
#include "expr/dtype_cons.h"
#include "expr/node_algorithm.h"
#include "options/base_options.h"
#include "options/quantifiers_options.h"
#include "printer/printer.h"
#include "theory/quantifiers/lazy_trie.h"
#include "theory/quantifiers/sygus/term_database_sygus.h"
#include "theory/rewriter.h"
#include "util/bitvector.h"
#include "util/random.h"
#include "util/rational.h"
#include "util/sampler.h"
#include "util/string.h"

using namespace cvc5::internal::kind;

namespace cvc5::internal {
namespace theory {
namespace quantifiers {

SygusSampler::SygusSampler(Env& env)
    : EnvObj(env), d_tds(nullptr), d_use_sygus_type(false), d_is_valid(false)
{
}

void SygusSampler::initialize(TypeNode tn,
                              const std::vector& vars,
                              unsigned nsamples,
                              bool unique_type_ids)
{
  d_tds = nullptr;
  d_use_sygus_type = false;
  d_is_valid = true;
  d_ftn = TypeNode::null();
  d_type_vars.clear();
  d_vars.clear();
  d_rvalue_cindices.clear();
  d_rvalue_null_cindices.clear();
  d_rstring_alphabet.clear();
  d_var_sygus_types.clear();
  d_const_sygus_types.clear();
  d_vars.insert(d_vars.end(), vars.begin(), vars.end());
  std::map type_to_type_id;
  unsigned type_id_counter = 0;
  for (const Node& sv : d_vars)
  {
    TypeNode svt = sv.getType();
    unsigned tnid = 0;
    if (unique_type_ids)
    {
      tnid = type_id_counter;
      type_id_counter++;
    }
    else
    {
      std::map::iterator itt = type_to_type_id.find(svt);
      if (itt == type_to_type_id.end())
      {
        type_to_type_id[svt] = type_id_counter;
        type_id_counter++;
      }
      else
      {
        tnid = itt->second;
      }
    }
    Trace("sygus-sample-debug")
        << "Type id for " << sv << " is " << tnid << std::endl;
    d_var_index[sv] = d_type_vars[tnid].size();
    d_type_vars[tnid].push_back(sv);
    d_type_ids[sv] = tnid;
  }
  initializeSamples(nsamples);
}

void SygusSampler::initializeSygus(TypeNode ftn, unsigned nsamples)
{
  d_is_valid = true;
  d_ftn = ftn;
  Assert(d_ftn.isDatatype());
  const DType& dt = d_ftn.getDType();
  Assert(dt.isSygus());

  Trace("sygus-sample") << "Register sampler for " << ftn << std::endl;

  d_vars.clear();
  d_type_vars.clear();
  d_var_index.clear();
  d_type_vars.clear();
  d_rvalue_cindices.clear();
  d_rvalue_null_cindices.clear();
  d_var_sygus_types.clear();
  // get the sygus variable list
  Node var_list = dt.getSygusVarList();
  if (!var_list.isNull())
  {
    for (const Node& sv : var_list)
    {
      d_vars.push_back(sv);
    }
  }
  // register sygus type
  registerSygusType(d_ftn);
  // Variables are associated with type ids based on the set of sygus types they
  // appear in.
  std::map var_to_type_id;
  unsigned type_id_counter = 0;
  for (const Node& sv : d_vars)
  {
    TypeNode svt = sv.getType();
    // is it equivalent to a previous variable?
    for (const auto& v : var_to_type_id)
    {
      Node svc = v.first;
      if (svc.getType() == svt)
      {
        if (d_var_sygus_types[sv].size() == d_var_sygus_types[svc].size())
        {
          bool success = true;
          for (unsigned t = 0, size = d_var_sygus_types[sv].size(); t < size;
               t++)
          {
            if (d_var_sygus_types[sv][t] != d_var_sygus_types[svc][t])
            {
              success = false;
              break;
            }
          }
          if (success)
          {
            var_to_type_id[sv] = var_to_type_id[svc];
          }
        }
      }
    }
    if (var_to_type_id.find(sv) == var_to_type_id.end())
    {
      var_to_type_id[sv] = type_id_counter;
      type_id_counter++;
    }
    unsigned tnid = var_to_type_id[sv];
    Trace("sygus-sample-debug")
        << "Type id for " << sv << " is " << tnid << std::endl;
    d_var_index[sv] = d_type_vars[tnid].size();
    d_type_vars[tnid].push_back(sv);
    d_type_ids[sv] = tnid;
  }

  initializeSamples(nsamples);
}

void SygusSampler::initializeSamples(unsigned nsamples)
{
  d_samples.clear();
  std::vector types;
  for (const Node& v : d_vars)
  {
    TypeNode vt = v.getType();
    types.push_back(vt);
    Trace("sygus-sample") << "  var #" << types.size() << " : " << v << " : "
                          << vt << std::endl;
  }
  std::map >::iterator> sts;
  if (options().quantifiers.sygusSampleGrammar)
  {
    for (unsigned j = 0, size = types.size(); j < size; j++)
    {
      sts[j] = d_var_sygus_types.find(d_vars[j]);
    }
  }

  unsigned nduplicates = 0;
  for (unsigned i = 0; i < nsamples; i++)
  {
    std::vector sample_pt;
    for (unsigned j = 0, size = types.size(); j < size; j++)
    {
      Node v = d_vars[j];
      Node r;
      if (options().quantifiers.sygusSampleGrammar)
      {
        // choose a random start sygus type, if possible
        if (sts[j] != d_var_sygus_types.end())
        {
          unsigned ntypes = sts[j]->second.size();
          if(ntypes > 0)
          {
            unsigned index = Random::getRandom().pick(0, ntypes - 1);
            if (index < ntypes)
            {
              // currently hard coded to 0.0, 0.5
              r = getSygusRandomValue(sts[j]->second[index], 0.0, 0.5);
            }
          }
        }
      }
      if (r.isNull())
      {
        r = getRandomValue(types[j]);
        if (r.isNull())
        {
          d_is_valid = false;
        }
      }
      sample_pt.push_back(r);
    }
    if (d_samples_trie.add(sample_pt))
    {
      if (TraceIsOn("sygus-sample"))
      {
        Trace("sygus-sample") << "Sample point #" << i << " : ";
        for (const Node& r : sample_pt)
        {
          Trace("sygus-sample") << r << " ";
        }
        Trace("sygus-sample") << std::endl;
      }
      d_samples.push_back(sample_pt);
    }
    else
    {
      i--;
      nduplicates++;
      if (nduplicates == nsamples * 10)
      {
        Trace("sygus-sample")
            << "...WARNING: excessive duplicates, cut off sampling at " << i
            << "/" << nsamples << " points." << std::endl;
        break;
      }
    }
  }

  d_trie.clear();
}

bool SygusSampler::PtTrie::add(std::vector& pt)
{
  PtTrie* curr = this;
  for (unsigned i = 0, size = pt.size(); i < size; i++)
  {
    curr = &(curr->d_children[pt[i]]);
  }
  bool retVal = curr->d_children.empty();
  curr = &(curr->d_children[Node::null()]);
  return retVal;
}

Node SygusSampler::registerTerm(Node n, bool forceKeep)
{
  if (!d_is_valid)
  {
    // do nothing
    return n;
  }
  TypeNode tn = n.getType();
  // cache based on the (original) type of n
  return d_trie[tn].add(n, this, 0, d_samples.size(), forceKeep);
}

bool SygusSampler::isContiguous(Node n)
{
  // compute free variables in n
  std::vector fvs;
  computeFreeVariables(n, fvs);
  // compute contiguous condition
  for (const std::pair >& p : d_type_vars)
  {
    bool foundNotFv = false;
    for (const Node& v : p.second)
    {
      bool hasFv = std::find(fvs.begin(), fvs.end(), v) != fvs.end();
      if (!hasFv)
      {
        foundNotFv = true;
      }
      else if (foundNotFv)
      {
        return false;
      }
    }
  }
  return true;
}

void SygusSampler::computeFreeVariables(Node n, std::vector& fvs)
{
  std::unordered_set visited;
  std::unordered_set::iterator it;
  std::vector visit;
  TNode cur;
  visit.push_back(n);
  do
  {
    cur = visit.back();
    visit.pop_back();
    if (visited.find(cur) == visited.end())
    {
      visited.insert(cur);
      if (cur.isVar())
      {
        if (d_var_index.find(cur) != d_var_index.end())
        {
          fvs.push_back(cur);
        }
      }
      for (const Node& cn : cur)
      {
        visit.push_back(cn);
      }
    }
  } while (!visit.empty());
}

bool SygusSampler::isOrdered(Node n) { return checkVariables(n, true, false); }

bool SygusSampler::isLinear(Node n) { return checkVariables(n, false, true); }

bool SygusSampler::checkVariables(Node n, bool checkOrder, bool checkLinear)
{
  // compute free variables in n for each type
  std::map > fvs;

  std::unordered_set visited;
  std::unordered_set::iterator it;
  std::vector visit;
  TNode cur;
  visit.push_back(n);
  do
  {
    cur = visit.back();
    visit.pop_back();
    if (visited.find(cur) == visited.end())
    {
      visited.insert(cur);
      if (cur.isVar())
      {
        std::map::iterator itv = d_var_index.find(cur);
        if (itv != d_var_index.end())
        {
          if (checkOrder)
          {
            unsigned tnid = d_type_ids[cur];
            // if this variable is out of order
            if (itv->second != fvs[tnid].size())
            {
              return false;
            }
            fvs[tnid].push_back(cur);
          }
          if (checkLinear)
          {
            if (expr::hasSubtermMulti(n, cur))
            {
              return false;
            }
          }
        }
      }
      for (unsigned j = 0, nchildren = cur.getNumChildren(); j < nchildren; j++)
      {
        visit.push_back(cur[(nchildren - j) - 1]);
      }
    }
  } while (!visit.empty());
  return true;
}

bool SygusSampler::containsFreeVariables(Node a, Node b, bool strict)
{
  // compute free variables in a
  std::vector fvs;
  computeFreeVariables(a, fvs);
  std::vector fv_found;

  std::unordered_set visited;
  std::unordered_set::iterator it;
  std::vector visit;
  TNode cur;
  visit.push_back(b);
  do
  {
    cur = visit.back();
    visit.pop_back();
    if (visited.find(cur) == visited.end())
    {
      visited.insert(cur);
      if (cur.isVar())
      {
        if (std::find(fvs.begin(), fvs.end(), cur) == fvs.end())
        {
          return false;
        }
        else if (strict)
        {
          if (fv_found.size() + 1 == fvs.size())
          {
            return false;
          }
          // cur should only be visited once
          Assert(std::find(fv_found.begin(), fv_found.end(), cur)
                 == fv_found.end());
          fv_found.push_back(cur);
        }
      }
      for (const Node& cn : cur)
      {
        visit.push_back(cn);
      }
    }
  } while (!visit.empty());
  return true;
}

void SygusSampler::getVariables(std::vector& vars) const
{
  vars.insert(vars.end(), d_vars.begin(), d_vars.end());
}

const std::vector& SygusSampler::getSamplePoint(size_t index) const
{
  Assert(index < d_samples.size());
  return d_samples[index];
}

void SygusSampler::addSamplePoint(const std::vector& pt)
{
  Assert(pt.size() == d_vars.size());
  d_samples.push_back(pt);
}

Node SygusSampler::evaluate(Node n, unsigned index)
{
  Assert(index < d_samples.size());
  // do beta-reductions in n first
  n = d_env.getRewriter()->rewrite(n);
  // use efficient rewrite for substitution + rewrite
  Node ev = d_env.evaluate(n, d_vars, d_samples[index], true);
  Assert(!ev.isNull());
  Trace("sygus-sample-ev") << "Evaluate ( " << n << ", " << index << " ) -> ";
  Trace("sygus-sample-ev") << ev << std::endl;
  return ev;
}

int SygusSampler::getDiffSamplePointIndex(Node a, Node b)
{
  for (unsigned i = 0, nsamp = d_samples.size(); i < nsamp; i++)
  {
    Node ae = evaluate(a, i);
    Node be = evaluate(b, i);
    if (ae != be)
    {
      return i;
    }
  }
  return -1;
}

Node SygusSampler::getRandomValue(TypeNode tn)
{
  NodeManager* nm = nodeManager();
  if (tn.isBoolean())
  {
    return nm->mkConst(Random::getRandom().pickWithProb(0.5));
  }
  else if (tn.isBitVector())
  {
    unsigned sz = tn.getBitVectorSize();
    return nm->mkConst(Sampler::pickBvUniform(sz));
  }
  else if (tn.isFloatingPoint())
  {
    unsigned e = tn.getFloatingPointExponentSize();
    unsigned s = tn.getFloatingPointSignificandSize();
    return nm->mkConst(options().quantifiers.sygusSampleFpUniform
                           ? Sampler::pickFpUniform(e, s)
                           : Sampler::pickFpBiased(e, s));
  }
  else if (tn.isString() || tn.isInteger())
  {
    // if string, determine the alphabet
    if (tn.isString() && d_rstring_alphabet.empty())
    {
      Trace("sygus-sample-str-alpha")
          << "Setting string alphabet..." << std::endl;
      std::unordered_set alphas;
      for (const std::pair >& c :
           d_const_sygus_types)
      {
        if (c.first.getType().isString())
        {
          Trace("sygus-sample-str-alpha")
              << "...have constant " << c.first << std::endl;
          Assert(c.first.isConst());
          std::vector svec = c.first.getConst().getVec();
          for (unsigned ch : svec)
          {
            alphas.insert(ch);
          }
        }
      }
      // can limit to 1 extra characters beyond those in the grammar (2 if
      // there are none in the grammar)
      unsigned num_fresh_char = alphas.empty() ? 2 : 1;
      unsigned fresh_char = 0;
      for (unsigned i = 0; i < num_fresh_char; i++)
      {
        while (alphas.find(fresh_char) != alphas.end())
        {
          fresh_char++;
        }
        alphas.insert(fresh_char);
      }
      Trace("sygus-sample-str-alpha")
          << "Sygus sampler: limit strings alphabet to : " << std::endl
          << " ";
      for (unsigned ch : alphas)
      {
        d_rstring_alphabet.push_back(ch);
        Trace("sygus-sample-str-alpha") << " \\u" << ch;
      }
      Trace("sygus-sample-str-alpha") << std::endl;
    }

    std::vector vec;
    double ext_freq = .5;
    unsigned base = tn.isString() ? d_rstring_alphabet.size() : 10;
    while (Random::getRandom().pickWithProb(ext_freq))
    {
      // add a digit
      unsigned digit = Random::getRandom().pick(0, base - 1);
      if (tn.isString())
      {
        digit = d_rstring_alphabet[digit];
      }
      vec.push_back(digit);
    }
    if (tn.isString())
    {
      return nm->mkConst(String(vec));
    }
    else if (tn.isInteger())
    {
      Rational baser(base);
      Rational curr(1);
      std::vector sum;
      for (unsigned j = 0, size = vec.size(); j < size; j++)
      {
        Node digit = nm->mkConstInt(Rational(vec[j]) * curr);
        sum.push_back(digit);
        curr = curr * baser;
      }
      Node ret;
      if (sum.empty())
      {
        ret = nm->mkConstInt(Rational(0));
      }
      else if (sum.size() == 1)
      {
        ret = sum[0];
      }
      else
      {
        ret = nm->mkNode(Kind::ADD, sum);
      }

      if (Random::getRandom().pickWithProb(0.5))
      {
        // negative
        ret = nm->mkNode(Kind::NEG, ret);
      }
      ret = d_env.getRewriter()->rewrite(ret);
      Assert(ret.isConst());
      Assert(ret.getType()==tn);
      return ret;
    }
  }
  else if (tn.isReal())
  {
    Node s = getRandomValue(nm->integerType());
    Node r = getRandomValue(nm->integerType());
    if (!s.isNull() && !r.isNull())
    {
      Rational sr = s.getConst();
      Rational rr = r.getConst();
      if (rr.sgn() == 0)
      {
        return nm->mkConstReal(s.getConst());
      }
      return nm->mkConstReal(sr / rr);
    }
  }
  // default: use type enumerator
  unsigned counter = 0;
  while (Random::getRandom().pickWithProb(0.5))
  {
    counter++;
  }
  Node ret = d_tenum.getEnumerateTerm(tn, counter);
  if (ret.isNull())
  {
    // beyond bounds, return the first
    ret = d_tenum.getEnumerateTerm(tn, 0);
  }
  return ret;
}

Node SygusSampler::getSygusRandomValue(TypeNode tn,
                                       double rchance,
                                       double rinc,
                                       unsigned depth)
{
  if (!tn.isDatatype())
  {
    return getRandomValue(tn);
  }
  const DType& dt = tn.getDType();
  if (!dt.isSygus())
  {
    return getRandomValue(tn);
  }
  Assert(d_rvalue_cindices.find(tn) != d_rvalue_cindices.end());
  Trace("sygus-sample-grammar")
      << "Sygus random value " << tn << ", depth = " << depth
      << ", rchance = " << rchance << std::endl;
  // check if we terminate on this call
  // we refuse to enumerate terms of 10+ depth as a hard limit
  bool terminate = Random::getRandom().pickWithProb(rchance) || depth >= 10;
  // if we terminate, only nullary constructors can be chosen
  std::vector& cindices =
      terminate ? d_rvalue_null_cindices[tn] : d_rvalue_cindices[tn];
  unsigned ncons = cindices.size();
  // select a random constructor, or random value when index=ncons.
  unsigned index = Random::getRandom().pick(0, ncons);
  Trace("sygus-sample-grammar")
      << "Random index 0..." << ncons << " was : " << index << std::endl;
  if (index < ncons)
  {
    Trace("sygus-sample-grammar")
        << "Recurse constructor index #" << index << std::endl;
    unsigned cindex = cindices[index];
    Assert(cindex < dt.getNumConstructors());
    const DTypeConstructor& dtc = dt[cindex];
    // more likely to terminate in recursive calls
    double rchance_new = rchance + (1.0 - rchance) * rinc;
    std::map pre;
    bool success = true;
    // generate random values for all arguments
    for (unsigned i = 0, nargs = dtc.getNumArgs(); i < nargs; i++)
    {
      TypeNode tnc = d_tds->getArgType(dtc, i);
      Node c = getSygusRandomValue(tnc, rchance_new, rinc, depth + 1);
      if (c.isNull())
      {
        success = false;
        Trace("sygus-sample-grammar") << "...fail." << std::endl;
        break;
      }
      Trace("sygus-sample-grammar")
          << "  child #" << i << " : " << c << std::endl;
      pre[i] = c;
    }
    if (success)
    {
      Trace("sygus-sample-grammar") << "mkGeneric" << std::endl;
      Node ret = d_tds->mkGeneric(dt, cindex, pre);
      Trace("sygus-sample-grammar") << "...returned " << ret << std::endl;
      ret = d_env.getRewriter()->rewrite(ret);
      Trace("sygus-sample-grammar") << "...after rewrite " << ret << std::endl;
      // A rare case where we generate a non-constant value from constant
      // leaves is (/ n 0).
      if(ret.isConst())
      {
        return ret;
      }
    }
  }
  Trace("sygus-sample-grammar") << "...resort to random value" << std::endl;
  // if we did not generate based on the grammar, pick a random value
  return getRandomValue(dt.getSygusType());
}

// recursion depth bounded by number of types in grammar (small)
void SygusSampler::registerSygusType(TypeNode tn)
{
  if (d_rvalue_cindices.find(tn) == d_rvalue_cindices.end())
  {
    d_rvalue_cindices[tn].clear();
    if (!tn.isDatatype())
    {
      return;
    }
    const DType& dt = tn.getDType();
    if (!dt.isSygus())
    {
      return;
    }
    for (unsigned i = 0, ncons = dt.getNumConstructors(); i < ncons; i++)
    {
      const DTypeConstructor& dtc = dt[i];
      Node sop = dtc.getSygusOp();
      bool isVar = std::find(d_vars.begin(), d_vars.end(), sop) != d_vars.end();
      if (isVar)
      {
        // if it is a variable, add it to the list of sygus types for that var
        d_var_sygus_types[sop].push_back(tn);
      }
      else
      {
        // otherwise, it is a constructor for sygus random value
        d_rvalue_cindices[tn].push_back(i);
        if (dtc.getNumArgs() == 0)
        {
          d_rvalue_null_cindices[tn].push_back(i);
          if (sop.isConst())
          {
            d_const_sygus_types[sop].push_back(tn);
          }
        }
      }
      // recurse on all subfields
      for (unsigned j = 0, nargs = dtc.getNumArgs(); j < nargs; j++)
      {
        TypeNode tnc = d_tds->getArgType(dtc, j);
        registerSygusType(tnc);
      }
    }
  }
}

}  // namespace quantifiers
}  // namespace theory
}  // namespace cvc5::internal




© 2015 - 2024 Weber Informatics LLC | Privacy Policy