From aecf47142b83563095420c08bf9ede4347c0f806 Mon Sep 17 00:00:00 2001 From: Matthew Fernandez Date: Sun, 30 Aug 2020 09:18:58 -0700 Subject: [PATCH] implement the beginnings of generic translation to SMT Github: related to #13 "partial order reduction" --- librumur/CMakeLists.txt | 2 + librumur/include/rumur/SymContext.h | 54 ++++++ librumur/include/rumur/rumur.h | 2 + librumur/include/rumur/smt.h | 52 ++++++ librumur/src/SymContext.cc | 49 ++++++ librumur/src/smt.cc | 258 ++++++++++++++++++++++++++++ 6 files changed, 417 insertions(+) create mode 100644 librumur/include/rumur/SymContext.h create mode 100644 librumur/include/rumur/smt.h create mode 100644 librumur/src/SymContext.cc create mode 100644 librumur/src/smt.cc diff --git a/librumur/CMakeLists.txt b/librumur/CMakeLists.txt index 158530e4..5ab705bc 100644 --- a/librumur/CMakeLists.txt +++ b/librumur/CMakeLists.txt @@ -70,7 +70,9 @@ add_library(librumur src/Property.cc src/resolve-symbols.cc src/Rule.cc + src/smt.cc src/Stmt.cc + src/SymContext.cc src/traverse.cc src/TypeExpr.cc src/validate.cc diff --git a/librumur/include/rumur/SymContext.h b/librumur/include/rumur/SymContext.h new file mode 100644 index 00000000..a472cbc0 --- /dev/null +++ b/librumur/include/rumur/SymContext.h @@ -0,0 +1,54 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace rumur { + +// Symbolic context, for maintaining a mapping between Murphi variables and +// external (generated) symbols. This has extremely limited functionality, only +// enough to support the translation to SMT (see smt.h). +class SymContext { + + private: + // stack of symbol tables, mapping AST unique IDs to external names + std::vector> scope; + + // monotonic counter used for generating unique symbols + size_t counter = 0; + + public: + SymContext(); + + // descend into or ascend from a variable scope + void open_scope(); + void close_scope(); + + /// add a new known symbol + /// + /// This registers the symbol in the current innermost scope. + /// + /// \param id Unique identifier of the source AST node + /// \return A unique name created for this symbol + std::string register_symbol(size_t id); + + /// lookup a previously registered symbol + /// + /// This lookup is performed in all known variable scopes, going from + /// innermost to outermost in preference order + /// + /// \param id Unique identifier of the AST node being looked up + /// \param origin The node that caused this lookup (used for error + /// diagnostics) + /// \return The unique name this symbol maps to + std::string lookup_symbol(size_t id, const Node &origin) const; + + private: + std::string make_symbol(); + +}; + +} diff --git a/librumur/include/rumur/rumur.h b/librumur/include/rumur/rumur.h index 621478c3..877a72ff 100644 --- a/librumur/include/rumur/rumur.h +++ b/librumur/include/rumur/rumur.h @@ -20,7 +20,9 @@ #include #include #include +#include #include +#include #include #include #include diff --git a/librumur/include/rumur/smt.h b/librumur/include/rumur/smt.h new file mode 100644 index 00000000..e3ed6a64 --- /dev/null +++ b/librumur/include/rumur/smt.h @@ -0,0 +1,52 @@ +// functionality related to interacting with a Satisfiability Modulo Theories +// solver + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace rumur { + +struct SMTConfig { + + // use bitvectors instead of integers for numeric values? + bool prefer_bitvectors = false; + + // bit width to use to represent numerical values if using bitvectors + size_t bitvector_width = 64; + + // various SMT operators whose selection is dependent on context + std::string add (const Node &origin) const; + std::string band(const Node &origin) const; + std::string bnot(const Node &origin) const; + std::string bor (const Node &origin) const; + std::string bxor(const Node &origin) const; + std::string div (const Node &origin) const; + std::string geq (const Node &origin) const; + std::string gt (const Node &origin) const; + std::string leq (const Node &origin) const; + std::string lsh (const Node &origin) const; + std::string lt (const Node &origin) const; + std::string mod (const Node &origin) const; + std::string mul (const Node &origin) const; + std::string neg (const Node &origin) const; + std::string rsh (const Node &origin) const; + std::string sub (const Node &origin) const; + std::string numeric_literal(const mpz_class &value, const Number &origin) const; + +}; + +// translate a given expression to SMTLIBv2 +void to_smt(std::ostream &out, const Expr &n, SymContext &ctxt, SMTConfig &conf); + +// wrapper around the above for when you do not need a long lived output buffer +std::string to_smt(const Expr &n, SymContext &ctxt, SMTConfig &conf); + +} diff --git a/librumur/src/SymContext.cc b/librumur/src/SymContext.cc new file mode 100644 index 00000000..332de699 --- /dev/null +++ b/librumur/src/SymContext.cc @@ -0,0 +1,49 @@ +#include +#include +#include +#include +#include +#include + +namespace rumur { + +SymContext::SymContext() { + open_scope(); +} + +void SymContext::open_scope() { + scope.push_back(std::unordered_map{}); +} + +void SymContext::close_scope() { + scope.pop_back(); + // TODO: we probably need to record symbols somewhere +} + +std::string SymContext::register_symbol(size_t id) { + // invent a new symbol and map this ID to it + std::string s = make_symbol(); + scope.back()[id] = s; + return s; +} + +std::string SymContext::lookup_symbol(size_t id, const Node &origin) const { + + // lookup the symbol in enclosing scopes from innermost to outermost + for (auto it = scope.rbegin(); it != scope.rend(); ++it) { + auto it2 = it->find(id); + if (it2 != it->end()) + return it2->second; + } + + // we expect any symbol encountered in a well-formed AST to be associated with + // a previously encountered definition + throw Error("unknown symbol encountered; applying SMT translation to an " + "unvalidated AST?", origin.loc); +} + +std::string SymContext::make_symbol() { + return "s" + std::to_string(counter++); +} + +} diff --git a/librumur/src/smt.cc b/librumur/src/smt.cc new file mode 100644 index 00000000..fedab8d5 --- /dev/null +++ b/librumur/src/smt.cc @@ -0,0 +1,258 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace rumur { + +std::string SMTConfig::add(const Node&) const { + return prefer_bitvectors ? "bvadd" : "+"; +} + +std::string SMTConfig::band(const Node &origin) const { + if (prefer_bitvectors) + return "bvand"; + throw Error("SMT translation involving bitwise AND is only supported when " + "using bitvector representations", origin.loc); +} + +std::string SMTConfig::bnot(const Node &origin) const { + if (prefer_bitvectors) + return "bvnot"; + throw Error("SMT translation involving bitwise NOT is only supported when " + "using bitvector representations", origin.loc); +} + +std::string SMTConfig::bor(const Node &origin) const { + if (prefer_bitvectors) + return "bvor"; + throw Error("SMT translation involving bitwise OR is only supported when " + "using bitvector representations", origin.loc); +} + +std::string SMTConfig::bxor(const Node &origin) const { + if (prefer_bitvectors) + return "bvxor"; + throw Error("SMT translation involving bitwise XOR is only supported when " + "using bitvector representations", origin.loc); +} + +std::string SMTConfig::div(const Node&) const { + // solvers like CVC4 may fail with an error when given "div" but just ignore + // this for now + return prefer_bitvectors ? "bvsdiv" : "div"; +} + +std::string SMTConfig::geq(const Node&) const { + return prefer_bitvectors ? "bvsge" : ">="; +} + +std::string SMTConfig::gt(const Node&) const { + return prefer_bitvectors ? "bvsgt" : ">"; +} + +std::string SMTConfig::leq(const Node&) const { + return prefer_bitvectors ? "bvsle" : "<="; +} + +std::string SMTConfig::lsh(const Node &origin) const { + if (prefer_bitvectors) + return "bvshl"; + throw Error("SMT translation involving left shift is only supported when " + "using bitvector representations", origin.loc); +} + +std::string SMTConfig::lt(const Node&) const { + return prefer_bitvectors ? "bvslt" : "<"; +} + +std::string SMTConfig::mod(const Node&) const { + return prefer_bitvectors ? "bvsmod" : "mod"; +} + +std::string SMTConfig::mul(const Node&) const { + return prefer_bitvectors ? "bvmul" : "*"; +} + +std::string SMTConfig::neg(const Node&) const { + return prefer_bitvectors ? "bvneg" : "-"; +} + +std::string SMTConfig::rsh(const Node &origin) const { + if (prefer_bitvectors) + return "bvashr"; + throw Error("SMT translation involving right shift is only supported when " + "using bitvector representations", origin.loc); +} + +std::string SMTConfig::sub(const Node&) const { + return prefer_bitvectors ? "bvsub" : "-"; +} + +std::string SMTConfig::numeric_literal(const mpz_class &value, + const Number &origin) const { + + if (value < 0) + return "(" + neg(origin) + " " + numeric_literal(-value, origin) + ")"; + + if (prefer_bitvectors) { + return "(_ bv" + value.get_str() + " " + std::to_string(bitvector_width) + + ")"; + } + + return value.get_str(); +} + +namespace { class Translator : public ConstTraversal { + + private: + std::ostream &out; + SymContext &ctxt; + const SMTConfig &conf; + + public: + Translator(std::ostream &out_, SymContext &ctxt_, const SMTConfig &conf_) + : out(out_), ctxt(ctxt_), conf(conf_) { } + + Translator &operator<<(const std::string &s) { + out << s; + return *this; + } + + Translator &operator<<(const Expr &e) { + dispatch(e); + return *this; + } + + void visit_add(const Add &n) { + *this << "(" << conf.add(n) << " " << *n.lhs << " " << *n.rhs << ")"; + } + + void visit_and(const And &n) { + *this << "(and " << *n.lhs << " " << *n.rhs << ")"; + } + + void visit_band(const Band &n) { + *this << "(" << conf.band(n) << " " << *n.lhs << " " << *n.rhs << ")"; + } + + void visit_bnot(const Bnot &n) { + *this << "(" << conf.bnot(n) << " " << *n.rhs << ")"; + } + + void visit_bor(const Bor &n) { + *this << "(" << conf.bor(n) << " " << *n.lhs << " " << *n.rhs << ")"; + } + + void visit_element(const Element &n) { + *this << "(select " << *n.array << " " << *n.index << ")"; + } + + void visit_exprid(const ExprID &n) { + *this << ctxt.lookup_symbol(n.value->unique_id, n); + } + + void visit_eq(const Eq &n) { + *this << "(= " << *n.lhs << " " << *n.rhs << ")"; + } + + void visit_div(const Div &n) { + *this << "(" << conf.div(n) << " " << *n.lhs << " " << *n.rhs << ")"; + } + + void visit_geq(const Geq &n) { + *this << "(" << conf.geq(n) << " " << *n.lhs << " " << *n.rhs << ")"; + } + + void visit_gt(const Gt &n) { + *this << "(" << conf.gt(n) << " " << *n.lhs << " " << *n.rhs << ")"; + } + + void visit_implication(const Implication &n) { + *this << "(=> " << *n.lhs << " " << *n.rhs << ")"; + } + + void visit_isundefined(const IsUndefined &n) { + throw Error("SMT translation of isundefined expressions is not supported", + n.loc); + } + + void visit_leq(const Leq &n) { + *this << "(" << conf.leq(n) << " " << *n.lhs << " " << *n.rhs << ")"; + } + + void visit_lsh(const Lsh &n) { + *this << "(" << conf.lsh(n) << " " << *n.lhs << " " << *n.rhs << ")"; + } + + void visit_lt(const Lt &n) { + *this << "(" << conf.lt(n) << " " << *n.lhs << " " << *n.rhs << ")"; + } + + void visit_mod(const Mod &n) { + *this << "(" << conf.mod(n) << " " << *n.lhs << " " << *n.rhs << ")"; + } + + void visit_mul(const Mul &n) { + *this << "(" << conf.mul(n) << " " << *n.lhs << " " << *n.rhs << ")"; + } + + void visit_negative(const Negative &n) { + *this << "(" << conf.neg(n) << " " << *n.rhs << ")"; + } + + void visit_neq(const Neq &n) { + *this << "(not (= " << *n.lhs << " " << *n.rhs << "))"; + } + + void visit_number(const Number &n) { + *this << conf.numeric_literal(n.value, n); + } + + void visit_not(const Not &n) { + *this << "(not " << *n.rhs << ")"; + } + + void visit_or(const Or &n) { + *this << "(or " << *n.lhs << " " << *n.rhs << ")"; + } + + void visit_rsh(const Rsh &n) { + *this << "(" << conf.rsh(n) << " " << *n.lhs << " " << *n.rhs << ")"; + } + + void visit_sub(const Sub &n) { + *this << "(" << conf.sub(n) << " " << *n.lhs << " " << *n.rhs << ")"; + } + + void visit_ternary(const Ternary &n) { + *this << "(ite " << *n.cond << " " << *n.lhs << " " << *n.rhs << ")"; + } + + void visit_xor(const Xor &n) { + *this << "(" << conf.bxor(n) << " " << *n.lhs << " " << *n.rhs << ")"; + } + +}; } + +void to_smt(std::ostream &out, const Expr &n, SymContext &ctxt, + SMTConfig &conf) { + + Translator t(out, ctxt, conf); + t.dispatch(n); +} + +std::string to_smt(const Expr &n, SymContext &ctxt, SMTConfig &conf) { + std::ostringstream buf; + to_smt(buf, n, ctxt, conf); + return buf.str(); +} + +}