From fbb09c17830301a8be9f6403fab1806f6fafd607 Mon Sep 17 00:00:00 2001 From: Steffen Smolka Date: Thu, 5 Dec 2024 16:30:51 -0800 Subject: [PATCH] [NetKAT] Add pretty printer for SymbolicPackets, for debugging only. PiperOrigin-RevId: 703290191 --- gutil/status.h | 3 + gutil/status_matchers.h | 2 + netkat/BUILD.bazel | 59 +++++ netkat/interned_field.cc | 53 ++++ netkat/interned_field.h | 117 +++++++++ netkat/interned_field_test.cc | 76 ++++++ netkat/netkat.proto | 2 +- netkat/paged_stable_vector.h | 74 ++++++ netkat/paged_stable_vector_test.cc | 107 ++++++++ netkat/symbolic_packet.cc | 376 +++++++++++++++++++++++++++++ netkat/symbolic_packet.h | 347 ++++++++++++++++++-------- netkat/symbolic_packet_test.cc | 235 ++++++++++++++++-- 12 files changed, 1336 insertions(+), 115 deletions(-) create mode 100644 netkat/interned_field.cc create mode 100644 netkat/interned_field.h create mode 100644 netkat/interned_field_test.cc create mode 100644 netkat/paged_stable_vector.h create mode 100644 netkat/paged_stable_vector_test.cc create mode 100644 netkat/symbolic_packet.cc diff --git a/gutil/status.h b/gutil/status.h index 8ff4776..66865d3 100644 --- a/gutil/status.h +++ b/gutil/status.h @@ -11,6 +11,9 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +// +// IWYU pragma: always_keep, since outside of Google this defines key macros. + #ifndef GOOGLE_NETKAT_GUTIL_STATUS_H #define GOOGLE_NETKAT_GUTIL_STATUS_H diff --git a/gutil/status_matchers.h b/gutil/status_matchers.h index 756873b..26d2062 100644 --- a/gutil/status_matchers.h +++ b/gutil/status_matchers.h @@ -11,6 +11,8 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +// +// IWYU pragma: always_keep, since outside of Google this defines key macros. #ifndef GOOGLE_NETKAT_GUTIL_STATUS_MATCHERS_H #define GOOGLE_NETKAT_GUTIL_STATUS_MATCHERS_H diff --git a/netkat/BUILD.bazel b/netkat/BUILD.bazel index 69b5820..b535705 100644 --- a/netkat/BUILD.bazel +++ b/netkat/BUILD.bazel @@ -29,9 +29,21 @@ cc_test( cc_library( name = "symbolic_packet", + srcs = ["symbolic_packet.cc"], hdrs = ["symbolic_packet.h"], deps = [ + ":evaluator", + ":interned_field", ":netkat_cc_proto", + ":paged_stable_vector", + "//gutil:status", + "@com_google_absl//absl/container:fixed_array", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", ], ) @@ -40,9 +52,14 @@ cc_test( name = "symbolic_packet_test", srcs = ["symbolic_packet_test.cc"], deps = [ + ":evaluator", + ":netkat_proto_constructors", ":symbolic_packet", + "//gutil:status_matchers", + "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", + "@com_google_fuzztest//fuzztest", "@com_google_googletest//:gtest_main", ], ) @@ -69,6 +86,48 @@ cc_library( ], ) +cc_library( + name = "interned_field", + srcs = ["interned_field.cc"], + hdrs = ["interned_field.h"], + deps = [ + "//gutil:status", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_test( + name = "interned_field_test", + srcs = ["interned_field_test.cc"], + deps = [ + ":interned_field", + "//gutil:status_matchers", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "paged_stable_vector", + hdrs = ["paged_stable_vector.h"], +) + +cc_test( + name = "paged_stable_vector_test", + srcs = ["paged_stable_vector_test.cc"], + deps = [ + ":paged_stable_vector", + "@com_google_fuzztest//fuzztest", + "@com_google_googletest//:gtest_main", + ], +) + cc_test( name = "netkat_proto_constructors_test", srcs = ["netkat_proto_constructors_test.cc"], diff --git a/netkat/interned_field.cc b/netkat/interned_field.cc new file mode 100644 index 0000000..d580642 --- /dev/null +++ b/netkat/interned_field.cc @@ -0,0 +1,53 @@ +// Copyright 2024 The NetKAT authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "netkat/interned_field.h" + +#include + +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "gutil/status.h" + +namespace netkat { + +InternedField InternedFieldManager::InternField(absl::string_view field_name) { + auto [it, inserted] = interned_field_by_name_.try_emplace( + field_name, InternedField(field_names_.size())); + if (inserted) field_names_.push_back(std::string(field_name)); + return it->second; +} + +std::string InternedFieldManager::GetFieldNameOrDie(InternedField field) const { + CHECK_LT(field.index_, field_names_.size()); // Crash ok + return field_names_[field.index_]; +} + +absl::Status InternedFieldManager::CheckInternalInvariants() const { + for (int i = 0; i < field_names_.size(); ++i) { + auto it = interned_field_by_name_.find(field_names_[i]); + RET_CHECK(it != interned_field_by_name_.end()); + RET_CHECK(it->second.index_ == i); + } + + for (const auto& [name, field] : interned_field_by_name_) { + RET_CHECK(field.index_ < field_names_.size()); + RET_CHECK(field_names_[field.index_] == name); + } + + return absl::OkStatus(); +} + +} // namespace netkat diff --git a/netkat/interned_field.h b/netkat/interned_field.h new file mode 100644 index 0000000..7eabf6a --- /dev/null +++ b/netkat/interned_field.h @@ -0,0 +1,117 @@ +// Copyright 2024 The NetKAT authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// ----------------------------------------------------------------------------- +// File: interned_field.h +// ----------------------------------------------------------------------------- +// +// A module for "interning" (aka hash-consing) NetKAT packet fields, see +// https://en.wikipedia.org/wiki/String_interning. This makes it cheap to +// compare, hash, copy and store packet fields (small constant time/space). + +#ifndef GOOGLE_NETKAT_NETKAT_INTERNED_FIELD_H_ +#define GOOGLE_NETKAT_NETKAT_INTERNED_FIELD_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" + +namespace netkat { + +// An "interned" (aka hash-consed) NetKAT packet field, e.g. "dst_ip". +// +// Technically, a lightweight handle (16 bits) that is very cheap (O(1)) to +// copy, store, hash, and compare. Handles can only be created by an +// `InternedFieldManager` object, which owns the field name (e.g. "dst_ip") +// associated with the handle. +// +// CAUTION: Each `InternedField` is implicitly associated with the manager +// object that created it; using it with a different manager object has +// undefined behavior. +class [[nodiscard]] InternedField { + public: + // `InternedField`s can only be created by `InternedFieldManager`. + InternedField() = delete; + friend class InternedFieldManager; + + // O(1) comparison, thanks to interning/hash-consing. + friend auto operator<=>(InternedField a, InternedField b) = default; + + // Hashing, see https://abseil.io/docs/cpp/guides/hash. + template + friend H AbslHashValue(H h, InternedField field) { + return H::combine(std::move(h), field.index_); + } + + // Formatting, see https://abseil.io/docs/cpp/guides/abslstringify. + template + friend void AbslStringify(Sink& sink, InternedField field) { + absl::Format(&sink, "InternedField<%d>", field.index_); + } + + private: + // An index into the `field_names_` vector of the `InternedFieldManager` + // object associated with this `InternedField`: `field_names_[index_]` is the + // name of the field. The index is otherwise arbitrary and meaningless. + // + // We use a 16-bit index as a tradeoff between minimizing memory usage while + // supporting sufficiently many fields. We expect 100s, but not more than + // 2^16 ~= 65k fields. + uint16_t index_; + explicit InternedField(uint16_t index) : index_(index) {} +}; + +// Protect against regressions in the memory layout, as it affects performance. +static_assert(sizeof(InternedField) <= 2); + +// An "arena" for interning NetKAT packet fields, owning the memory associated +// with the interned fields. +class InternedFieldManager { + public: + InternedFieldManager() = default; + + // Returns an interned representation of field with the given name. + InternedField InternField(absl::string_view field_name); + + // Returns the name of the given interned field, or crashes if the interned + // field was not created by this manager object. + std::string GetFieldNameOrDie(InternedField field) const; + + // Dynamically checks all class invariants. Exposed for testing only. + absl::Status CheckInternalInvariants() const; + + private: + // All field names interned by this manager object. The name of an interned + // field `f` created by this object is `field_names_[f.index_]`. + std::vector field_names_; + + // A so called "unique table" to ensure each field name is added to + // `field_names_` at most once, and thus is represented by a unique index into + // that vector. + // + // Invariant: + // `interned_field_by_name_[n] == f` iff `field_names_[f.index_] == n`. + absl::flat_hash_map interned_field_by_name_; +}; + +} // namespace netkat + +#endif // GOOGLE_NETKAT_NETKAT_INTERNED_FIELD_H_ diff --git a/netkat/interned_field_test.cc b/netkat/interned_field_test.cc new file mode 100644 index 0000000..437db8b --- /dev/null +++ b/netkat/interned_field_test.cc @@ -0,0 +1,76 @@ +// Copyright 2024 The NetKAT authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "netkat/interned_field.h" + +#include "absl/base/no_destructor.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/str_cat.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "gutil/status_matchers.h" + +namespace netkat { + +namespace { + +using testing::StartsWith; + +// We use a global manager object across all tests to exercise statefulness. +InternedFieldManager& Manager() { + static absl::NoDestructor manager; + return *manager; +} + +// After executing all tests, we check once that no invariants are violated. +class CheckInternedFieldManagerInvariantsOnTearDown + : public testing::Environment { + public: + ~CheckInternedFieldManagerInvariantsOnTearDown() override {} + void SetUp() override {} + void TearDown() override { ASSERT_OK(Manager().CheckInternalInvariants()); } +}; +testing::Environment* const foo_env = testing::AddGlobalTestEnvironment( + new CheckInternedFieldManagerInvariantsOnTearDown); + +TEST(InternedFieldManagerTest, AbslStringifyWorkst) { + EXPECT_THAT(absl::StrCat(Manager().InternField("foo")), + StartsWith("InternedField")); +} + +TEST(InternedFieldManagerTest, AbslHashValueWorks) { + absl::flat_hash_set set = { + Manager().InternField("foo"), + Manager().InternField("bar"), + }; + EXPECT_EQ(set.size(), 2); +} + +TEST(InternedFieldManagerTest, InternFieldReturnsSameFieldForSameName) { + EXPECT_EQ(Manager().InternField("foo"), Manager().InternField("foo")); +} + +TEST(InternedFieldManagerTest, + InternFieldReturnsDifferentFieldForDifferentNames) { + EXPECT_NE(Manager().InternField("foo"), Manager().InternField("bar")); +} + +TEST(InternedFieldManagerTest, GetFieldNameOrDieReturnsNameOfInternedField) { + InternedField foo = Manager().InternField("foo"); + InternedField bar = Manager().InternField("bar"); + EXPECT_EQ(Manager().GetFieldNameOrDie(foo), "foo"); + EXPECT_EQ(Manager().GetFieldNameOrDie(bar), "bar"); +} +} // namespace +} // namespace netkat diff --git a/netkat/netkat.proto b/netkat/netkat.proto index 3ad90c3..33c8b18 100644 --- a/netkat/netkat.proto +++ b/netkat/netkat.proto @@ -13,7 +13,7 @@ // limitations under the License. // // ----------------------------------------------------------------------------- -// netkat.proto +// File: netkat.proto // ----------------------------------------------------------------------------- // // Proto representation of NetKAT programs (predicates and policies). diff --git a/netkat/paged_stable_vector.h b/netkat/paged_stable_vector.h new file mode 100644 index 0000000..04ac895 --- /dev/null +++ b/netkat/paged_stable_vector.h @@ -0,0 +1,74 @@ +// Copyright 2024 The NetKAT authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// ----------------------------------------------------------------------------- +// File: paged_stable_vector.h +// ----------------------------------------------------------------------------- + +#ifndef GOOGLE_NETKAT_NETKAT_PAGED_STABLE_VECTOR_H_ +#define GOOGLE_NETKAT_NETKAT_PAGED_STABLE_VECTOR_H_ + +#include +#include +#include + +namespace netkat { + +// A variant of `std::vector` that allocates memory in pages (or "chunks") of +// fixed `PageSize`. This introduces an extra level of indirection and +// introduces some level of discontiguity (depending on `PageSize`), but allows +// the class to guarantee pointer stability: calls to `push_back`/`emplace_back` +// never invalidate pointers/iterators/references to elements previously added +// to the vector. +// +// Allocating memory in pages also avoids the cost of relocation, which may be +// significant for very large vectors in performance-sensitive applications. +// +// The API of this class is kept just large enough to cover our use cases. +template +class PagedStableVector { + public: + PagedStableVector() = default; + + size_t size() const { + return data_.empty() ? 0 + : (data_.size() - 1) * PageSize + data_.back().size(); + } + + template + void push_back(Value&& value) { + if (size() % PageSize == 0) data_.emplace_back().reserve(PageSize); + data_.back().push_back(std::forward(value)); + } + + template + void emplace_back(Args&&... value) { + if (size() % PageSize == 0) data_.emplace_back().reserve(PageSize); + data_.back().emplace_back(std::forward(value)...); + } + + T& operator[](size_t index) { + return data_[index / PageSize][index % PageSize]; + } + const T& operator[](size_t index) const { + return data_[index / PageSize][index % PageSize]; + } + + private: + std::vector> data_; +}; + +} // namespace netkat + +#endif // GOOGLE_NETKAT_NETKAT_PAGED_STABLE_VECTOR_H_ diff --git a/netkat/paged_stable_vector_test.cc b/netkat/paged_stable_vector_test.cc new file mode 100644 index 0000000..7fab3a6 --- /dev/null +++ b/netkat/paged_stable_vector_test.cc @@ -0,0 +1,107 @@ +// Copyright 2024 The NetKAT authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "netkat/paged_stable_vector.h" + +#include +#include + +#include "fuzztest/fuzztest.h" +#include "gtest/gtest.h" + +namespace netkat { +namespace { + +// A small, but otherwise random page size used throughout the tests. +// Using a small page size is useful for exercising the page replacement logic. +static constexpr int kSmallPageSize = 3; + +void PushBackInreasesSize(std::vector elements) { + PagedStableVector vector; + for (const auto& element : elements) { + vector.push_back(element); + } + EXPECT_EQ(vector.size(), elements.size()); +} +FUZZ_TEST(PagedStableVectorTest, PushBackInreasesSize); + +void EmplaceBackInreasesSize(std::vector elements) { + PagedStableVector vector; + for (const auto& element : elements) { + vector.emplace_back(element); + } + EXPECT_EQ(vector.size(), elements.size()); +} +FUZZ_TEST(PagedStableVectorTest, EmplaceBackInreasesSize); + +void PushBackAddsElementToBack(std::vector elements) { + PagedStableVector vector; + for (int i = 0; i < elements.size(); ++i) { + vector.push_back(elements[i]); + for (int j = 0; j < i; ++j) { + EXPECT_EQ(vector[j], elements[j]); + } + } +} +FUZZ_TEST(PagedStableVectorTest, PushBackAddsElementToBack); + +void EmplaceBackAddsElementToBack(std::vector elements) { + PagedStableVector vector; + for (int i = 0; i < elements.size(); ++i) { + vector.emplace_back(elements[i]); + for (int j = 0; j < i; ++j) { + EXPECT_EQ(vector[j], elements[j]); + } + } +} +FUZZ_TEST(PagedStableVectorTest, EmplaceBackAddsElementToBack); + +void BracketAssigmentWorks(std::vector elements) { + PagedStableVector vector; + for (int i = 0; i < elements.size(); ++i) { + vector.push_back("initial value"); + } + for (int i = 0; i < elements.size(); ++i) { + vector[i] = elements[i]; + } + for (int i = 0; i < elements.size(); ++i) { + EXPECT_EQ(vector[i], elements[i]); + } +} +FUZZ_TEST(PagedStableVectorTest, BracketAssigmentWorks); + +TEST(PagedStableVectorTest, ReferencesDontGetInvalidated) { + PagedStableVector vector; + + // Store a few references. + vector.push_back("first element"); + std::string& first_element = vector[0]; + vector.push_back("second element"); + std::string& second_element = vector[1]; + + // Push a ton of elements to trigger page allocation. + // If this were a regular std::vector, the references would be invalidated. + for (int i = 0; i < 100 * kSmallPageSize; ++i) { + vector.push_back("dummy"); + } + + // Check that the references are still valid. + first_element = "new first element"; + EXPECT_EQ(vector[0], "new first element"); + second_element = "new second element"; + EXPECT_EQ(vector[1], "new second element"); +}; + +} // namespace +} // namespace netkat diff --git a/netkat/symbolic_packet.cc b/netkat/symbolic_packet.cc new file mode 100644 index 0000000..70af9c8 --- /dev/null +++ b/netkat/symbolic_packet.cc @@ -0,0 +1,376 @@ +// Copyright 2024 The NetKAT authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "netkat/symbolic_packet.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/container/fixed_array.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/escaping.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "gutil/status.h" +#include "netkat/evaluator.h" + +namespace netkat { + +// The empty and full set of packets are not decision nodes, and thus we cannot +// associate an index into the `nodes_` vector with them. Instead, we represent +// them using sentinel values, chosen maximally to avoid collisions with proper +// indices. +enum SentinelNodeIndex : uint32_t { + kEmptySet = std::numeric_limits::max(), + kFullSet = std::numeric_limits::max() - 1, +}; + +SymbolicPacket::SymbolicPacket() : node_index_(SentinelNodeIndex::kEmptySet) {} + +std::string SymbolicPacket::ToString() const { + if (node_index_ == SentinelNodeIndex::kEmptySet) { + return "SymbolicPacket"; + } else if (node_index_ == SentinelNodeIndex::kFullSet) { + return "SymbolicPacket"; + } else { + return absl::StrFormat("SymbolicPacket<%d>", node_index_); + } +} + +SymbolicPacket SymbolicPacketManager::EmptySet() const { + return SymbolicPacket(SentinelNodeIndex::kEmptySet); +} + +SymbolicPacket SymbolicPacketManager::FullSet() const { + return SymbolicPacket(SentinelNodeIndex::kFullSet); +} + +bool SymbolicPacketManager::IsEmptySet(SymbolicPacket packet) const { + return packet == EmptySet(); +} + +bool SymbolicPacketManager::IsFullSet(SymbolicPacket packet) const { + return packet == FullSet(); +} + +const SymbolicPacketManager::DecisionNode& SymbolicPacketManager::GetNodeOrDie( + SymbolicPacket packet) const { + CHECK_LT(packet.node_index_, nodes_.size()); // Crash ok + return nodes_[packet.node_index_]; +} + +SymbolicPacket SymbolicPacketManager::NodeToPacket(DecisionNode&& node) { + if (node.branch_by_field_value.empty()) return node.default_branch; + +// When in debug mode, we check a node's invariants before interning it. +// We could check the invariants of all nodes by calling +// `CheckInternalInvariants`, but that would be redundant and asymptotically +// expensive. +#ifndef NDEBUG + for (const auto& [value, branch] : node.branch_by_field_value) { + CHECK(branch != node.default_branch) << PrettyPrint(node); + if (!IsEmptySet(branch) && !IsFullSet(branch)) { + auto& branch_node = GetNodeOrDie(branch); + CHECK(branch_node.field > node.field) << absl::StreamFormat( + "(%v > %v)\n---branch---\n%s\n---node---\n%s", branch_node.field, + node.field, PrettyPrint(branch), PrettyPrint(node)); + } + } +#endif + + auto [it, inserted] = + packet_by_node_.try_emplace(node, SymbolicPacket(nodes_.size())); + if (inserted) nodes_.push_back(std::move(node)); + return it->second; +} + +bool SymbolicPacketManager::Contains(SymbolicPacket symbolic_packet, + Packet concrete_packet) const { + if (IsEmptySet(symbolic_packet)) return false; + if (IsFullSet(symbolic_packet)) return true; + + const DecisionNode& node = GetNodeOrDie(symbolic_packet); + std::string field = field_manager_.GetFieldNameOrDie(node.field); + auto it = concrete_packet.find(field); + if (it != concrete_packet.end()) { + for (const auto& [value, branch] : node.branch_by_field_value) { + if (it->second == value) return Contains(branch, concrete_packet); + } + } + return Contains(node.default_branch, concrete_packet); +} + +SymbolicPacket SymbolicPacketManager::Compile(const PredicateProto& pred) { + switch (pred.predicate_case()) { + case PredicateProto::kBoolConstant: + return pred.bool_constant().value() ? FullSet() : EmptySet(); + case PredicateProto::kMatch: + return Match(pred.match().field(), pred.match().value()); + case PredicateProto::kAndOp: + return And(Compile(pred.and_op().left()), Compile(pred.and_op().right())); + case PredicateProto::kOrOp: + return Or(Compile(pred.or_op().left()), Compile(pred.or_op().right())); + case PredicateProto::kNotOp: + return Not(Compile(pred.not_op().negand())); + // By convention, uninitialized predicates must be treated like `false`. + case PredicateProto::PREDICATE_NOT_SET: + return EmptySet(); + } + LOG(FATAL) << "Unhandled predicate kind: " << pred.predicate_case(); +} + +SymbolicPacket SymbolicPacketManager::Match(absl::string_view field, + int value) { + return NodeToPacket(DecisionNode{ + .field = field_manager_.InternField(field), + .default_branch = EmptySet(), + .branch_by_field_value = {{value, FullSet()}}, + }); +} + +// TODO(b/382380335): Use complement edges to reduce the complexity of this +// function from O(n) to O(1). +SymbolicPacket SymbolicPacketManager::Not(SymbolicPacket negand) { + // Base cases. + if (IsEmptySet(negand)) return FullSet(); + if (IsFullSet(negand)) return EmptySet(); + + // Compute result the hard way. + const DecisionNode& negand_node = GetNodeOrDie(negand); + DecisionNode result_node{ + .field = negand_node.field, + .default_branch = Not(negand_node.default_branch), + .branch_by_field_value{negand_node.branch_by_field_value.size()}, + }; + + for (int i = 0; i < negand_node.branch_by_field_value.size(); ++i) { + auto [value, branch] = negand_node.branch_by_field_value[i]; + SymbolicPacket negated_branch = Not(branch); + DCHECK(branch != negand_node.default_branch); + DCHECK(negated_branch != result_node.default_branch); + result_node.branch_by_field_value[i] = + std::make_pair(value, negated_branch); + } + + return NodeToPacket(std::move(result_node)); +} + +SymbolicPacket SymbolicPacketManager::And(SymbolicPacket left, + SymbolicPacket right) { + // Base cases. + if (IsEmptySet(left) || IsFullSet(right) || left == right) return left; + if (IsEmptySet(right) || IsFullSet(left)) return right; + + // TODO(b/382379263): Before computing the result recursively, Look up if the + // result has previously been computed using a memoization table. This can + // reduce the number of nodes we need to visit exponentially. + + // Compute result the hard way. + const DecisionNode* left_node = &GetNodeOrDie(left); + const DecisionNode* right_node = &GetNodeOrDie(right); + + // We exploit that `And` is commutative to canonicalize the order of the + // arguments, reducing the number of cases by 1. + if (left_node->field > right_node->field) { + std::swap(left, right); + std::swap(left_node, right_node); + } + + // Case 1: left_node->field < right_node->field: branch on left field. + if (left_node->field < right_node->field) { + SymbolicPacket default_branch = And(left_node->default_branch, right); + absl::FixedArray> branch_by_field_value( + left_node->branch_by_field_value.size()); + int num_branches = 0; + for (const auto& [value, left_branch] : left_node->branch_by_field_value) { + SymbolicPacket branch = And(left_branch, right); + if (branch == default_branch) continue; + branch_by_field_value[num_branches++] = std::make_pair(value, branch); + } + return NodeToPacket(DecisionNode{ + .field = left_node->field, + .default_branch = default_branch, + .branch_by_field_value{ + branch_by_field_value.begin(), + branch_by_field_value.begin() + num_branches, + }, + }); + } + + // Case 2: left_node->field == right_node->field: branch on shared field. + DCHECK(left_node->field == right_node->field); + SymbolicPacket default_branch = + And(left_node->default_branch, right_node->default_branch); + absl::FixedArray> branch_by_field_value( + left_node->branch_by_field_value.size() + + right_node->branch_by_field_value.size()); + int num_branches = 0; + auto add_branch = [&](int value, SymbolicPacket branch) { + if (branch == default_branch) return; + branch_by_field_value[num_branches++] = std::make_pair(value, branch); + }; + auto left_it = left_node->branch_by_field_value.begin(); + auto left_end = left_node->branch_by_field_value.end(); + auto right_it = right_node->branch_by_field_value.begin(); + auto right_end = right_node->branch_by_field_value.end(); + while (left_it != left_end && right_it != right_end) { + auto [left_value, left_branch] = *left_it; + auto [right_value, right_branch] = *right_it; + if (left_value < right_value) { + add_branch(left_value, And(left_branch, right_node->default_branch)); + ++left_it; + } else if (left_value > right_value) { + add_branch(right_value, And(left_node->default_branch, right_branch)); + ++right_it; + } else { // left_value == right_value + add_branch(left_value, And(left_branch, right_branch)); + ++left_it; + ++right_it; + } + } + for (; left_it != left_end; ++left_it) { + auto [left_value, left_branch] = *left_it; + add_branch(left_value, And(left_branch, right_node->default_branch)); + } + for (; right_it != right_end; ++right_it) { + auto [right_value, right_branch] = *right_it; + add_branch(right_value, And(left_node->default_branch, right_branch)); + } + return NodeToPacket(DecisionNode{ + .field = left_node->field, + .default_branch = default_branch, + .branch_by_field_value{ + branch_by_field_value.begin(), + branch_by_field_value.begin() + num_branches, + }, + }); +} + +SymbolicPacket SymbolicPacketManager::Or(SymbolicPacket left, + SymbolicPacket right) { + // Apply De Morgan's law: a || b == !(!a && !b). + // + // This is currently convenient and terribly inefficient. But once we have + // complement edges (b/382380335) and AND-memoization (b/382379263), reducing + // OR to NOT and AND will actually be better than implementing OR directly, + // since it will allows us to recycle the AND-memoization table. + // + // TODO(b/382380335, b/382379263): Implement complement edges and memoization. + return Not(And(Not(left), Not(right))); +} + +std::string SymbolicPacketManager::PrettyPrint(SymbolicPacket packet) const { + std::string result; + std::deque work_list = {packet}; + absl::flat_hash_set visited = {packet}; + while (!work_list.empty()) { + SymbolicPacket packet = work_list.front(); + work_list.pop_front(); + absl::StrAppend(&result, packet); + + if (IsFullSet(packet) || IsEmptySet(packet)) continue; + + const DecisionNode& node = GetNodeOrDie(packet); + absl::StrAppend(&result, ":\n"); + std::string field = absl::StrFormat( + "%v:'%s'", node.field, + absl::CEscape(field_manager_.GetFieldNameOrDie(node.field))); + for (const auto& [value, branch] : node.branch_by_field_value) { + absl::StrAppendFormat(&result, " %s == %d -> %v\n", field, value, + branch); + if (IsFullSet(branch) || IsEmptySet(branch)) continue; + bool new_branch = visited.insert(branch).second; + if (new_branch) work_list.push_back(branch); + } + SymbolicPacket fallthrough = node.default_branch; + absl::StrAppendFormat(&result, " %s == * -> %v\n", field, fallthrough); + if (IsFullSet(fallthrough) || IsEmptySet(fallthrough)) continue; + bool new_branch = visited.insert(fallthrough).second; + if (new_branch) work_list.push_back(fallthrough); + } + return result; +} + +std::string SymbolicPacketManager::PrettyPrint(const DecisionNode& node) const { + std::string result; + std::vector work_list; + std::string field = absl::StrFormat( + "%v:'%s'", node.field, + absl::CEscape(field_manager_.GetFieldNameOrDie(node.field))); + for (const auto& [value, branch] : node.branch_by_field_value) { + absl::StrAppendFormat(&result, " %s == %d -> %v\n", field, value, branch); + if (!IsFullSet(branch) || !IsEmptySet(branch)) work_list.push_back(branch); + } + SymbolicPacket fallthrough = node.default_branch; + absl::StrAppendFormat(&result, " %s == * -> %v\n", field, fallthrough); + if (!IsFullSet(fallthrough) && !IsEmptySet(fallthrough)) { + work_list.push_back(fallthrough); + } + + for (const SymbolicPacket& branch : work_list) { + absl::StrAppend(&result, PrettyPrint(branch)); + } + + return result; +} + +absl::Status SymbolicPacketManager::CheckInternalInvariants() const { + // Invariant: `packet_by_node_[n] = s` iff `nodes_[s.node_index_] == n`. + for (const auto& [node, packet] : packet_by_node_) { + RET_CHECK(packet.node_index_ < nodes_.size()); + RET_CHECK(nodes_[packet.node_index_] == node); + } + for (int i = 0; i < nodes_.size(); ++i) { + const DecisionNode& node = nodes_[i]; + auto it = packet_by_node_.find(node); + RET_CHECK(it != packet_by_node_.end()); + RET_CHECK(it->second == SymbolicPacket(i)); + } + + // Node Invariants. + for (int i = 0; i < nodes_.size(); ++i) { + const DecisionNode& node = nodes_[i]; + // Invariant: `branch_by_field_value` is non-empty. + // Maintained by `NodeToPacket`. + RET_CHECK(!node.branch_by_field_value.empty()); + + // Invariant: node field is strictly smaller than sub-node fields. + RET_CHECK(IsFullSet(node.default_branch) || + IsEmptySet(node.default_branch) || + GetNodeOrDie(node.default_branch).field > node.field); + for (const auto& [value, branch] : node.branch_by_field_value) { + RET_CHECK(IsFullSet(branch) || IsEmptySet(branch) || + GetNodeOrDie(branch).field > node.field); + + // Invariant: Each case in `branch_by_field_value` is != + // `default_branch`. + RET_CHECK(branch != node.default_branch); + } + + // Invariant: node field is interned by `field_manager_`. + field_manager_.GetFieldNameOrDie(node.field); // No crash. + } + + return absl::OkStatus(); +} + +} // namespace netkat diff --git a/netkat/symbolic_packet.h b/netkat/symbolic_packet.h index 30f9cca..3c445b7 100644 --- a/netkat/symbolic_packet.h +++ b/netkat/symbolic_packet.h @@ -13,7 +13,7 @@ // limitations under the License. // // ----------------------------------------------------------------------------- -// symbolic_packet.h +// File: symbolic_packet.h // ----------------------------------------------------------------------------- // // Defines `SymbolicPacket` and its companion class `SymbolicPacketManager`. @@ -29,34 +29,34 @@ // https://en.wikipedia.org/wiki/Binary_decision_diagram. // // ----------------------------------------------------------------------------- -// Why have a `SmbolicPacketManager` class? +// Why have a manager class? // ----------------------------------------------------------------------------- // -// With few exceptions, the APIs for creating, manipulating, and inspecting -// `SymbolicPacket`s are all defined as methods and static functions of the -// `SymbolicPacketManager` class. But why? +// The APIs for creating, manipulating, and inspecting `SymbolicPacket`s are all +// defined as methods of the `SymbolicPacketManager` class. But why? // -// TL;DR, all data associated with a `SymbolicPacket` is stored in data vectors -// owned by the manager class. Under the hood, a `SymbolicPacket` is just an -// index into these data vectors. This design pattern is motivated by -// computational and memory efficiency, and is standard for BDD-based libraries. +// TL;DR, all data associated with `SymbolicPacket`s is stored by the manager +// class; `SymbolicPacket` itself is just a lightweight (32-bit) handle. This +// design pattern is motivated by computational and memory efficiency, and is +// standard for BDD-based libraries. // // The manager object acts as an "arena" that owns and manages all memory // associated with `SymbolicPacket`s, enhancing data locality and sharing. This -// technique is known as hash-consing and is similar to the flyweight pattern -// and string interning. It has a long list of benefits, most importantly: +// technique is known as interning or hash-consing and is similar to the +// flyweight pattern. It has a long list of benefits, most importantly: // -// * Canonicity: Can gurantee that semantically identical `SymbolicPacket` are -// represnted by the same index into `SymbolicPacketManager`, making semantic -// `SymbolicPacket` comparison O(1) (just comparing two integers)! +// * Canonicity: Can guarantee that semantically identical `SymbolicPacket` are +// represented by the same handle, making semantic `SymbolicPacket` comparison +// O(1) (just comparing two integers)! // // * Memory efficiency: The graph structures used to encode symbolic packets are -// maximally shared across all packets, avoiding redundant copies. +// maximally shared across all packets, avoiding redundant copies of isomorph +// subgraphs. // // * Cache friendliness: Storing all data in contiguous arrays within the // manager improves data locality and thus cache utilization. // -// * Light representation: Since `SymbolicPacket`s are simply integres in +// * Light representation: Since `SymbolicPacket`s are simply integers in // memory, they are cheap to store, copy, compare, and hash. // // * Memoization: Thanks to canonicity and lightness of representation, @@ -65,72 +65,66 @@ // // SymbolicPacket, SymbolicPacket -> SymbolicPacket // -// can be memoized as a loockup table of type (int, int) -> int. +// can be memoized as a lookup table of type (int, int) -> int. // // ----------------------------------------------------------------------------- - +// // CAUTION: This implementation has NOT yet been optimized for performance. -// Performance can likely be improved significantly, e.g. as follows: -// * Profiling and benchmarking to identify inefficiencies. -// * Using standard techniques described in the literature on BDDs and other -// decision diagrams, see e.g. "Efficient Implementation of a BDD Package". +// See the TODOs in the cc file for low hanging fruit. Beyond known +// inefficiencies, performance can likely be improved significantly further +// through profiling and benchmarking. Also see "Efficient Implementation of a +// BDD Package" for standard techniques to improve performance. #ifndef GOOGLE_NETKAT_NETKAT_SYMBOLIC_PACKET_H_ #define GOOGLE_NETKAT_NETKAT_SYMBOLIC_PACKET_H_ +#include #include -#include +#include #include +#include "absl/container/fixed_array.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" #include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "netkat/evaluator.h" +#include "netkat/interned_field.h" #include "netkat/netkat.pb.h" +#include "netkat/paged_stable_vector.h" namespace netkat { -// TODO(smolkaj): Implement this. -class SymbolicPacketManager; - -// A "symbolic packet" is a data structure representing a set of packets. -// It is a compressed representation that can efficiently encode typical large -// and even infinite sets seen in practice. +// A "symbolic packet" is a lightweight handle (32 bits) that represents a set +// of packets. Handles can only be created by a `SymbolicPacketManager` object, +// which owns the graph-based representation of the set. The representation can +// efficiently encode typical large and even infinite sets seen in practice. +// +// The APIs of this object are almost entirely defined as methods of the +// companion class `SymbolicPacketManager`. See the section "Why have a maanger +// class?" at the top of the file to learn why. +// +// CAUTION: Each `SymbolicPacket` is implicitly associated with the manager +// object that created it; using it with a different manager has undefined +// behavior. +// +// This data structure enjoys the following powerful *canonicity property*: two +// symbolic packets represent the same set if and only if they have the same +// memory representation. Since the memory representation is just 32 bits, +// semantic set equality is cheap: O(1)! // // Compared to NetKAT predicates, which semantically also represent sets of // packets, symbolic packets have a few advantages: // * Cheap to store, copy, hash, and compare: O(1) // * Cheap to check set equality: O(1) // * Cheap to check set membership and set containment: O(# packet fields) -// -// NOTES ON THE API: -// * The majority of operations on `SymbolicPacket`s are defined as methods on -// the companion class `SymbolicPacketManager`. See the section "Why have a -// `SmbolicPacketManager` class?" at the top of the file to learn why. -// * Each `SymbolicPacket` is implicitly associated with the manager object that -// created it; using it with a different manager object has undefined -// behavior. -// * `EmptySet()` and `FullSet()` are exceptions to the above rule in that they -// work with any `SymbolicPacketManager` object. -class SymbolicPacket { +class [[nodiscard]] SymbolicPacket { public: // Default constructor: the empty set of packets. - SymbolicPacket() = default; - - // TODO(smolkaj): Move the EmptySet/FullSet APIs to the - // `SymbolicPacketManager` class for consistency. - - // The symbolic packet representing the empty set of packets. - static SymbolicPacket EmptySet() { return SymbolicPacket(kEmptySetIndex); } - - // The symbolic packet representing the set of all packets. - static SymbolicPacket FullSet() { return SymbolicPacket(kFullSetIndex); } - - // Returns true iff this symbolic packet represents the empty set of packets. - bool IsEmptySet() const { return node_index_ == kEmptySetIndex; } - - // Returns true iff this symbolic packet represents the set of all packets. - bool IsFullSet() const { return node_index_ == kFullSetIndex; } + SymbolicPacket(); // Two symbolic packets compare equal iff they represent the same set of - // concrete packets. Comparison is O(1), thanks to a canonical representation. + // concrete packets. Comparison is O(1), thanks to interning/hash-consing. friend auto operator<=>(SymbolicPacket a, SymbolicPacket b) = default; // Hashing, see https://abseil.io/docs/cpp/guides/hash. @@ -142,56 +136,213 @@ class SymbolicPacket { // Formatting, see https://abseil.io/docs/cpp/guides/abslstringify. template friend void AbslStringify(Sink& sink, SymbolicPacket packet) { - if (packet.IsEmptySet()) { - absl::Format(&sink, "SymbolicPacket"); - } else if (packet.IsFullSet()) { - absl::Format(&sink, "SymbolicPacket"); - } else { - absl::Format(&sink, "SymbolicPacket<%d>", packet.node_index_); - } + absl::Format(&sink, "%s", packet.ToString()); } + std::string ToString() const; + + private: + // An index into the `nodes_` vector of the `SymbolicPacketManager` object + // associated with this `SymbolicPacket`. The semantics of this symbolic + // packet is entirely determined by the node `nodes_[node_index_]`. The index + // is otherwise arbitrary and meaningless. + // + // We use a 32-bit index as a tradeoff between minimizing memory usage and + // maximizing the number of `SymbolicPacket`s that can be created, both + // aspects that impact how well we scale to large NetKAT models. We expect + // millions, but not billions, of symbolic packets in practice, and 2^32 ~= 4 + // billion. + uint32_t node_index_; + explicit SymbolicPacket(uint32_t node_index) : node_index_(node_index) {} + friend class SymbolicPacketManager; +}; + +// Protect against regressions in the memory layout, as it affects performance. +static_assert(sizeof(SymbolicPacket) <= 4); + +// An "arena" in which `SymbolicPacket`s can be created and manipulated. +// +// This class defines the majority of operations on `SymbolicPacket`s and owns +// all the memory associated with the `SymbolicPacket`s returned by the class's +// methods. +// +// CAUTION: Using a `SymbolicPacket` returned by one `SymbolicPacketManager` +// object with a different manager is undefined behavior. +class SymbolicPacketManager { + public: + SymbolicPacketManager() = default; + + // The class is move-only: not copyable, but movable. + SymbolicPacketManager(const SymbolicPacketManager&) = delete; + SymbolicPacketManager& operator=(const SymbolicPacketManager&) = delete; + SymbolicPacketManager(SymbolicPacketManager&&) = default; + SymbolicPacketManager& operator=(SymbolicPacketManager&&) = default; + + // Returns true iff this symbolic packet represents the empty set of packets. + bool IsEmptySet(SymbolicPacket packet) const; + + // Returns true iff this symbolic packet represents the set of all packets. + bool IsFullSet(SymbolicPacket packet) const; + + // Returns true if the set represented by `symbolic_packet` contains the given + // `concrete_packet`, or false otherwise. + bool Contains(SymbolicPacket symbolic_packet, Packet concrete_packet) const; + + // Compiles the given `PredicateProto` into a `SymbolicPacket` that + // represents the set of packets satisfying the predicate. + SymbolicPacket Compile(const PredicateProto& pred); + + // The symbolic packet representing the empty set of packets. + SymbolicPacket EmptySet() const; + + // The symbolic packet representing the set of all packets. + SymbolicPacket FullSet() const; + + // Returns the set of packets whose `field` is equal to `value`. + SymbolicPacket Match(absl::string_view field, int value); + + // Returns the set of packets that are in the `left` *AND* in the `right` set. + // Also known as set intersection. + SymbolicPacket And(SymbolicPacket left, SymbolicPacket right); + + // Returns the set of packets that are in the `left` *OR* in the `right` set. + // Also known as set union. + SymbolicPacket Or(SymbolicPacket left, SymbolicPacket right); + + // Returns the set of packets that are *NOT* in the given set. + // Also known as set complement. + SymbolicPacket Not(SymbolicPacket negand); + + // Returns a human-readable string representation of the given `packet`, + // intended for debugging. + [[nodiscard]] std::string PrettyPrint(SymbolicPacket packet) const; + + // Dynamically checks all class invariants. Exposed for testing only. + absl::Status CheckInternalInvariants() const; + + // There are many additional set operations supported by the data structure. + // We may implement them as needed. For example: + // * subset - is one set a subset of another? + // * witness - given a (non-empty) set, return one (or n) elements from the + // set. + // * sample - return a member from the set uniformly at random. private: - // In memory, a `SymbolicPacket` is just an integer, making it cheap to - // store, copy, hash, and compare. + // Internally, this class represents symbolic packets (and thus packet sets) + // as nodes in a directed acyclic graph (DAG). Each node branches based on the + // value of a single packet field, and each branch points to another + // symbolic packet, which in turn is either the full/empty set, or represented + // by another node in the graph. // - // The `node_index_` is either: - // * A sentinel value: - // * `kEmptySetIndex`, representing the empty set of packets. - // * `kFullSetIndex`, representing the full set of packets. - // * An index into the `nodes_` vector of the `SymbolicPacketManager` object - // associated with this `SymbolicPacket`. In this case, the semantics of - // this object is given by the node stored at `nodes_[node_index_]` in the - // manager object. The index is otherwise arbitrary and meaningless, and - // thus, so is this object unless we have access to the associated manager - // object. + // The graph is "ordered", "reduced", and contains no "isomorphic subgraphs": // - // We use a bit width of 32 as a tradeoff between minimizing memory usage - // (which is critical for scaling to large NetKAT models) and maximizing the - // number of `SymbolicPacket`s that can be created (which is also critical for - // scaling to large NetKAT models) -- we expect millions, but probably not - // billions, of symbolic packets in practice, and 2^32 ~= 4 billion. + // * Ordered: Along each path through the graph, fields increase strictly + // monotonically (with respect to `<` defined on `InternedField`s). + // * Reduced: Intutively, there exist no redundant branches or nodes. + // Invariants 1 and 2 on `branch_by_field_value` formalize this intuition. + // * No isomorphic subgraphs: Nodes are interned by the class, ensuring that + // structurally identical nodes are guaranteed to be stored by the class + // only once. Together with the other two properties, this implies that each + // node stored by the class represents a unique set of packets. // - // The sentinel values are chosen maximally to avoid collisions with valid - // indices, which are assigned dynamically by the manager object starting at - // 0. For performance reasons, there is no runtime protection against - // collisions and overflow if we create too many distinct `SymbolicPacket`s. + // This representation is closely related to Binary Decision Diagrams (BDDs), + // see https://en.wikipedia.org/wiki/Binary_decision_diagram. This variant of + // BDDs is described in the paper "KATch: A Fast Symbolic Verifier for + // NetKAT". + + // A decision node in the symbolic packet DAG. The node branches on the value + // of a single `field`, and (the consequent of) each branch is a + // `SymbolicPacket` corresponding to either another decision node or the + // full/empty set. Semantically, represents a cascading conditional of the + // form: // - // This data structure enjoys the following powerful canonicity property: two - // symbolic packets represent the same set if and only if they have the same - // `node_index_`. - uint32_t node_index_ = SentinelValue::kEmptySetIndex; - enum SentinelValue : uint32_t { - kEmptySetIndex = std::numeric_limits::max(), - kFullSetIndex = std::numeric_limits::max() - 1, + // if (field == value_1) then branch_1 + // else if (field == value_2) then branch_2 + // ... + // else default_branch + struct DecisionNode { + // The packet field whose value this decision node branches on. + // + // INVARIANTS: + // * Strictly smaller (`<`) than the fields of other decision nodes + // reachable from this node. + // * Interned by `field_manager_`. + InternedField field; + + // The consequent of the "else" branch of this decision node. + SymbolicPacket default_branch; + + // The "if" branches of the decision node, "keyed" by the value they branch + // on. Each element of the array is a (value, branch)-pair encoding + // "if (field == value) then branch". + // + // CHOICE OF DATA STRUCTURE: + // Logically this is a value -> branch map, but we store it as a fized-size + // array to optimize memory layout (contiguous, compact, flat), exploiting + // the following observations: + // * Nodes are not mutated after creation, so we can use a fixed-size + // container and save some bytes relative to dynamically-sized containers. + // * We don't need fast lookups, so we can avoid the overhead of + // lookup-optimized data structures like hash maps. + // + // INVARIANTS: + // 1. Maintained by `NodeToPacket`: `branch_by_field_value` is non-empty. + // (If it is empty, the decision node gets replaced by `default_branch`.) + // 2. Each branch is != `default_branch`. + // (If the branch is == `default_branch`, it must be omitted.) + // 3. The pairs are ordered by strictly increasing value. No two branches + // have the same value. + absl::FixedArray, + /*use_heap_allocation_above_size=*/0> + branch_by_field_value; + + // Protect against regressions in memory layout, as it affects performance. + static_assert(sizeof(branch_by_field_value) == 16); + + friend auto operator<=>(const DecisionNode& a, + const DecisionNode& b) = default; + + // Hashing, see https://abseil.io/docs/cpp/guides/hash. + template + friend H AbslHashValue(H h, const DecisionNode& node) { + return H::combine(std::move(h), node.field, node.default_branch, + node.branch_by_field_value); + } }; - explicit SymbolicPacket(uint32_t node_index) : node_index_(node_index) {} - friend class SymbolicPacketManager; -}; -static_assert( - sizeof(SymbolicPacket) <= 4, - "SymbolicPacket should have small memory footprint for performance"); + // Protect against regressions in memory layout, as it affects performance. + static_assert(sizeof(DecisionNode) == 24); + static_assert(alignof(DecisionNode) == 8); + + SymbolicPacket NodeToPacket(DecisionNode&& node); + + // Returns the `DecisionNode` corresponding to the given `SymbolicPacket`, or + // crashes if the `packet` is `EmptySet()` or `FullSet()`. + const DecisionNode& GetNodeOrDie(SymbolicPacket packet) const; + + [[nodiscard]] std::string PrettyPrint(const DecisionNode& node) const; + + // The page size of the `nodes_` vector: 64 MiB or ~ 67 MB. + // Chosen large enough to reduce the cost of dynamic allocation, and small + // enough to avoid excessive memory overhead. + static constexpr size_t kPageSize = (1 << 26) / sizeof(DecisionNode); + + // The decision nodes forming the BDD-style DAG representation of symbolic + // packets. `SymbolicPacket::node_index_` indexes into this vector. + // + // We use a custom vector class that provides pointer stability, allowing us + // to create new nodes while traversing the graph (e.g. during operations like + // `And`, `Or`, `Not`). The class also avoids expensive reallocations. + PagedStableVector nodes_; + + // A so called "unique table" to ensure each node is only added to `nodes_` + // once, and thus has a unique `SymbolicPacket::node_index`. + // + // INVARIANT: `packet_by_node_[n] = s` iff `nodes_[s.node_index_] == n`. + absl::flat_hash_map packet_by_node_; + + // INVARIANT: All `DecisionNode` fields are interned by this manager. + InternedFieldManager field_manager_; +}; } // namespace netkat diff --git a/netkat/symbolic_packet_test.cc b/netkat/symbolic_packet_test.cc index ff2c4bb..6e48115 100644 --- a/netkat/symbolic_packet_test.cc +++ b/netkat/symbolic_packet_test.cc @@ -15,44 +15,247 @@ #include "netkat/symbolic_packet.h" +#include + +#include "absl/base/no_destructor.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" +#include "fuzztest/fuzztest.h" +#include "gmock/gmock.h" #include "gtest/gtest.h" +#include "gutil/status_matchers.h" +#include "netkat/evaluator.h" +#include "netkat/netkat_proto_constructors.h" namespace netkat { -namespace { -TEST(SymbolicPacketTest, DefaultConstructorYieldsEmptySet) { - EXPECT_TRUE(SymbolicPacket().IsEmptySet()); +// We use a global manager object to exercise statefulness more deeply across +// test cases. This also enables better pretty printing for debugging, see +// `PrintTo`. +SymbolicPacketManager& Manager() { + static absl::NoDestructor manager; + return *manager; +} + +// The default `SymbolicPacket` pretty printer sucks! It does not have access to +// the graph structure representing the packet, since that is stored in the +// manager object. Thus, it returns opaque strings like "SymbolicPacket<123>". +// +// We define this much better override, which GoogleTest gives precedence to. +void PrintTo(const SymbolicPacket& packet, std::ostream* os) { + *os << Manager().PrettyPrint(packet); } -TEST(SymbolicPacketTest, EmptySetIsEmptySet) { - EXPECT_TRUE(SymbolicPacket::EmptySet().IsEmptySet()); +namespace { + +using ::testing::StartsWith; + +// After executing all tests, we check once that no invariants are violated, for +// defense in depth. Checking invariants after each test (e.g. using a fixture) +// would likely not scale and seems overkill. +class CheckSymbolicPacketManagerInvariantsOnTearDown + : public testing::Environment { + public: + ~CheckSymbolicPacketManagerInvariantsOnTearDown() override {} + void SetUp() override {} + void TearDown() override { ASSERT_OK(Manager().CheckInternalInvariants()); } +}; +testing::Environment* const foo_env = testing::AddGlobalTestEnvironment( + new CheckSymbolicPacketManagerInvariantsOnTearDown); + +TEST(SymbolicPacketManagerTest, EmptySetIsEmptySet) { + EXPECT_TRUE(Manager().IsEmptySet(Manager().EmptySet())); + EXPECT_FALSE(Manager().IsFullSet(Manager().EmptySet())); } -TEST(SymbolicPacketTest, FullSetIsFullSet) { - EXPECT_TRUE(SymbolicPacket::FullSet().IsFullSet()); +TEST(SymbolicPacketManagerTest, FullSetIsFullSet) { + EXPECT_TRUE(Manager().IsFullSet(Manager().FullSet())); + EXPECT_FALSE(Manager().IsEmptySet(Manager().FullSet())); } -TEST(SymbolicPacketTest, EmptySetDoesNotEqualFullSet) { - EXPECT_NE(SymbolicPacket::EmptySet(), SymbolicPacket::FullSet()); +TEST(SymbolicPacketManagerTest, EmptySetDoesNotEqualFullSet) { + EXPECT_NE(Manager().EmptySet(), Manager().FullSet()); } -TEST(SymbolicPacketTest, AbslStringifyWorksForEmptySet) { - EXPECT_EQ(absl::StrCat(SymbolicPacket::EmptySet()), "SymbolicPacket"); +TEST(SymbolicPacketManagerTest, AbslStringifyWorksForEmptySet) { + EXPECT_THAT(absl::StrCat(Manager().EmptySet()), StartsWith("SymbolicPacket")); } -TEST(SymbolicPacketTest, AbslStringifyWorksForFullSet) { - EXPECT_EQ(absl::StrCat(SymbolicPacket::FullSet()), "SymbolicPacket"); +TEST(SymbolicPacketManagerTest, AbslStringifyWorksForFullSet) { + EXPECT_THAT(absl::StrCat(Manager().FullSet()), StartsWith("SymbolicPacket")); } -TEST(SymbolicPacketTest, AbslHashValueWorks) { +TEST(SymbolicPacketManagerTest, AbslHashValueWorks) { absl::flat_hash_set set = { - SymbolicPacket::EmptySet(), - SymbolicPacket::FullSet(), + Manager().EmptySet(), + Manager().FullSet(), }; EXPECT_EQ(set.size(), 2); } +TEST(SymbolicPacketManagerTest, TrueCompilesToFullSet) { + EXPECT_EQ(Manager().Compile(TrueProto()), Manager().FullSet()); +} + +TEST(SymbolicPacketManagerTest, FalseCompilesToEmptySet) { + EXPECT_EQ(Manager().Compile(FalseProto()), Manager().EmptySet()); +} + +void MatchCompilesToMatch(std::string field, int value) { + EXPECT_EQ(Manager().Compile(MatchProto(field, value)), + Manager().Match(field, value)); +} +FUZZ_TEST(SymbolicPacketManagerTest, MatchCompilesToMatch); + +void AndCompilesToAnd(const PredicateProto& left, const PredicateProto& right) { + EXPECT_EQ(Manager().Compile(AndProto(left, right)), + Manager().And(Manager().Compile(left), Manager().Compile(right))); +} +FUZZ_TEST(SymbolicPacketManagerTest, AndCompilesToAnd); + +void OrCompilesToOr(const PredicateProto& left, const PredicateProto& right) { + EXPECT_EQ(Manager().Compile(OrProto(left, right)), + Manager().Or(Manager().Compile(left), Manager().Compile(right))); +} +FUZZ_TEST(SymbolicPacketManagerTest, OrCompilesToOr); + +void NotCompilesToNot(const PredicateProto& pred) { + EXPECT_EQ(Manager().Compile(NotProto(pred)), + Manager().Not(Manager().Compile(pred))); +} +FUZZ_TEST(SymbolicPacketManagerTest, NotCompilesToNot); + +void CompilationPreservesSemantics(const PredicateProto& pred, + const Packet& packet) { + EXPECT_EQ(Manager().Contains(Manager().Compile(pred), packet), + Evaluate(pred, packet)); +} +FUZZ_TEST(SymbolicPacketManagerTest, CompilationPreservesSemantics); + +void EqualPredicatesCompileToEqualSymbolicPackets(const PredicateProto& pred) { + EXPECT_EQ(Manager().Compile(pred), Manager().Compile(pred)); +} +FUZZ_TEST(SymbolicPacketManagerTest, + EqualPredicatesCompileToEqualSymbolicPackets); + +void NegationCompilesToDifferentSymbolicPacket(const PredicateProto& pred) { + EXPECT_NE(Manager().Compile(pred), Manager().Compile(NotProto(pred))); +} +FUZZ_TEST(SymbolicPacketManagerTest, NegationCompilesToDifferentSymbolicPacket); + +void DoubleNegationCompilesToSameSymbolicPacket(const PredicateProto& pred) { + EXPECT_EQ(Manager().Compile(pred), + Manager().Compile(NotProto(NotProto(pred)))); +} +FUZZ_TEST(SymbolicPacketManagerTest, + DoubleNegationCompilesToSameSymbolicPacket); + +TEST(SymbolicPacketManagerTest, TrueNotEqualsMatch) { + EXPECT_NE(Manager().Compile(TrueProto()), + Manager().Compile(MatchProto("hi", 42))); +} +TEST(SymbolicPacketManagerTest, FalseNotEqualsMatch) { + EXPECT_NE(Manager().Compile(FalseProto()), + Manager().Compile(MatchProto("hi", 42))); +} +TEST(SymbolicPacketManagerTest, MatchNotEqualsDifferentMatch) { + EXPECT_NE(Manager().Compile(MatchProto("hi", 42)), + Manager().Compile(MatchProto("bye", 42))); + EXPECT_NE(Manager().Compile(MatchProto("hi", 42)), + Manager().Compile(MatchProto("hi", 24))); +} +TEST(SymbolicPacketManagerTest, NotTrueEqualsFalse) { + EXPECT_EQ(Manager().Compile(NotProto(TrueProto())), + Manager().Compile(FalseProto())); +} + +void AndIsIdempotent(const PredicateProto& pred) { + EXPECT_EQ(Manager().Compile(AndProto(pred, pred)), Manager().Compile(pred)); +} +FUZZ_TEST(SymbolicPacketManagerTest, AndIsIdempotent); + +void OrIsIdempotent(const PredicateProto& pred) { + EXPECT_EQ(Manager().Compile(OrProto(pred, pred)), Manager().Compile(pred)); +} +FUZZ_TEST(SymbolicPacketManagerTest, OrIsIdempotent); + +void PredOrItsNegationIsTrue(const PredicateProto& pred) { + EXPECT_EQ(Manager().Compile(OrProto(pred, NotProto(pred))), + Manager().Compile(TrueProto())); +} +FUZZ_TEST(SymbolicPacketManagerTest, PredOrItsNegationIsTrue); + +void PredAndItsNegationIsFalse(const PredicateProto& pred) { + EXPECT_EQ(Manager().Compile(AndProto(pred, NotProto(pred))), + Manager().Compile(FalseProto())); +} +FUZZ_TEST(SymbolicPacketManagerTest, PredAndItsNegationIsFalse); + +void AndTrueIsIdentity(const PredicateProto& pred) { + EXPECT_EQ(Manager().Compile(AndProto(pred, TrueProto())), + Manager().Compile(pred)); +} +FUZZ_TEST(SymbolicPacketManagerTest, AndTrueIsIdentity); + +void OrFalseIsIdentity(const PredicateProto& pred) { + EXPECT_EQ(Manager().Compile(OrProto(pred, FalseProto())), + Manager().Compile(pred)); +} +FUZZ_TEST(SymbolicPacketManagerTest, OrFalseIsIdentity); + +void AndFalseIsFalse(const PredicateProto& pred) { + EXPECT_EQ(Manager().Compile(AndProto(pred, FalseProto())), + Manager().Compile(FalseProto())); +} +FUZZ_TEST(SymbolicPacketManagerTest, AndFalseIsFalse); + +void OrTrueIsTrue(const PredicateProto& pred) { + EXPECT_EQ(Manager().Compile(OrProto(pred, TrueProto())), + Manager().Compile(TrueProto())); +} +FUZZ_TEST(SymbolicPacketManagerTest, OrTrueIsTrue); + +void AndIsCommutative(const PredicateProto& a, const PredicateProto& b) { + EXPECT_EQ(Manager().Compile(AndProto(a, b)), + Manager().Compile(AndProto(b, a))); +} +FUZZ_TEST(SymbolicPacketManagerTest, AndIsCommutative); + +void OrIsCommutative(const PredicateProto& a, const PredicateProto& b) { + EXPECT_EQ(Manager().Compile(OrProto(a, b)), Manager().Compile(OrProto(b, a))); +} +FUZZ_TEST(SymbolicPacketManagerTest, OrIsCommutative); + +void DistributiveLawsHolds(const PredicateProto& a, const PredicateProto& b, + const PredicateProto& c) { + EXPECT_EQ(Manager().Compile(AndProto(a, OrProto(b, c))), + Manager().Compile(OrProto(AndProto(a, b), AndProto(a, c)))); + EXPECT_EQ(Manager().Compile(OrProto(a, AndProto(b, c))), + Manager().Compile(AndProto(OrProto(a, b), OrProto(a, c)))); +} +FUZZ_TEST(SymbolicPacketManagerTest, DistributiveLawsHolds); + +void DeMorgansLawsHolds(const PredicateProto& a, const PredicateProto& b) { + EXPECT_EQ(Manager().Compile(NotProto(AndProto(a, b))), + Manager().Compile(OrProto(NotProto(a), NotProto(b)))); + EXPECT_EQ(Manager().Compile(NotProto(OrProto(a, b))), + Manager().Compile(AndProto(NotProto(a), NotProto(b)))); +} +FUZZ_TEST(SymbolicPacketManagerTest, DeMorgansLawsHolds); + +void AndIsAssociative(const PredicateProto& a, const PredicateProto& b, + const PredicateProto& c) { + EXPECT_EQ(Manager().Compile(AndProto(a, AndProto(b, c))), + Manager().Compile(AndProto(AndProto(a, b), c))); +} +FUZZ_TEST(SymbolicPacketManagerTest, AndIsAssociative); + +void OrIsAssociative(const PredicateProto& a, const PredicateProto& b, + const PredicateProto& c) { + EXPECT_EQ(Manager().Compile(OrProto(a, OrProto(b, c))), + Manager().Compile(OrProto(OrProto(a, b), c))); +} +FUZZ_TEST(SymbolicPacketManagerTest, OrIsAssociative); + } // namespace } // namespace netkat