From ca4c503c448cfa3ed3cd4de2b8f9911276c9092e Mon Sep 17 00:00:00 2001 From: Steffen Smolka Date: Tue, 26 Nov 2024 22:49:52 -0800 Subject: [PATCH] [NetKAT] Finish (initial, inefficient) symbolic packet implementation. PiperOrigin-RevId: 700572230 --- gutil/BUILD.bazel | 149 +++++++++ gutil/README.md | 7 + gutil/proto.cc | 198 ++++++++++++ gutil/proto.h | 127 ++++++++ gutil/proto_matchers.h | 258 +++++++++++++++ gutil/proto_matchers_test.cc | 215 +++++++++++++ gutil/proto_string_error_collector.h | 56 ++++ gutil/proto_test.cc | 241 ++++++++++++++ gutil/proto_test.proto | 24 ++ gutil/status.cc | 63 ++++ gutil/status.h | 341 ++++++++++++++++++++ gutil/status_matchers.h | 146 +++++++++ gutil/status_matchers_test.cc | 117 +++++++ gutil/testing.cc | 49 +++ gutil/testing.h | 53 ++++ gutil/testing_test.cc | 47 +++ netkat/BUILD.bazel | 36 +++ netkat/netkat.proto | 2 +- netkat/netkat_proto_constructors.cc | 67 ++++ netkat/netkat_proto_constructors.h | 45 +++ netkat/netkat_proto_constructors_test.cc | 75 +++++ netkat/symbolic_packet.cc | 382 +++++++++++++++++++++++ netkat/symbolic_packet.h | 348 +++++++++++++++------ netkat/symbolic_packet_test.cc | 226 +++++++++++++- 24 files changed, 3163 insertions(+), 109 deletions(-) create mode 100644 gutil/BUILD.bazel create mode 100644 gutil/README.md create mode 100644 gutil/proto.cc create mode 100644 gutil/proto.h create mode 100644 gutil/proto_matchers.h create mode 100644 gutil/proto_matchers_test.cc create mode 100644 gutil/proto_string_error_collector.h create mode 100644 gutil/proto_test.cc create mode 100644 gutil/proto_test.proto create mode 100644 gutil/status.cc create mode 100644 gutil/status.h create mode 100644 gutil/status_matchers.h create mode 100644 gutil/status_matchers_test.cc create mode 100644 gutil/testing.cc create mode 100644 gutil/testing.h create mode 100644 gutil/testing_test.cc create mode 100644 netkat/netkat_proto_constructors.cc create mode 100644 netkat/netkat_proto_constructors.h create mode 100644 netkat/netkat_proto_constructors_test.cc create mode 100644 netkat/symbolic_packet.cc diff --git a/gutil/BUILD.bazel b/gutil/BUILD.bazel new file mode 100644 index 0000000..d663416 --- /dev/null +++ b/gutil/BUILD.bazel @@ -0,0 +1,149 @@ + + +package( + default_applicable_licenses = ["//:license"], + default_visibility = ["//visibility:private"], +) + +cc_library( + name = "testing", + testonly = True, + srcs = ["testing.cc"], + hdrs = ["testing.h"], + visibility = ["//:__subpackages__"], + deps = [ + ":proto", + ":status", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "testing_test", + srcs = ["testing_test.cc"], + deps = [ + ":testing", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "status", + srcs = ["status.cc"], + hdrs = ["status.h"], + visibility = ["//:__subpackages__"], + deps = [ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + ], +) + +cc_library( + name = "status_matchers", + testonly = True, + hdrs = ["status_matchers.h"], + visibility = ["//:__subpackages__"], + deps = [ + ":status", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_googletest//:gtest", + ], +) + +cc_test( + name = "status_matchers_test", + srcs = ["status_matchers_test.cc"], + deps = [ + ":status", + ":status_matchers", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "proto_string_error_collector", + hdrs = ["proto_string_error_collector.h"], + deps = [ + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "proto", + srcs = ["proto.cc"], + hdrs = ["proto.h"], + visibility = ["//:__subpackages__"], + deps = [ + ":proto_string_error_collector", + ":status", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "proto_test", + srcs = ["proto_test.cc"], + deps = [ + ":proto", + ":proto_matchers", + ":proto_test_cc_proto", + ":status", + ":status_matchers", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + ], +) + +proto_library( + name = "proto_test_proto", + srcs = ["proto_test.proto"], +) + +cc_proto_library( + name = "proto_test_cc_proto", + deps = [":proto_test_proto"], +) + +cc_library( + name = "proto_matchers", + testonly = True, + hdrs = ["proto_matchers.h"], + visibility = ["//:__subpackages__"], + deps = [ + ":proto", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "proto_matchers_test", + srcs = ["proto_matchers_test.cc"], + deps = [ + ":proto_matchers", + ":proto_test_cc_proto", + ":testing", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/gutil/README.md b/gutil/README.md new file mode 100644 index 0000000..9709876 --- /dev/null +++ b/gutil/README.md @@ -0,0 +1,7 @@ +# Gutil + +Code in this folder has been copied/modified from +https://github.com/google/gutil. + +THIS IS A TEMPORARY WORKAROUND until we can depend on the gutil repository +directly, which is tracked in b/380957915. diff --git a/gutil/proto.cc b/gutil/proto.cc new file mode 100644 index 0000000..a47cd49 --- /dev/null +++ b/gutil/proto.cc @@ -0,0 +1,198 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gutil/proto.h" + +#include + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/io/zero_copy_stream_impl.h" +#include "google/protobuf/message.h" +#include "google/protobuf/text_format.h" +#include "google/protobuf/util/json_util.h" +#include "gutil/proto_string_error_collector.h" +#include "gutil/status.h" + +namespace netkat { + +bool IsEmptyProto(const google::protobuf::Message &message) { + return message.ByteSizeLong() == 0; +} + +absl::Status ReadProtoFromFile(absl::string_view filename, + google::protobuf::Message *message) { + // Verifies that the version of the library that we linked against is + // compatible with the version of the headers we compiled against. + GOOGLE_PROTOBUF_VERIFY_VERSION; + + int fd = open(std::string(filename).c_str(), O_RDONLY); + if (fd < 0) { + return InvalidArgumentErrorBuilder() + << "Error opening the file " << filename; + } + + google::protobuf::io::FileInputStream file_stream(fd); + file_stream.SetCloseOnDelete(true); + + if (!google::protobuf::TextFormat::Parse(&file_stream, message)) { + return InvalidArgumentErrorBuilder() << "Failed to parse file " << filename; + } + + return absl::OkStatus(); +} + +absl::Status ReadProtoFromString(absl::string_view proto_string, + google::protobuf::Message *message) { + // Verifies that the version of the library that we linked against is + // compatible with the version of the headers we compiled against. + GOOGLE_PROTOBUF_VERIFY_VERSION; + + google::protobuf::TextFormat::Parser parser; + std::string all_errors; + StringErrorCollector collector(&all_errors); + parser.RecordErrorsTo(&collector); + + if (!parser.ParseFromString(std::string(proto_string), message)) { + return InvalidArgumentErrorBuilder() + << "string <" << proto_string << "> did not parse as a " + << message->GetTypeName() << ":\n" + << all_errors; + } + + return absl::OkStatus(); +} + +absl::Status SaveProtoToFile(absl::string_view filename, + const google::protobuf::Message &message) { + // Verifies that the version of the library that we linked against is + // compatible with the version of the headers we compiled against. + GOOGLE_PROTOBUF_VERIFY_VERSION; + int fd = + open(std::string(filename).c_str(), O_WRONLY | O_CREAT | O_TRUNC, 0644); + if (fd < 0) { + return InvalidArgumentErrorBuilder() + << "Error opening the file " << filename; + } + + google::protobuf::io::FileOutputStream file_stream(fd); + file_stream.SetCloseOnDelete(true); + + if (!google::protobuf::TextFormat::Print(message, &file_stream)) { + return InvalidArgumentErrorBuilder() + << "Failed to print proto to file " << filename; + } + + file_stream.Flush(); + return absl::OkStatus(); +} + +absl::StatusOr ProtoDiff( + const google::protobuf::Message &message1, + const google::protobuf::Message &message2, + google::protobuf::util::MessageDifferencer &differ) { + if (message1.GetDescriptor() != message2.GetDescriptor()) { + return netkat::InvalidArgumentErrorBuilder() + << "cannot compute diff for messages of incompatible descriptors `" + << message1.GetDescriptor()->full_name() << "' vs '" + << message2.GetDescriptor()->full_name() << "'"; + } + + std::string diff; + differ.ReportDifferencesToString(&diff); + ProtoEqual(message1, message2, differ); + return diff; +} + +// Calls `ProtoDiff` with default `MessageDifferencer`. +absl::StatusOr ProtoDiff( + const google::protobuf::Message &message1, + const google::protobuf::Message &message2) { + google::protobuf::util::MessageDifferencer differ = + google::protobuf::util::MessageDifferencer(); + return ProtoDiff(message1, message2, differ); +} + +bool ProtoEqual(const google::protobuf::Message &message1, + const google::protobuf::Message &message2, + google::protobuf::util::MessageDifferencer &differ) { + if (message1.GetDescriptor() != message2.GetDescriptor()) { + return false; + } + + return differ.Compare(message1, message2); +} + +// Calls `ProtoEqual` with default `MessageDifferencer`. +bool ProtoEqual(const google::protobuf::Message &message1, + const google::protobuf::Message &message2) { + google::protobuf::util::MessageDifferencer differ = + google::protobuf::util::MessageDifferencer(); + return ProtoEqual(message1, message2, differ); +} + +absl::StatusOr GetOneOfFieldName( + const google::protobuf::Message &message, const std::string &oneof_name) { + const auto *oneof_descriptor = + message.GetDescriptor()->FindOneofByName(oneof_name); + const auto *field = message.GetReflection()->GetOneofFieldDescriptor( + message, oneof_descriptor); + if (!field) { + return netkat::NotFoundErrorBuilder() + << "Oneof field \"" << oneof_name << "\" is not set"; + } + return std::string(field->name()); +} + +std::string PrintTextProto(const google::protobuf::Message &message) { + std::string message_text; + + google::protobuf::TextFormat::Printer printer; + printer.SetExpandAny(true); + + printer.PrintToString(message, &message_text); + + return message_text; +} + +// Print proto in TextFormat with single line mode enabled. +std::string PrintShortTextProto(const google::protobuf::Message &message) { + std::string message_short_text; + + google::protobuf::TextFormat::Printer printer; + printer.SetSingleLineMode(true); + printer.SetExpandAny(true); + + printer.PrintToString(message, &message_short_text); + // Single line mode currently might have an extra space at the end. + if (!message_short_text.empty() && message_short_text.back() == ' ') { + message_short_text.pop_back(); + } + + return message_short_text; +} + +absl::StatusOr SerializeProtoAsJson( + const google::protobuf::Message &proto) { + std::string json; + RETURN_IF_ERROR(netkat::ToAbslStatus( + google::protobuf::util::MessageToJsonString(proto, &json))); + return json; +} + +} // namespace netkat diff --git a/gutil/proto.h b/gutil/proto.h new file mode 100644 index 0000000..e082b4e --- /dev/null +++ b/gutil/proto.h @@ -0,0 +1,127 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GOOGLE_NETKAT_GUTIL_PROTO_H +#define GOOGLE_NETKAT_GUTIL_PROTO_H + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "google/protobuf/message.h" +#include "google/protobuf/util/json_util.h" +#include "google/protobuf/util/message_differencer.h" +#include "gutil/status.h" + +namespace netkat { + +// Returns `true` if the given `message` has no fields set, `false` otherwise. +bool IsEmptyProto(const google::protobuf::Message &message); + +// Read the contents of the file into a protobuf. +absl::Status ReadProtoFromFile(absl::string_view filename, + google::protobuf::Message *message); + +// Read the contents of the string into a protobuf. +absl::Status ReadProtoFromString(absl::string_view proto_string, + google::protobuf::Message *message); + +// Saves the content of a protobuf into a file. +absl::Status SaveProtoToFile(absl::string_view filename, + const google::protobuf::Message &message); + +// Read the contents of the given string into a protobuf and returns it. +template +absl::StatusOr ParseTextProto(absl::string_view proto_string) { + T message; + RETURN_IF_ERROR(ReadProtoFromString(proto_string, &message)); + return message; +} + +// Returns diff of the given protobuf messages, provided they have the same +// `Descriptor` (message1.GetDescriptor() == message2.GetDescriptor()), or +// returns `InvalidArgumentError` otherwise. Optionally, a `differ` can be +// provided for fine-grained control over how to compute the diff. +absl::StatusOr ProtoDiff( + const google::protobuf::Message &message1, + const google::protobuf::Message &message2, + google::protobuf::util::MessageDifferencer &differ); +absl::StatusOr ProtoDiff( + const google::protobuf::Message &message1, + const google::protobuf::Message &message2); + +// Similar to `ProtoDiff`, except returns boolean result of equality comparison. +bool ProtoEqual(const google::protobuf::Message &message1, + const google::protobuf::Message &message2, + google::protobuf::util::MessageDifferencer &differ); +bool ProtoEqual(const google::protobuf::Message &message1, + const google::protobuf::Message &message2); + +// Get the name of the oneof enum that is set. +// Eg: +// message IrValue { +// oneof format { +// string hex_str = 1; +// string ipv4 = 2; +// string ipv6 = 3; +// string mac = 4; +// string str = 5; +// } +// } +// IrValue value; +// value.set_hex_str("0xf00d"); +// std::string name = GetOneOfFieldName(value, std::string("format")); +// EXPECT_EQ(name, "hex_str"); +absl::StatusOr GetOneOfFieldName( + const google::protobuf::Message &message, const std::string &oneof_name); + +// Print proto in TextFormat. +std::string PrintTextProto(const google::protobuf::Message &message); + +// Print proto in TextFormat in a single line. +std::string PrintShortTextProto(const google::protobuf::Message &message); + +// Parses the given JSON string into a proto of type `T`. +template +absl::StatusOr ParseJsonAsProto(absl::string_view raw_json_string, + bool ignore_unknown_fields = false); + +// Serializes the given proto message as a JSON string. +absl::StatusOr SerializeProtoAsJson( + const google::protobuf::Message &proto); + +// -- END OF PUBLIC INTERFACE - Implementation details follow ------------------ + +template +absl::StatusOr ParseJsonAsProto(absl::string_view raw_json_string, + bool ignore_unknown_fields) { + google::protobuf::util::JsonParseOptions options; + options.ignore_unknown_fields = ignore_unknown_fields; + T proto; + // OS protobuf uses its own `Status`-like and `string_view`-like classes, so + // some gymnastics are required here: + // - ToAbslStatus converts any `Status`-like type to an absl::Status. + // - We pass in `{raw_json_string.data(), raw_json_string.size()}` instead of + // `raw_json_string`, constructing a new object of the appropriate + // `string_view`-like type implicitly. + RETURN_IF_ERROR( + netkat::ToAbslStatus(google::protobuf::util::JsonStringToMessage( + {raw_json_string.data(), raw_json_string.size()}, &proto, options))); + return proto; +} + +} // namespace netkat + +#endif // GOOGLE_NETKAT_GUTIL_PROTO_H diff --git a/gutil/proto_matchers.h b/gutil/proto_matchers.h new file mode 100644 index 0000000..d88bb5f --- /dev/null +++ b/gutil/proto_matchers.h @@ -0,0 +1,258 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GOOGLE_NETKAT_GUTIL_PROTO_MATCHERS_H +#define GOOGLE_NETKAT_GUTIL_PROTO_MATCHERS_H + +#include +#include +#include + +#include "absl/log/log.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "gmock/gmock.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" +#include "google/protobuf/util/message_differencer.h" +#include "gtest/gtest.h" +#include "gutil/proto.h" + +namespace netkat { + +// -- EqualsProto matcher ------------------------------------------------------ + +// Implements a protobuf matcher interface that verifies 2 protobufs are equal +// while ignoring the repeated field ordering. +// +// Sample usage: +// EXPECT_THAT(MyCall(), EqualsProto(R"pb( +// table_name: "ROUTER_INTERFACE_TABLE" +// priority: 123 +// matches { +// name: "router_interface_id" +// exact { str: "16" } +// } +// )pb")); +// +// Sample output on failure: +// Value of: MyCall() +// Expected: +// table_name: "ROUTER_INTERFACE_TABLE" +// matches { +// name: "router_interface_id" +// exact { +// str: "16" +// } +// } +// priority: 123 +// +// Actual: 96-byte object <58-AC 77-4D 5C-55 00-00 00-00 00-00 00-00 00-00 +// 00-00 00-00 00-00 00-00 01-00 00-00 04-00 00-00 30-8D 80-4D 5C-55 00-00 +// 30-55 80-4D 5C-55 00-00 20-25 79-4D 5C-55 00-00 00-00 00-00 00-00 00-00 +// 00-00 00-00 00-00 00-00 C8-01 00-00 5C-55 00-00 00-E6 12-7F 70-7F 00-00 +// 00-00 00-00 00-00 00-00> (of type pdpi::IrTableEntry), +// table_name: "ROUTER_INTERFACE_TABLE" +// matches { +// name: "router_interface_id" +// exact { +// str: "16" +// } +// } +// priority: 456 +class ProtobufEqMatcher { + public: + ProtobufEqMatcher(const google::protobuf::Message& expected) + : expected_(expected.New()), expected_text_(PrintTextProto(expected)) { + expected_->CopyFrom(expected); + } + ProtobufEqMatcher(absl::string_view expected_text) + : expected_text_{expected_text} {} + + void DescribeTo(std::ostream* os, bool negated) const { + *os << "is " << (negated ? "not " : "") << "equal to " + << (expected_ == nullptr ? "" + : absl::StrCat(expected_->GetTypeName(), " ")) + << "<\n" + << expected_text_ << ">"; + } + void DescribeTo(std::ostream* os) const { return DescribeTo(os, false); } + void DescribeNegationTo(std::ostream* os) const { + return DescribeTo(os, true); + } + + template + bool MatchAndExplain(const ProtoType& actual, + ::testing::MatchResultListener* listener) const { + std::string diff; + google::protobuf::util::MessageDifferencer differ; + differ.set_scope(comparison_scope_); + differ.ReportDifferencesToString(&diff); + // Order does not matter for repeated fields. + differ.set_repeated_field_comparison( + google::protobuf::util::MessageDifferencer::RepeatedFieldComparison:: + AS_SET); + + // When parsing from a proto text string we must first create a temporary + // with the same proto type as the "acutal" argument. + if (expected_ == nullptr) { + absl::StatusOr expected = + netkat::ParseTextProto(expected_text_); + if (expected.ok()) { + expected_ = std::make_shared(std::move(*expected)); + } else { + *listener << "where the expected proto " << expected.status().message(); + return false; + } + } + + // Otherwise we can compare directly with the passed protobuf message. + bool equal = differ.Compare(*expected_, actual); + if (!equal) { + *listener << "with diff:\n" << diff; + } + return equal; + } + void SetComparePartially() { + comparison_scope_ = google::protobuf::util::MessageDifferencer::PARTIAL; + } + + private: + mutable std::shared_ptr expected_; + std::string expected_text_; + google::protobuf::util::MessageDifferencer::Scope comparison_scope_ = + google::protobuf::util::MessageDifferencer::FULL; +}; + +inline ::testing::PolymorphicMatcher EqualsProto( + const google::protobuf::Message& proto) { + return ::testing::MakePolymorphicMatcher(ProtobufEqMatcher(proto)); +} + +inline ::testing::PolymorphicMatcher EqualsProto( + absl::string_view proto_text) { + return ::testing::MakePolymorphicMatcher(ProtobufEqMatcher(proto_text)); +} + +// Checks that a pair of protos are equal. Useful in combination with +// `Pointwise`. +MATCHER(EqualsProto, "is a pair of equal protobufs") { + const auto& [x, y] = arg; + return testing::ExplainMatchResult(EqualsProto(x), y, result_listener); +} + +// Checks that a sequences of protos is equal to a given sequence. +template +auto EqualsProtoSequence(T&& sequence) { + return testing::Pointwise(EqualsProto(), std::forward(sequence)); +} + +// -- HasOneofCaseMatcher matcher ---------------------------------------------- + +template +class HasOneofCaseMatcher { + public: + using is_gtest_matcher = void; + using OneofCase = int; + HasOneofCaseMatcher(absl::string_view oneof_name, + OneofCase expected_oneof_case) + : oneof_name_{oneof_name}, expected_oneof_case_(expected_oneof_case) {} + + void DescribeTo(std::ostream* os, bool negate) const { + if (os == nullptr) return; + *os << "is a `" << GetMessageDescriptor().full_name() + << "` protobuf message whose oneof field `" << oneof_name_ << "`"; + if (negate) { + *os << " does not have case "; + } else { + *os << " has case "; + } + *os << "`" << GetOneofCaseName(expected_oneof_case_) << "`"; + } + void DescribeTo(std::ostream* os) const { DescribeTo(os, false); } + void DescribeNegationTo(std::ostream* os) const { DescribeTo(os, true); } + + bool MatchAndExplain(const ProtoMessage& message, + testing::MatchResultListener* listener) const { + const google::protobuf::Message& m = message; + const google::protobuf::FieldDescriptor* set_oneof_field = + m.GetReflection()->GetOneofFieldDescriptor(m, GetOneofDescriptor()); + *listener << "the oneof `" << oneof_name_ << "` is "; + if (set_oneof_field == nullptr) { + *listener << "unset"; + return false; + } else { + *listener << "set to `" << set_oneof_field->name() << "`"; + return set_oneof_field->number() == expected_oneof_case_; + } + } + + private: + std::string oneof_name_; + OneofCase expected_oneof_case_; + + const google::protobuf::Descriptor& GetMessageDescriptor() const { + auto* descriptor = ProtoMessage::descriptor(); + if (descriptor == nullptr) { + LOG(FATAL) // Crash ok: test + << "ProtoMessage::descriptor() returned null."; + } + return *descriptor; + } + const google::protobuf::OneofDescriptor* GetOneofDescriptor() const { + return GetMessageDescriptor().FindOneofByName(oneof_name_); + } + const google::protobuf::FieldDescriptor* GetOneofCaseDescriptor( + OneofCase oneof_case) const { + return GetMessageDescriptor().FindFieldByNumber(oneof_case); + } + std::string GetOneofCaseName(OneofCase oneof_case) const { + const google::protobuf::FieldDescriptor* descriptor = + GetOneofCaseDescriptor(oneof_case); + return descriptor == nullptr ? "" + : std::string(descriptor->name()); + } +}; + +// Protobuf matcher that checks if the oneof field with the given `oneof_name` +// is set to the given `expected_oneof_case`. +// That is, checks `proto.oneof_name_case() == expected_oneof_case`. +// +// Sample usage: +// ``` +// EXPECT_THAT(packet.headers(0), +// HasOneofCase( +// "header", packetlib::Header::kIpv4Header)); +// ``` +template +HasOneofCaseMatcher HasOneofCase(absl::string_view oneof_name, + int expected_oneof_case) { + return HasOneofCaseMatcher(oneof_name, expected_oneof_case); +} + +// Partially(m) returns a matcher that is the same as m, except that +// only fields present in the expected protobuf are considered (using +// google::protobuf::util::MessageDifferencer's PARTIAL comparison option). For +// example, Partially(EqualsProto(p)) will ignore any field that's +// not set in p when comparing the protobufs. The inner matcher m can +// be any of the Equals* and EquivTo* protobuf matchers above. +template +inline InnerProtoMatcher Partially(InnerProtoMatcher inner_proto_matcher) { + inner_proto_matcher.mutable_impl().SetComparePartially(); + return inner_proto_matcher; +} + +} // namespace netkat + +#endif // GOOGLE_NETKAT_GUTIL_PROTO_MATCHERS_H diff --git a/gutil/proto_matchers_test.cc b/gutil/proto_matchers_test.cc new file mode 100644 index 0000000..5eabec5 --- /dev/null +++ b/gutil/proto_matchers_test.cc @@ -0,0 +1,215 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gutil/proto_matchers.h" + +#include +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "gutil/proto_test.pb.h" +#include "gutil/testing.h" + +namespace netkat { +namespace { + +using ::testing::Not; + +TEST(ProtoMatcher, EqualsProto) { + TestMessage message; + message.set_int_field(123); + message.set_string_field("foo"); + message.set_bool_field(true); + + EXPECT_THAT(message, EqualsProto(message)); +} + +TEST(ProtoMatcher, EqualsProtoFromText) { + TestMessage message; + message.set_int_field(123); + message.set_string_field("foo"); + message.set_bool_field(true); + + EXPECT_THAT(message, EqualsProto(R"pb(int_field: 123 + string_field: "foo" + bool_field: true)pb")); +} + +TEST(ProtoMatcher, DescribeEqualsProto) { + auto matcher = EqualsProto(netkat::ParseProtoOrDie(R"pb( + int_field: 123 + string_field: "foo" + bool_field: true + )pb")); + + EXPECT_EQ(testing::DescribeMatcher(matcher), + R"(is equal to netkat.TestMessage < +int_field: 123 +string_field: "foo" +bool_field: true +>)"); + EXPECT_EQ(testing::DescribeMatcher(Not(matcher)), + R"(is not equal to netkat.TestMessage < +int_field: 123 +string_field: "foo" +bool_field: true +>)"); +} + +TEST(ProtoMatcher, DescribeEqualsProtoFromText) { + std::string text = + R"pb(int_field: 123 string_field: "foo" bool_field: true)pb"; + auto matcher = EqualsProto(text); + + EXPECT_EQ(testing::DescribeMatcher(matcher), + R"(is equal to < +int_field: 123 string_field: "foo" bool_field: true>)"); + EXPECT_EQ(testing::DescribeMatcher(Not(matcher)), + R"(is not equal to < +int_field: 123 string_field: "foo" bool_field: true>)"); +} + +TEST(BinaryEqualsProtoTest, EqualPairWorks) { + auto arbitrary_proto_1 = ParseProtoOrDie("int_field: 24"); + EXPECT_THAT(std::make_pair(arbitrary_proto_1, arbitrary_proto_1), + EqualsProto()); + + auto arbitrary_proto_2 = ParseProtoOrDie("string_field: \"hi\""); + EXPECT_THAT(std::make_pair(arbitrary_proto_2, arbitrary_proto_2), + EqualsProto()); +} + +TEST(BinaryEqualsProtoTest, UnequalPairWorks) { + auto arbitrary_proto_1 = ParseProtoOrDie("int_field: 24"); + auto arbitrary_proto_2 = ParseProtoOrDie("string_field: \"hi\""); + EXPECT_THAT(std::make_pair(arbitrary_proto_1, arbitrary_proto_2), + Not(EqualsProto())); + EXPECT_THAT(std::make_pair(arbitrary_proto_2, arbitrary_proto_1), + Not(EqualsProto())); +} + +TEST(BinaryEqualsProtoTest, WorksWithPointwiseMatcher) { + std::vector protos = { + ParseProtoOrDie("int_field: 24"), + ParseProtoOrDie("string_field: \"hi\""), + }; + EXPECT_THAT(protos, testing::Pointwise(EqualsProto(), protos)); +} + +TEST(BinaryEqualsProtoTest, DescribeBinaryEqualsProto) { + auto matcher = EqualsProto(); + EXPECT_EQ( + (testing::DescribeMatcher>(matcher)), + "is a pair of equal protobufs"); +} + +TEST(EqualsProtoSequenceTest, EqualSequencesWork) { + std::vector protos = { + ParseProtoOrDie("int_field: 24"), + ParseProtoOrDie("string_field: \"hi\""), + }; + EXPECT_THAT(protos, EqualsProtoSequence(protos)); +} + +TEST(EqualsProtoSequenceTest, UnequalSequencesWork) { + std::vector protos1 = { + ParseProtoOrDie("int_field: 24"), + ParseProtoOrDie("string_field: \"hi\""), + }; + std::vector protos2 = { + ParseProtoOrDie("int_field: 42"), + }; + EXPECT_THAT(protos1, Not(EqualsProtoSequence(protos2))); +} + +TEST(PartiallyMatcherTest, IdenticalProtosAreAlsoPartiallyEqual) { + TestMessage message; + message.set_int_field(123); + + EXPECT_THAT(message, Partially(EqualsProto(message))); +} + +TEST(PartiallyMatcherTest, PartiallyEqualsProtoOnlyComparePresentFields) { + TestMessage message; + message.set_int_field(123); + + EXPECT_THAT(message, Partially(EqualsProto(R"pb( + int_field: 123)pb"))); +} + +TEST(PartiallyMatcherTest, DifferentlProtosDoNotMatch) { + TestMessage message; + message.set_int_field(123); + message.set_string_field("foo"); + + // Proto differs in one field and remains the same for another field does not + // match. + EXPECT_THAT(message, Not(Partially(EqualsProto(R"pb( + int_field: 1234 + string_field: "foo")pb")))); + // Proto differs in both fields should not match. + EXPECT_THAT(message, Not(Partially(EqualsProto(R"pb( + int_field: 1234 + string_field: "bar")pb")))); +} + +TEST(HasOneofCaseTest, NotHasOneofCase) { + EXPECT_THAT(netkat::ParseProtoOrDie(R"pb()pb"), + Not(HasOneofCase( + "foo", TestMessageWithOneof::kStringFoo))); + EXPECT_THAT(netkat::ParseProtoOrDie(R"pb( + int_foo: 42 + )pb"), + Not(HasOneofCase( + "foo", TestMessageWithOneof::kStringFoo))); + EXPECT_THAT(netkat::ParseProtoOrDie(R"pb( + int_foo: 42 + )pb"), + Not(HasOneofCase( + "foo", TestMessageWithOneof::kBoolFoo))); +} + +TEST(HasOneofCaseTest, DoesHaveOneofCase) { + EXPECT_THAT(netkat::ParseProtoOrDie(R"pb( + string_foo: "hi" + )pb"), + HasOneofCase( + "foo", TestMessageWithOneof::kStringFoo)); + EXPECT_THAT( + netkat::ParseProtoOrDie(R"pb( + int_foo: 42 + )pb"), + HasOneofCase("foo", TestMessageWithOneof::kIntFoo)); + EXPECT_THAT(netkat::ParseProtoOrDie(R"pb( + bool_foo: false + )pb"), + HasOneofCase( + "foo", TestMessageWithOneof::kBoolFoo)); +} + +TEST(HasOneofCaseTest, Description) { + auto matcher = HasOneofCase( + "foo", TestMessageWithOneof::kStringFoo); + EXPECT_EQ(testing::DescribeMatcher(matcher), + "is a `netkat.TestMessageWithOneof` protobuf message whose oneof " + "field `foo` has case `string_foo`"); + EXPECT_EQ(testing::DescribeMatcher(Not(matcher)), + "is a `netkat.TestMessageWithOneof` protobuf message whose oneof " + "field `foo` does not have case `string_foo`"); +} + +} // namespace +} // namespace netkat diff --git a/gutil/proto_string_error_collector.h b/gutil/proto_string_error_collector.h new file mode 100644 index 0000000..cddcecc --- /dev/null +++ b/gutil/proto_string_error_collector.h @@ -0,0 +1,56 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef GOOGLE_NETKAT_GUTIL_PROTO_STRING_ERROR_COLLECTOR_H_ +#define GOOGLE_NETKAT_GUTIL_PROTO_STRING_ERROR_COLLECTOR_H_ + +#include + +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "google/protobuf/io/tokenizer.h" + +namespace netkat { + +// Collects errors by appending them to a given string. +class StringErrorCollector : public google::protobuf::io::ErrorCollector { + public: + // String error_text is unowned and must remain valid during the use of + // StringErrorCollector. + explicit StringErrorCollector(std::string* error_text) + : error_text_{error_text} {}; + StringErrorCollector(const StringErrorCollector&) = delete; + StringErrorCollector& operator=(const StringErrorCollector&) = delete; + + // Implementation of protobuf::io::ErrorCollector::AddError. + void RecordError(int line, int column, absl::string_view message) override { + if (error_text_ != nullptr) { + absl::SubstituteAndAppend(error_text_, "$0($1): $2\n", line, column, + message); + } + } + + // Implementation of protobuf::io::ErrorCollector::RecordWarning. + void RecordWarning(int line, int column, absl::string_view message) override { + RecordError(line, column, message); + } + + private: + std::string* const error_text_; +}; + +} // namespace netkat + +#endif // GOOGLE_NETKAT_GUTIL_PROTO_STRING_ERROR_COLLECTOR_H_ diff --git a/gutil/proto_test.cc b/gutil/proto_test.cc new file mode 100644 index 0000000..5396400 --- /dev/null +++ b/gutil/proto_test.cc @@ -0,0 +1,241 @@ +#include "gutil/proto.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "gutil/proto_matchers.h" +#include "gutil/proto_test.pb.h" +#include "gutil/status.h" +#include "gutil/status_matchers.h" + +namespace netkat { +namespace { + +using ::netkat::IsOkAndHolds; +using ::netkat::StatusIs; +using ::testing::Eq; +using ::testing::IsEmpty; +using ::testing::Not; +using ::testing::ResultOf; + +// Get a writeable directory where bazel tests can save output files to. +// https://docs.bazel.build/versions/main/test-encyclopedia.html#initial-conditions +absl::StatusOr GetTestTmpDir() { + char* test_tmpdir = std::getenv("TEST_TMPDIR"); + if (test_tmpdir == nullptr) { + return netkat::InternalErrorBuilder() + << "Could not find environment variable ${TEST_TMPDIR}. Is this a " + "bazel test run?"; + } + return test_tmpdir; +} + +TEST(IsEmptyProto, ReturnsTrueForEmptyProto) { + EXPECT_TRUE(IsEmptyProto(TestMessage())); + + // Same things, but a bit more convoluted. + TestMessage message; + message.set_int_field(42); + message.set_int_field(0); + EXPECT_TRUE(IsEmptyProto(message)) + << "where message = " << message.DebugString(); +} + +TEST(IsEmptyProto, ReturnsFalseForNonEmptyProto) { + EXPECT_THAT(ParseTextProto("int_field: 42"), + IsOkAndHolds(ResultOf(IsEmptyProto, Eq(false)))); +} + +TEST(ParseTextProto, EmptyTextProtoIsOk) { + EXPECT_THAT(ParseTextProto(""), IsOk()); +} + +TEST(ParseTextProto, InvalidTextProtoIsNotOk) { + EXPECT_THAT(ParseTextProto("bytes_field: true"), Not(IsOk())); +} + +TEST(ParseTextProto, NonEmptyValidTextProtoIsParsedCorrectly) { + auto proto = ParseTextProto(R"pb( + int_field: 42 + string_field: "hello!" + bool_field: true + )pb"); + ASSERT_THAT(proto, IsOk()); + EXPECT_EQ(proto->int_field(), 42); + EXPECT_EQ(proto->string_field(), "hello!"); + EXPECT_EQ(proto->bool_field(), true); +} + +TEST(ProtoDiff, ReturnsErrorForIncompatibleMessages) { + ASSERT_OK_AND_ASSIGN(auto message1, ParseTextProto(R"pb( + int_field: 42 + string_field: "hello!" + bool_field: true + )pb")); + ASSERT_OK_AND_ASSIGN(auto message2, ParseTextProto(R"pb( + int_field: 42 + string_field: "hello!" + bool_field: true + )pb")); + EXPECT_THAT(ProtoDiff(message1, message2).status(), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(ProtoDiff, ReturnsEmptyDiffForEqualMessages) { + ASSERT_OK_AND_ASSIGN(auto message1, ParseTextProto(R"pb( + int_field: 42 + string_field: "hello!" + bool_field: true + )pb")); + EXPECT_THAT(ProtoDiff(message1, message1), IsOkAndHolds(IsEmpty())); +} + +TEST(ProtoDiff, ReturnsNonEmptyDiffForUnequalMessages) { + ASSERT_OK_AND_ASSIGN(auto message1, ParseTextProto(R"pb( + int_field: 42 + string_field: "hello!" + bool_field: true + )pb")); + ASSERT_OK_AND_ASSIGN(auto message2, ParseTextProto(R"pb( + int_field: 43 + string_field: "bye" + bool_field: false + )pb")); + EXPECT_THAT(ProtoDiff(message1, message2), IsOkAndHolds(Not(IsEmpty()))); +} + +TEST(ProtoEqual, ReturnsErrorForIncompatibleMessages) { + ASSERT_OK_AND_ASSIGN(auto message1, ParseTextProto(R"pb( + int_field: 42 + string_field: "hello!" + bool_field: true + )pb")); + ASSERT_OK_AND_ASSIGN(auto message2, ParseTextProto(R"pb( + int_field: 42 + string_field: "hello!" + bool_field: true + )pb")); + EXPECT_THAT(ProtoEqual(message1, message2), Eq(false)); +} + +TEST(ProtoEqual, ReturnsTrueForEqualMessages) { + ASSERT_OK_AND_ASSIGN(auto message1, ParseTextProto(R"pb( + int_field: 42 + string_field: "hello!" + bool_field: true + )pb")); + EXPECT_THAT(ProtoEqual(message1, message1), Eq(true)); +} + +TEST(ProtoEqual, ReturnsFalseForUnequalMessages) { + ASSERT_OK_AND_ASSIGN(auto message1, ParseTextProto(R"pb( + int_field: 42 + string_field: "hello!" + bool_field: true + )pb")); + ASSERT_OK_AND_ASSIGN(auto message2, ParseTextProto(R"pb( + int_field: 43 + string_field: "bye" + bool_field: false + )pb")); + EXPECT_THAT(ProtoEqual(message1, message2), Eq(false)); +} + +TEST(TextProtoHelpers, PrintTextProto) { + TestMessage message; + message.set_int_field(42); + message.set_string_field("bye"); + message.set_bool_field(true); + EXPECT_THAT(PrintTextProto(message), + "int_field: 42\nstring_field: \"bye\"\nbool_field: true\n"); +} + +TEST(TextProtoHelpers, PrintShortTextProto) { + TestMessage message; + message.set_int_field(42); + message.set_string_field("bye"); + message.set_bool_field(true); + EXPECT_THAT(PrintShortTextProto(message), + "int_field: 42 string_field: \"bye\" bool_field: true"); +} + +TEST(ParseJsonAsProto, ParsesTestMessage) { + EXPECT_THAT(ParseJsonAsProto(R"json({ + "int_field" : 42, + "string_field" : "bye", + "bool_field" : true + })json"), + IsOkAndHolds(EqualsProto(R"pb( + int_field: 42 + string_field: "bye" + bool_field: true + )pb"))); +} + +TEST(ParseJsonAsProto, CanIgnoreUnknownFields) { + EXPECT_THAT(ParseJsonAsProto(R"json({ + "int_field" : 42, + "string_field" : "bye", + "bool_field" : true, + "unknown_field": "please ignore" + })json", + /*ignore_unknown_field=*/false), + Not(IsOk())); + EXPECT_THAT(ParseJsonAsProto(R"json({ + "int_field" : 42, + "string_field" : "bye", + "bool_field" : true, + "unknown_field": "please ignore" + })json", + /*ignore_unknown_field=*/true), + IsOkAndHolds(EqualsProto(R"pb( + int_field: 42 + string_field: "bye" + bool_field: true + )pb"))); +} + +TEST(SerializeProtoAsJson, RoundTripsWithParseJsonAsProto) { + ASSERT_OK_AND_ASSIGN(auto proto, ParseTextProto(R"pb( + int_field: 42 + string_field: "hello!" + bool_field: true + )pb")); + ASSERT_OK_AND_ASSIGN(std::string json, SerializeProtoAsJson(proto)); + EXPECT_THAT(ParseJsonAsProto(json), + IsOkAndHolds(EqualsProto(proto))); +} + +TEST(SaveProtoToFile, SavesProtoToFileTruncatesFileOnOverwrite) { + ASSERT_OK_AND_ASSIGN(std::string test_tmpdir, GetTestTmpDir()); + std::string proto_save_path = + absl::StrCat(test_tmpdir, "/forwarding_config.pb.txt"); + ASSERT_OK_AND_ASSIGN(netkat::TestMessage proto, + ParseTextProto(R"pb( + int_field: 42 + string_field: "hello!" + bool_field: true + )pb")); + ASSERT_OK(SaveProtoToFile(proto_save_path, proto)); + + netkat::TestMessage read_proto; + ASSERT_OK(netkat::ReadProtoFromFile(proto_save_path, &read_proto)); + EXPECT_THAT(read_proto, EqualsProto(proto)); + + netkat::TestMessage empty_proto; + // Overite the saved file with empty proto. + ASSERT_OK(SaveProtoToFile(proto_save_path, empty_proto)); + + netkat::TestMessage read_empty_proto; + ASSERT_OK(netkat::ReadProtoFromFile(proto_save_path, &read_empty_proto)); + // Verify the file is truncated. + EXPECT_THAT(read_empty_proto, EqualsProto(empty_proto)); +} + +} // namespace +} // namespace netkat diff --git a/gutil/proto_test.proto b/gutil/proto_test.proto new file mode 100644 index 0000000..ac4679e --- /dev/null +++ b/gutil/proto_test.proto @@ -0,0 +1,24 @@ +syntax = "proto3"; + +package netkat; + +message TestMessage { + int32 int_field = 1; + string string_field = 2; + bool bool_field = 3; +} + +message AnotherTestMessage { + int32 int_field = 1; + string string_field = 2; + bool bool_field = 3; +} + +message TestMessageWithOneof { + oneof foo { + string string_foo = 1; + int32 int_foo = 2; + bool bool_foo = 3; + bytes bytes_foo = 4; + } +} diff --git a/gutil/status.cc b/gutil/status.cc new file mode 100644 index 0000000..41680a9 --- /dev/null +++ b/gutil/status.cc @@ -0,0 +1,63 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "gutil/status.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/cord.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" + +namespace netkat { + +std::string StableStatusToString(const absl::Status& status) { + return absl::StrCat(absl::StatusCodeToString(status.code()), ": ", + status.message(), "\n"); +} + +absl::Status netkat::StatusBuilder::GetStatusAndLog() const { + std::string message = source_; + switch (join_style_) { + case MessageJoinStyle::kPrepend: + absl::StrAppend(&message, stream_.str(), status_.message()); + break; + case MessageJoinStyle::kAppend: + absl::StrAppend(&message, status_.message(), stream_.str()); + break; + case MessageJoinStyle::kAnnotate: + default: { + if (!status_.message().empty() && !stream_.str().empty()) { + absl::StrAppend(&message, status_.message(), "; ", stream_.str()); + } else if (status_.message().empty()) { + absl::StrAppend(&message, stream_.str()); + } else { + absl::StrAppend(&message, status_.message()); + } + break; + } + } + if (log_error_ && status_.code() != absl::StatusCode::kOk) { + std::cout << message << std::endl; + } + absl::Status new_status(status_.code(), message); + status_.ForEachPayload( + [&new_status](absl::string_view url, const absl::Cord& cord) { + new_status.SetPayload(url, cord); + }); + return new_status; +} + +} // namespace netkat diff --git a/gutil/status.h b/gutil/status.h new file mode 100644 index 0000000..8ff4776 --- /dev/null +++ b/gutil/status.h @@ -0,0 +1,341 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#ifndef GOOGLE_NETKAT_GUTIL_STATUS_H +#define GOOGLE_NETKAT_GUTIL_STATUS_H + +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" + +namespace netkat { + +// Converts `status` to a readable string. The current absl `ToString` method is +// not stable, which causes issues while golden testing. This function is +// stable. +std::string StableStatusToString(const absl::Status& status); + +// Protobuf and some other Google projects use Status classes that are isomorph, +// but not equal to absl::Status (outside of google3). +// This auxiliary function converts such Status classes to absl::Status. +template +absl::Status ToAbslStatus(T status) { + return absl::Status( + static_cast(status.code()), + absl::string_view(status.message().data(), status.message().size())); +} + +// A proxy type and function for template type deduction for logging +// `absl::StatusOr`s. +// +// Can be removed once `absl::StatusOr` supports `operator<<`. +template +class StreamableStatusOrProxy { + public: + explicit StreamableStatusOrProxy(const absl::StatusOr& status_or) + : status_or_(status_or) {} + + StreamableStatusOrProxy(const StreamableStatusOrProxy&) = delete; + StreamableStatusOrProxy& operator=(const StreamableStatusOrProxy&) = delete; + + friend std::ostream& operator<<(std::ostream& os, + const StreamableStatusOrProxy& logger) { + if (logger.status_or_.ok()) return os << *logger.status_or_; + return os << logger.status_or_.status().code() << ": " + << logger.status_or_.status().message(); + } + + private: + const absl::StatusOr& status_or_; +}; + +template +StreamableStatusOrProxy StreamableStatusOr( + const absl::StatusOr& status_or) { + return StreamableStatusOrProxy(status_or); +} + +// StatusBuilder facilitates easier construction of Status objects with streamed +// message building. +// +// Example usage: +// absl::Status foo(int i) { +// if (i < 0) { +// return StatusBuilder(absl::StatusCode::kInvalidArgument) << "i=" << i; +// } +// } +class ABSL_MUST_USE_RESULT StatusBuilder { + public: + StatusBuilder(std::string file, int line, absl::StatusCode code) + : status_(absl::Status(code, "")), + log_error_(false), + join_style_(MessageJoinStyle::kAnnotate) { + source_ = absl::StrCat("[", file, ":", line, "]: "); + } + + explicit StatusBuilder(absl::StatusCode code) + : status_(absl::Status(code, "")), + log_error_(false), + join_style_(MessageJoinStyle::kAnnotate) {} + + explicit StatusBuilder(absl::Status status) + : status_(status), + log_error_(false), + join_style_(MessageJoinStyle::kAnnotate) {} + + StatusBuilder(const StatusBuilder& other) + : source_(other.source_), + status_(other.status_), + log_error_(other.log_error_), + join_style_(other.join_style_) { + stream_ << other.stream_.str(); + } + + // Streaming to the StatusBuilder appends to the error message. + template + ABSL_MUST_USE_RESULT StatusBuilder& operator<<(t val) { + stream_ << val; + return *this; + } + + // Makes the StatusBuilder print the error message (with source) when + // converting to a different type. + ABSL_MUST_USE_RESULT StatusBuilder& LogError() { + log_error_ = true; + return *this; + } + + // The additional message is prepended to the pre-existing status error + // message. No separator is placed between the messages. + ABSL_MUST_USE_RESULT StatusBuilder& SetPrepend() { + join_style_ = MessageJoinStyle::kPrepend; + return *this; + } + + // The additional message is appended to the pre-existing status error + // message. No separator is placed between the messages. + ABSL_MUST_USE_RESULT StatusBuilder& SetAppend() { + join_style_ = MessageJoinStyle::kAppend; + return *this; + } + + // Override the StatusCode in status_ to the given value. + ABSL_MUST_USE_RESULT StatusBuilder& SetCode(absl::StatusCode code) { + status_ = absl::Status(code, status_.message()); + return *this; + } + + ABSL_MUST_USE_RESULT StatusBuilder& SetPayload(absl::string_view url, + absl::Cord payload) { + status_.SetPayload(url, std::move(payload)); + return *this; + } + + // Implicit type conversions. + operator absl::Status() const { return GetStatusAndLog(); } + template + operator absl::StatusOr() const { + return absl::StatusOr(static_cast(*this)); + } + + private: + enum class MessageJoinStyle { + kAnnotate, + kAppend, + kPrepend, + }; + + std::string source_; + absl::Status status_; + std::stringstream stream_; + bool log_error_; + MessageJoinStyle join_style_; + + absl::Status GetStatusAndLog() const; +}; + +// Custom allocators for default StatusCodes. +class ABSL_MUST_USE_RESULT CancelledErrorBuilder : public StatusBuilder { + public: + CancelledErrorBuilder() : StatusBuilder(absl::StatusCode::kCancelled) {} +}; +class ABSL_MUST_USE_RESULT UnknownErrorBuilder : public StatusBuilder { + public: + UnknownErrorBuilder() : StatusBuilder(absl::StatusCode::kUnknown) {} +}; +class ABSL_MUST_USE_RESULT InvalidArgumentErrorBuilder : public StatusBuilder { + public: + InvalidArgumentErrorBuilder() + : StatusBuilder(absl::StatusCode::kInvalidArgument) {} +}; +class ABSL_MUST_USE_RESULT DeadlineExceededErrorBuilder : public StatusBuilder { + public: + DeadlineExceededErrorBuilder() + : StatusBuilder(absl::StatusCode::kDeadlineExceeded) {} +}; +class ABSL_MUST_USE_RESULT NotFoundErrorBuilder : public StatusBuilder { + public: + NotFoundErrorBuilder() : StatusBuilder(absl::StatusCode::kNotFound) {} +}; +class ABSL_MUST_USE_RESULT AlreadyExistsErrorBuilder : public StatusBuilder { + public: + AlreadyExistsErrorBuilder() + : StatusBuilder(absl::StatusCode::kAlreadyExists) {} +}; +class ABSL_MUST_USE_RESULT PermissionDeniedErrorBuilder : public StatusBuilder { + public: + PermissionDeniedErrorBuilder() + : StatusBuilder(absl::StatusCode::kPermissionDenied) {} +}; +class ABSL_MUST_USE_RESULT ResourceExhaustedErrorBuilder + : public StatusBuilder { + public: + ResourceExhaustedErrorBuilder() + : StatusBuilder(absl::StatusCode::kResourceExhausted) {} +}; +class ABSL_MUST_USE_RESULT FailedPreconditionErrorBuilder + : public StatusBuilder { + public: + FailedPreconditionErrorBuilder() + : StatusBuilder(absl::StatusCode::kFailedPrecondition) {} +}; +class ABSL_MUST_USE_RESULT AbortedErrorBuilder : public StatusBuilder { + public: + AbortedErrorBuilder() : StatusBuilder(absl::StatusCode::kAborted) {} +}; +class ABSL_MUST_USE_RESULT OutOfRangeErrorBuilder : public StatusBuilder { + public: + OutOfRangeErrorBuilder() : StatusBuilder(absl::StatusCode::kOutOfRange) {} +}; +class ABSL_MUST_USE_RESULT UnimplementedErrorBuilder : public StatusBuilder { + public: + UnimplementedErrorBuilder() + : StatusBuilder(absl::StatusCode::kUnimplemented) {} +}; +class ABSL_MUST_USE_RESULT InternalErrorBuilder : public StatusBuilder { + public: + InternalErrorBuilder() : StatusBuilder(absl::StatusCode::kInternal) {} +}; +class ABSL_MUST_USE_RESULT UnavailableErrorBuilder : public StatusBuilder { + public: + UnavailableErrorBuilder() : StatusBuilder(absl::StatusCode::kUnavailable) {} +}; +class ABSL_MUST_USE_RESULT DataLossErrorBuilder : public StatusBuilder { + public: + DataLossErrorBuilder() : StatusBuilder(absl::StatusCode::kDataLoss) {} +}; +class ABSL_MUST_USE_RESULT UnauthenticatedErrorBuilder : public StatusBuilder { + public: + UnauthenticatedErrorBuilder() + : StatusBuilder(absl::StatusCode::kUnauthenticated) {} +}; + +// status.h internal classes. Not for public use. +namespace status_internal { +// Holds a status builder in the '_' parameter. +class StatusBuilderHolder { + public: + StatusBuilderHolder(const absl::Status& status) : builder_(status) {} + StatusBuilderHolder(absl::Status&& status) : builder_(std::move(status)) {} + + StatusBuilder builder_; +}; +} // namespace status_internal + +} // namespace netkat + +// RETURN_IF_ERROR evaluates an expression that returns a absl::Status. If the +// result is not ok, returns a StatusBuilder for the result. Otherwise, +// continues. Because the macro ends in an unterminated StatusBuilder, all +// StatusBuilder extensions can be used. +// +// Example: +// absl::Status Foo() {...} +// absl::Status Bar() { +// RETURN_IF_ERROR(Foo()).LogError() << "Additional Info"; +// return absl::OkStatus() +// } +#define RETURN_IF_ERROR(expr) \ + for (absl::Status status = expr; !status.ok();) \ + return netkat::StatusBuilder(std::move(status)) + +// These macros help create unique variable names for ASSIGN_OR_RETURN. Not for +// public use. +#define __ASSIGN_OR_RETURN_VAL_DIRECT(arg) __ASSIGN_OR_RETURN_RESULT_##arg +#define __ASSIGN_OR_RETURN_VAL(arg) __ASSIGN_OR_RETURN_VAL_DIRECT(arg) + +// An implementation of ASSIGN_OR_RETURN that does not include a StatusBuilder. +// Not for public use. +#define __ASSIGN_OR_RETURN(dest, expr) \ + auto __ASSIGN_OR_RETURN_VAL(__LINE__) = expr; \ + if (!__ASSIGN_OR_RETURN_VAL(__LINE__).ok()) { \ + return __ASSIGN_OR_RETURN_VAL(__LINE__).status(); \ + } \ + dest = std::move(__ASSIGN_OR_RETURN_VAL(__LINE__)).value() + +// An implementation of ASSIGN_OR_RETURN that provides a StatusBuilder for extra +// processing. Not for public use. +#define __ASSIGN_OR_RETURN_STREAM(dest, expr, stream) \ + auto __ASSIGN_OR_RETURN_VAL(__LINE__) = expr; \ + if (!__ASSIGN_OR_RETURN_VAL(__LINE__).ok()) { \ + return ::netkat::status_internal::StatusBuilderHolder( \ + __ASSIGN_OR_RETURN_VAL(__LINE__).status()) \ + .builder##stream; \ + } \ + dest = std::move(__ASSIGN_OR_RETURN_VAL(__LINE__)).value() + +// Macro to choose the correct implementation for ASSIGN_OR_RETURN based on +// the number of inputs. Not for public use. +#define __ASSIGN_OR_RETURN_PICK(dest, expr, stream, func, ...) func + +// ASSIGN_OR_RETURN evaluates an expression that returns a StatusOr. If the +// result is ok, the value is saved to dest. Otherwise, the status is returned. +// +// Example: +// absl::StatusOr Foo() {...} +// absl::Status Bar() { +// ASSIGN_OR_RETURN(int value, Foo()); +// std::cout << "value: " << value; +// return absl::OkStatus(); +// } +// +// ASSIGN_OR_RETURN optionally takes in a third parameter that takes in +// absl::StatusBuilder commands. Usage should assume a StatusBuilder object is +// available and referred to as '_'. +// +// Example: +// absl::StatusOr Foo() {...} +// absl::Status Bar() { +// ASSIGN_OR_RETURN(int value, Foo(), _.LogError() << "Additional Info"); +// std::cout << "value: " << value; +// return absl::OkStatus(); +// } +#define ASSIGN_OR_RETURN(...) \ + __ASSIGN_OR_RETURN_PICK(__VA_ARGS__, __ASSIGN_OR_RETURN_STREAM, \ + __ASSIGN_OR_RETURN) \ + (__VA_ARGS__) + +// Returns an error if `cond` doesn't hold. +#define RET_CHECK(cond) \ + while (!(cond)) \ + return netkat::InternalErrorBuilder() << "(" << #cond << ") failed" + +#endif // GOOGLE_NETKAT_GUTIL_STATUS_H_ diff --git a/gutil/status_matchers.h b/gutil/status_matchers.h new file mode 100644 index 0000000..756873b --- /dev/null +++ b/gutil/status_matchers.h @@ -0,0 +1,146 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GOOGLE_NETKAT_GUTIL_STATUS_MATCHERS_H +#define GOOGLE_NETKAT_GUTIL_STATUS_MATCHERS_H + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "gutil/status.h" // IWYU pragma: keep + +namespace netkat { + +namespace internal { + +template +const Status& GetStatus(const Status& status) { + return status; +} + +template +const absl::Status& GetStatus(const absl::StatusOr& statusor) { + return statusor.status(); +} + +} // namespace internal + +MATCHER(IsOk, negation ? "is not OK" : "is OK") { + return internal::GetStatus(arg).ok(); +} + +// Convenience macros for checking that a status return type is okay. +#define EXPECT_OK(expression) EXPECT_THAT(expression, ::netkat::IsOk()) +#define ASSERT_OK(expression) ASSERT_THAT(expression, ::netkat::IsOk()) + +#ifndef __ASSIGN_OR_RETURN_VAL_DIRECT +#define __ASSIGN_OR_RETURN_VAL_DIRECT(arg) __ASSIGN_OR_RETURN_RESULT_##arg +#define __ASSIGN_OR_RETURN_VAL(arg) __ASSIGN_OR_RETURN_VAL_DIRECT(arg) +#endif + +// ASSERT_OK_AND_ASSIGN evaluates the expression (which needs to evaluate to a +// StatusOr) and asserts that the expression has status OK. It then assigns the +// result to lhs, and otherwise fails. +#define ASSERT_OK_AND_ASSIGN(lhs, expression) \ + auto __ASSIGN_OR_RETURN_VAL(__LINE__) = expression; \ + if (!__ASSIGN_OR_RETURN_VAL(__LINE__).status().ok()) { \ + FAIL() << #expression \ + << " failed: " << __ASSIGN_OR_RETURN_VAL(__LINE__).status(); \ + } \ + lhs = std::move(__ASSIGN_OR_RETURN_VAL(__LINE__)).value(); + +MATCHER_P(StatusIs, status_code, + absl::StrCat(negation ? "is not " : "is ", + absl::StatusCodeToString(status_code))) { + return internal::GetStatus(arg).code() == status_code; +} + +MATCHER_P2(StatusIs, status_code, message_matcher, + absl::StrFormat("is %s%s, %s has a status message that %s", + negation ? "not " : "", + absl::StatusCodeToString(status_code), + negation ? "or" : "and", + testing::DescribeMatcher( + message_matcher, negation))) { + const absl::Status& status = internal::GetStatus(arg); + return status.code() == status_code && + testing::ExplainMatchResult(message_matcher, status.message(), + result_listener); +} + +template +class IsOkAndHoldsMatcherImpl : public testing::MatcherInterface { + public: + using T = typename std::remove_reference_t::value_type; + + template + explicit IsOkAndHoldsMatcherImpl(InnerMatcher&& inner_matcher) + : inner_matcher_(testing::SafeMatcherCast( + std::forward(inner_matcher))) {} + + bool MatchAndExplain(StatusOrT t, + testing::MatchResultListener* listener) const override { + if (!t.ok()) { + *listener << "which has status " << t.status(); + return false; + } + return inner_matcher_.MatchAndExplain(*t, listener); + } + + void DescribeTo(std::ostream* os) const override { + *os << "is OK and has a value that "; + inner_matcher_.DescribeTo(os); + } + void DescribeNegationTo(std::ostream* os) const override { + *os << "is not OK or has a value that "; + inner_matcher_.DescribeNegationTo(os); + } + + private: + testing::Matcher inner_matcher_; +}; + +template +class IsOkAndHoldsMatcher { + public: + explicit IsOkAndHoldsMatcher(InnerMatcher&& inner_matcher) + : inner_matcher_(std::forward(inner_matcher)) {} + + template + operator testing::Matcher() const { // NOLINT + return testing::Matcher( + new IsOkAndHoldsMatcherImpl(inner_matcher_)); + } + + private: + InnerMatcher inner_matcher_; +}; + +template +IsOkAndHoldsMatcher IsOkAndHolds(InnerMatcher&& inner_matcher) { + return IsOkAndHoldsMatcher( + std::forward(inner_matcher)); +} + +} // namespace netkat + +#endif // GOOGLE_NETKAT_GUTIL_STATUS_MATCHERS_H diff --git a/gutil/status_matchers_test.cc b/gutil/status_matchers_test.cc new file mode 100644 index 0000000..b932472 --- /dev/null +++ b/gutil/status_matchers_test.cc @@ -0,0 +1,117 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "gutil/status_matchers.h" + +#include +#include + +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace netkat { +namespace { + +using ::testing::_; +using ::testing::Eq; +using ::testing::HasSubstr; +using ::testing::Not; + +TEST(AbseilStatusMatcher, IsOk) { EXPECT_THAT(absl::Status(), IsOk()); } + +TEST(AbseilStatusMatcher, IsNotOk) { + EXPECT_THAT(absl::UnknownError("unknown error"), Not(IsOk())); +} + +TEST(AbseilStatusMatcher, StatusIs) { + EXPECT_THAT(absl::UnknownError("unknown error"), + StatusIs(absl::StatusCode::kUnknown)); +} + +TEST(AbseilStatusMatcher, StatusIsNot) { + EXPECT_THAT(absl::UnknownError("unknown error"), + Not(StatusIs(absl::StatusCode::kInvalidArgument))); +} + +TEST(AbseilStatusMatcher, StatusIsWithMessage) { + EXPECT_THAT(absl::UnknownError("unknown error"), + StatusIs(absl::StatusCode::kUnknown, "unknown error")); +} + +TEST(AbseilStatusMatcher, StatusIsWithMessageNot) { + EXPECT_THAT(absl::UnknownError("unknown error"), + Not(StatusIs(absl::StatusCode::kInvalidArgument, "unknown"))); +} + +TEST(AbslStatusOrMatcher, IsOk) { EXPECT_THAT(absl::StatusOr(1), IsOk()); } + +TEST(AbslStatusOrMatcher, IsNotOk) { + EXPECT_THAT(absl::StatusOr(absl::UnknownError("unknown error")), + Not(IsOk())); +} + +TEST(AbslStatusOrMatcher, StatusIs) { + EXPECT_THAT(absl::StatusOr(absl::UnknownError("unknown error")), + StatusIs(absl::StatusCode::kUnknown)); +} + +TEST(AbslStatusOrMatcher, StatusIsNot) { + EXPECT_THAT(absl::StatusOr(absl::UnknownError("unknown error")), + Not(StatusIs(absl::StatusCode::kInvalidArgument))); +} + +TEST(AbslStatusOrMatcher, StatusIsWithMessage) { + EXPECT_THAT(absl::StatusOr(absl::UnknownError("unknown error")), + StatusIs(absl::StatusCode::kUnknown, HasSubstr("unknown"))); +} + +TEST(AbslStatusOrMatcher, StatusIsWithMessageNot) { + EXPECT_THAT(absl::StatusOr(absl::UnknownError("unknown error")), + Not(StatusIs(absl::StatusCode::kInvalidArgument, "unknown"))); +} + +TEST(AbslStatusOrMatcher, StatusIsOkAndHolds) { + EXPECT_THAT(absl::StatusOr(1320), IsOkAndHolds(1320)); +} + +TEST(AbslStatusOrMatcher, StatusIsNotOkAndHolds) { + EXPECT_THAT(absl::StatusOr(1320), Not(IsOkAndHolds(0))); +} + +TEST(AbslStatusOrMatcher, StatusIsOkAndHoldsWithExpectation) { + EXPECT_THAT(absl::StatusOr("The quick brown fox"), + IsOkAndHolds(HasSubstr("fox"))); +} + +// This test will fail to build if the macro doesn't work. +TEST(AbslStatusOrMatcher, AssignOrReturnWorksWithMoveOnlyTypes) { + ASSERT_OK_AND_ASSIGN( + auto value_from_expression, + absl::StatusOr>(absl::make_unique(0))); +} + +TEST(IsOkAndHoldsTest, Description) { + auto describe = [](const auto& matcher) { + return testing::DescribeMatcher>(matcher); + }; + EXPECT_EQ(describe(IsOkAndHolds(_)), + "is OK and has a value that is anything"); + EXPECT_EQ(describe(Not(IsOkAndHolds(Eq(4)))), + "is not OK or has a value that isn't equal to 4"); +} + +} // namespace +} // namespace netkat diff --git a/gutil/testing.cc b/gutil/testing.cc new file mode 100644 index 0000000..876ece4 --- /dev/null +++ b/gutil/testing.cc @@ -0,0 +1,49 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gutil/testing.h" + +#include + +#include "absl/strings/ascii.h" +#include "absl/strings/string_view.h" + +namespace netkat { + +// Implementation taken from protobuf descriptor library. +std::string SnakeCaseToCamelCase(absl::string_view input, bool lower_first) { + bool capitalize_next = true; + std::string result; + result.reserve(input.size()); + + for (char character : input) { + if (character == '_') { + capitalize_next = true; + } else if (capitalize_next) { + result.push_back(absl::ascii_toupper(character)); + capitalize_next = false; + } else { + result.push_back(character); + } + } + + // Lower-case the first letter. + if (lower_first && !result.empty()) { + result[0] = absl::ascii_tolower(result[0]); + } + + return result; +} + +} // namespace netkat diff --git a/gutil/testing.h b/gutil/testing.h new file mode 100644 index 0000000..f4fcfad --- /dev/null +++ b/gutil/testing.h @@ -0,0 +1,53 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GOOGLE_NETKAT_GUTIL_TESTING_H +#define GOOGLE_NETKAT_GUTIL_TESTING_H + +#include + +#include "absl/log/check.h" +#include "absl/strings/string_view.h" +#include "gutil/proto.h" + +namespace netkat { + +// Parses a protobuf from a string, and crashes if parsing failed. Only use in +// tests. +template +T ParseProtoOrDie(absl::string_view proto_string) { + T message; + CHECK_OK(ReadProtoFromString(proto_string, &message)); // Crash OK + return message; +} + +// Parses a protobuf from a file, and crashes if parsing failed. Only use in +// tests. +template +T ParseProtoFileOrDie(absl::string_view proto_file) { + T message; + CHECK_OK(ReadProtoFromFile(proto_file, &message)); // Crash OK + return message; +} + +// Takes a snake_case string and returns a CamelCase string. If `lower_first` is +// set, the first character will be lowercase (if a letter) and otherwise it +// will be uppercase. +// Used to e.g. convert snake case strings to GTEST compatible test names. +std::string SnakeCaseToCamelCase(absl::string_view input, + bool lower_first = false); + +} // namespace netkat + +#endif // GOOGLE_NETKAT_GUTIL_TESTING_H diff --git a/gutil/testing_test.cc b/gutil/testing_test.cc new file mode 100644 index 0000000..6f78c05 --- /dev/null +++ b/gutil/testing_test.cc @@ -0,0 +1,47 @@ +#include "gutil/testing.h" + +#include "absl/strings/str_cat.h" +#include "gtest/gtest.h" + +namespace netkat { +namespace { + +TEST(SnakeCaseToCamelCaseTest, WorksForSomeStandardInputs) { + EXPECT_EQ(SnakeCaseToCamelCase("my_camel_case"), "MyCamelCase"); + EXPECT_EQ(SnakeCaseToCamelCase("word"), "Word"); + EXPECT_EQ(SnakeCaseToCamelCase("two_words"), "TwoWords"); + EXPECT_EQ(SnakeCaseToCamelCase("3_words"), "3Words"); + EXPECT_EQ(SnakeCaseToCamelCase("_my_camel_case_"), "MyCamelCase"); +} + +TEST(SnakeCaseToCamelCaseTest, LowerFirstWorks) { + EXPECT_EQ(SnakeCaseToCamelCase("my_camel_case", /*lower_first=*/true), + "myCamelCase"); + EXPECT_EQ(SnakeCaseToCamelCase("word", /*lower_first=*/true), "word"); + EXPECT_EQ(SnakeCaseToCamelCase("two_words", /*lower_first=*/true), + "twoWords"); + EXPECT_EQ(SnakeCaseToCamelCase("3_words", /*lower_first=*/true), "3Words"); + EXPECT_EQ(SnakeCaseToCamelCase("_my_camel_case_", /*lower_first=*/true), + "myCamelCase"); +} + +TEST(SnakeCaseToCamelCaseTest, WorksForWeirdInputs) { + for (bool lower_first : {true, false}) { + EXPECT_EQ(SnakeCaseToCamelCase("_with__extra_underlines_", lower_first), + absl::StrCat(lower_first ? "w" : "W", "ithExtraUnderlines")); + EXPECT_EQ(SnakeCaseToCamelCase("alreadyCamelCase", lower_first), + absl::StrCat(lower_first ? "a" : "A", "lreadyCamelCase")); + // Note that only the first letter after each '_' and the first letter + // changes case. + EXPECT_EQ(SnakeCaseToCamelCase("wEiRd_cASiNg", lower_first), + absl::StrCat(lower_first ? "w" : "W", "EiRdCASiNg")); + EXPECT_EQ(SnakeCaseToCamelCase("?weird_first_character", lower_first), + "?weirdFirstCharacter"); + EXPECT_EQ( + SnakeCaseToCamelCase("many_\nnon-letter..._char:acters", lower_first), + absl::StrCat(lower_first ? "m" : "M", "any\nnon-letter...Char:acters")); + } +} + +} // namespace +} // namespace netkat diff --git a/netkat/BUILD.bazel b/netkat/BUILD.bazel index 876f055..f133666 100644 --- a/netkat/BUILD.bazel +++ b/netkat/BUILD.bazel @@ -29,9 +29,19 @@ cc_test( cc_library( name = "symbolic_packet", + srcs = ["symbolic_packet.cc"], hdrs = ["symbolic_packet.h"], deps = [ + ":evaluator", ":netkat_cc_proto", + "//gutil:status", + "@com_google_absl//absl/container:fixed_array", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", ], ) @@ -40,9 +50,13 @@ cc_test( name = "symbolic_packet_test", srcs = ["symbolic_packet_test.cc"], deps = [ + ":evaluator", + ":netkat_proto_constructors", ":symbolic_packet", + "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", + "@com_google_fuzztest//fuzztest", "@com_google_googletest//:gtest_main", ], ) @@ -58,6 +72,28 @@ cc_library( ], ) +cc_library( + name = "netkat_proto_constructors", + srcs = ["netkat_proto_constructors.cc"], + hdrs = ["netkat_proto_constructors.h"], + deps = [ + ":netkat_cc_proto", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_test( + name = "netkat_proto_constructors_test", + srcs = ["netkat_proto_constructors_test.cc"], + deps = [ + ":netkat_cc_proto", + ":netkat_proto_constructors", + "//gutil:proto_matchers", + "@com_google_fuzztest//fuzztest", + "@com_google_googletest//:gtest_main", + ], +) + cc_test( name = "evaluator_test", srcs = ["evaluator_test.cc"], diff --git a/netkat/netkat.proto b/netkat/netkat.proto index 0b7bef8..d61a5e2 100644 --- a/netkat/netkat.proto +++ b/netkat/netkat.proto @@ -13,7 +13,7 @@ // limitations under the License. // // ----------------------------------------------------------------------------- -// netkat.proto +// File: netkat.proto // ----------------------------------------------------------------------------- // // Proto representation of NetKAT programs (predicates and policies). diff --git a/netkat/netkat_proto_constructors.cc b/netkat/netkat_proto_constructors.cc new file mode 100644 index 0000000..e62b4fe --- /dev/null +++ b/netkat/netkat_proto_constructors.cc @@ -0,0 +1,67 @@ +// Copyright 2024 The NetKAT authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include "netkat/netkat_proto_constructors.h" + +#include +#include + +#include "absl/strings/string_view.h" +#include "netkat/netkat.pb.h" + +namespace netkat { + +PredicateProto TrueProto() { + auto result = PredicateProto(); + result.mutable_bool_constant()->set_value(true); + return result; +} + +PredicateProto FalseProto() { + auto result = PredicateProto(); + result.mutable_bool_constant()->set_value(false); + return result; +} + +PredicateProto MatchProto(absl::string_view field, int value) { + auto result = PredicateProto(); + PredicateProto::Match& match = *result.mutable_match(); + match.set_field(std::string(field)); + match.set_value(value); + return result; +} +PredicateProto AndProto(PredicateProto left, PredicateProto right) { + auto result = PredicateProto(); + PredicateProto::And& and_op = *result.mutable_and_op(); + *and_op.mutable_left() = std::move(left); + *and_op.mutable_right() = std::move(right); + return result; +} + +PredicateProto OrProto(PredicateProto left, PredicateProto right) { + auto result = PredicateProto(); + PredicateProto::Or& or_op = *result.mutable_or_op(); + *or_op.mutable_left() = std::move(left); + *or_op.mutable_right() = std::move(right); + return result; +} +PredicateProto NotProto(PredicateProto negand) { + auto result = PredicateProto(); + PredicateProto::Not& not_op = *result.mutable_not_op(); + *not_op.mutable_negand() = std::move(negand); + return result; +} + +} // namespace netkat diff --git a/netkat/netkat_proto_constructors.h b/netkat/netkat_proto_constructors.h new file mode 100644 index 0000000..40973b1 --- /dev/null +++ b/netkat/netkat_proto_constructors.h @@ -0,0 +1,45 @@ +// Copyright 2024 The NetKAT authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// ----------------------------------------------------------------------------- +// File: netkat_proto_helpers.h +// ----------------------------------------------------------------------------- +// +// Helper functions to make constructing netkat.proto messages more readable, +// specifically in unit test where readability is key. + +#ifndef GOOGLE_NETKAT_NETKAT_NETKAT_PROTO_CONSTRUCTORS_H_ +#define GOOGLE_NETKAT_NETKAT_NETKAT_PROTO_CONSTRUCTORS_H_ + +#include "absl/strings/string_view.h" +#include "netkat/netkat.pb.h" + +namespace netkat { + +// -- Predicate constructors --------------------------------------------------- + +PredicateProto TrueProto(); +PredicateProto FalseProto(); +PredicateProto MatchProto(absl::string_view field, int value); +PredicateProto AndProto(PredicateProto left, PredicateProto right); +PredicateProto OrProto(PredicateProto left, PredicateProto right); +PredicateProto NotProto(PredicateProto negand); + +// -- Policy constructors ------------------------------------------------------ + +// TODO - smolkaj: Add policy constructors when needed. + +} // namespace netkat + +#endif // GOOGLE_NETKAT_NETKAT_NETKAT_PROTO_CONSTRUCTORS_H_ diff --git a/netkat/netkat_proto_constructors_test.cc b/netkat/netkat_proto_constructors_test.cc new file mode 100644 index 0000000..155e556 --- /dev/null +++ b/netkat/netkat_proto_constructors_test.cc @@ -0,0 +1,75 @@ +// Copyright 2024 The NetKAT authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "netkat/netkat_proto_constructors.h" + +#include + +#include "fuzztest/fuzztest.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "gutil/proto_matchers.h" +#include "netkat/netkat.pb.h" + +namespace netkat { +namespace { + +using ::netkat::EqualsProto; + +TEST(TrueProtoTest, ReturnsTrueProto) { + EXPECT_THAT(TrueProto(), EqualsProto(R"pb(bool_constant { value: true })pb")); +} + +TEST(FalseProtoTest, ReturnsFalseProto) { + EXPECT_THAT(FalseProto(), + EqualsProto(R"pb(bool_constant { value: false })pb")); +} + +void MatchProtoReturnsMatch(std::string field, int value) { + auto match_proto = PredicateProto(); + auto& match = *match_proto.mutable_match(); + match.set_field(field); + match.set_value(value); + EXPECT_THAT(MatchProto(field, value), EqualsProto(match_proto)); +} +FUZZ_TEST(AndProtoTest, MatchProtoReturnsMatch); + +void AndProtoReturnsAnd(PredicateProto left, PredicateProto right) { + auto and_proto = PredicateProto(); + auto& and_op = *and_proto.mutable_and_op(); + *and_op.mutable_left() = left; + *and_op.mutable_right() = right; + EXPECT_THAT(AndProto(left, right), EqualsProto(and_proto)); +} +FUZZ_TEST(AndProtoTest, AndProtoReturnsAnd); + +void OrProtoReturnsOr(PredicateProto left, PredicateProto right) { + auto or_proto = PredicateProto(); + auto& or_op = *or_proto.mutable_or_op(); + *or_op.mutable_left() = left; + *or_op.mutable_right() = right; + EXPECT_THAT(OrProto(left, right), EqualsProto(or_proto)); +} +FUZZ_TEST(OrProtoTest, OrProtoReturnsOr); + +void NotProtoReturnsNot(PredicateProto negand) { + auto not_proto = PredicateProto(); + auto& not_op = *not_proto.mutable_not_op(); + *not_op.mutable_negand() = negand; + EXPECT_THAT(NotProto(negand), EqualsProto(not_proto)); +} +FUZZ_TEST(NotProtoTest, NotProtoReturnsNot); + +} // namespace +} // namespace netkat diff --git a/netkat/symbolic_packet.cc b/netkat/symbolic_packet.cc new file mode 100644 index 0000000..79d3426 --- /dev/null +++ b/netkat/symbolic_packet.cc @@ -0,0 +1,382 @@ +// Copyright 2024 The NetKAT authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ----------------------------------------------------------------------------- + +#include "netkat/symbolic_packet.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/escaping.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "gutil/status.h" +#include "netkat/evaluator.h" + +namespace netkat { + +// The empty and full set of packets cannot be represented as decision nodes, +// and thus we cannot associate an index into the `nodes_` vector with them. +// Instead, we represent them using sentinel values, chosen maximally to avoid +// collisions with proper indices. +enum SentinelNodeIndex : uint32_t { + kEmptySet = std::numeric_limits::max(), + kFullSet = std::numeric_limits::max() - 1, +}; + +SymbolicPacket::SymbolicPacket() : node_index_(SentinelNodeIndex::kEmptySet) {} + +std::string SymbolicPacket::ToString() const { + if (node_index_ == SentinelNodeIndex::kEmptySet) { + return "SymbolicPacket"; + } else if (node_index_ == SentinelNodeIndex::kFullSet) { + return "SymbolicPacket"; + } else { + return absl::StrFormat("SymbolicPacket<%d>", node_index_); + } +} + +// The symbolic packet representing the empty set of packets. +SymbolicPacket SymbolicPacketManager::EmptySet() { + return SymbolicPacket(SentinelNodeIndex::kEmptySet); +} + +// The symbolic packet representing the set of all packets. +SymbolicPacket SymbolicPacketManager::FullSet() { + return SymbolicPacket(SentinelNodeIndex::kFullSet); +} + +// Returns true iff this symbolic packet represents the empty set of +bool SymbolicPacketManager::IsEmptySet(SymbolicPacket packet) const { + return packet == EmptySet(); +} + +// Returns true iff this symbolic packet represents the set of all packets. +bool SymbolicPacketManager::IsFullSet(SymbolicPacket packet) const { + return packet == FullSet(); +} + +const SymbolicPacketManager::DecisionNode& SymbolicPacketManager::GetNodeOrDie( + SymbolicPacket packet) const { + CHECK_LT(packet.node_index_, nodes_.size()); // Crash ok + return nodes_[packet.node_index_]; +} + +const std::string& SymbolicPacketManager::GetFieldNameOrDie(Field field) const { + CHECK_LT(field.index_, fields_.size()); // Crash ok + return fields_[field.index_]; +} + +Field SymbolicPacketManager::GetField(const std::string& field_name) { + auto [it, inserted] = + field_by_name_.try_emplace(field_name, Field(fields_.size())); + if (inserted) fields_.push_back(field_name); + return it->second; +} + +SymbolicPacket SymbolicPacketManager::NodeToPacket(DecisionNode&& node) { + if (node.branch_by_field_value.empty()) return node.default_branch; + +// When in debug mode, we check key invariants before creating a new node, to +// ease debugging. +#ifndef NDEBUG + for (const auto& [value, branch] : node.branch_by_field_value) { + CHECK(branch != node.default_branch) << PrettyPrint(node); + if (!IsEmptySet(branch) && !IsFullSet(branch)) { + auto& branch_node = GetNodeOrDie(branch); + CHECK(branch_node.field > node.field) << absl::StreamFormat( + "(%v > %v)\n---branch---\n%s\n---node---\n%s", branch_node.field, + node.field, PrettyPrint(branch), PrettyPrint(node)); + } + } +#endif + + auto [it, inserted] = + packet_by_node_.try_emplace(node, SymbolicPacket(nodes_.size())); + if (inserted) nodes_.push_back(std::move(node)); + return it->second; +} + +bool SymbolicPacketManager::Contains(SymbolicPacket symbolic_packet, + Packet concrete_packet) const { + if (IsEmptySet(symbolic_packet)) return false; + if (IsFullSet(symbolic_packet)) return true; + + const DecisionNode& node = GetNodeOrDie(symbolic_packet); + const std::string& field = GetFieldNameOrDie(node.field); + auto it = concrete_packet.find(field); + if (it != concrete_packet.end()) { + for (const auto& [value, branch] : node.branch_by_field_value) { + if (it->second == value) return Contains(branch, concrete_packet); + } + } + return Contains(node.default_branch, concrete_packet); +} + +SymbolicPacket SymbolicPacketManager::Compile(const PredicateProto& pred) { + switch (pred.predicate_case()) { + case PredicateProto::kBoolConstant: + return pred.bool_constant().value() ? FullSet() : EmptySet(); + case PredicateProto::kMatch: + return Match(pred.match().field(), pred.match().value()); + case PredicateProto::kAndOp: + return And(Compile(pred.and_op().left()), Compile(pred.and_op().right())); + case PredicateProto::kOrOp: + return Or(Compile(pred.or_op().left()), Compile(pred.or_op().right())); + case PredicateProto::kNotOp: + return Not(Compile(pred.not_op().negand())); + // By convention, uninitialized predicates must be treated like `false`. + case PredicateProto::PREDICATE_NOT_SET: + return EmptySet(); + } + LOG(FATAL) << "Unhandled predicate kind: " << pred.predicate_case(); +} + +SymbolicPacket SymbolicPacketManager::Match(const std::string& field, + int value) { + return NodeToPacket(DecisionNode{ + .field = GetField(field), + .default_branch = EmptySet(), + .branch_by_field_value = {{value, FullSet()}}, + }); +} + +// TODO(smolkaj): Use complement edges. +SymbolicPacket SymbolicPacketManager::Not(SymbolicPacket negand) { + // Base cases. + if (IsEmptySet(negand)) return FullSet(); + if (IsFullSet(negand)) return EmptySet(); + + // Compute result the hard way. + const DecisionNode& negand_node = GetNodeOrDie(negand); + DecisionNode result_node{ + .field = negand_node.field, + .default_branch = Not(negand_node.default_branch), + .branch_by_field_value{negand_node.branch_by_field_value.size()}, + }; + + for (int i = 0; i < negand_node.branch_by_field_value.size(); ++i) { + auto [value, branch] = negand_node.branch_by_field_value[i]; + SymbolicPacket negated_branch = Not(branch); + DCHECK(branch != negand_node.default_branch); + DCHECK(negated_branch != result_node.default_branch); + result_node.branch_by_field_value[i] = + std::make_pair(value, negated_branch); + } + + return NodeToPacket(std::move(result_node)); +} + +SymbolicPacket SymbolicPacketManager::And(SymbolicPacket left, + SymbolicPacket right) { + // Base cases. + if (IsEmptySet(left) || IsFullSet(right) || left == right) return left; + if (IsEmptySet(right) || IsFullSet(left)) return right; + + // Compute result the hard way. + const DecisionNode* left_node = &GetNodeOrDie(left); + const DecisionNode* right_node = &GetNodeOrDie(right); + + // We exploit that `And` is commutative to canonicalize the order of the + // arguments, reducing the number of cases by 1. + if (left_node->field > right_node->field) { + std::swap(left, right); + std::swap(left_node, right_node); + } + + // Case 1: left_node->field < right_node->field: branch on left field. + if (left_node->field < right_node->field) { + SymbolicPacket default_branch = And(left_node->default_branch, right); + std::vector> branch_by_field_value; + branch_by_field_value.reserve(left_node->branch_by_field_value.size()); + for (const auto& [value, left_branch] : left_node->branch_by_field_value) { + SymbolicPacket branch = And(left_branch, right); + if (branch == default_branch) continue; + branch_by_field_value.push_back(std::make_pair(value, branch)); + } + return NodeToPacket(DecisionNode{ + .field = left_node->field, + .default_branch = default_branch, + .branch_by_field_value{ + branch_by_field_value.begin(), + branch_by_field_value.end(), + }, + }); + } + + // Case 2: left_node->field == right_node->field: branch on shared field. + DCHECK(left_node->field == right_node->field); + SymbolicPacket default_branch = + And(left_node->default_branch, right_node->default_branch); + std::vector> branch_by_field_value; + branch_by_field_value.reserve(left_node->branch_by_field_value.size() + + right_node->branch_by_field_value.size()); + auto add_branch = [&](int value, SymbolicPacket branch) { + if (branch == default_branch) return; + branch_by_field_value.push_back(std::make_pair(value, branch)); + }; + auto left_it = left_node->branch_by_field_value.begin(); + auto left_end = left_node->branch_by_field_value.end(); + auto right_it = right_node->branch_by_field_value.begin(); + auto right_end = right_node->branch_by_field_value.end(); + while (left_it != left_end && right_it != right_end) { + auto [left_value, left_branch] = *left_it; + auto [right_value, right_branch] = *right_it; + if (left_value < right_value) { + add_branch(left_value, And(left_branch, right_node->default_branch)); + ++left_it; + } else if (left_value > right_value) { + add_branch(right_value, And(left_node->default_branch, right_branch)); + ++right_it; + } else { // left_value == right_value + add_branch(left_value, And(left_branch, right_branch)); + ++left_it; + ++right_it; + } + } + for (; left_it != left_end; ++left_it) { + auto [left_value, left_branch] = *left_it; + add_branch(left_value, And(left_branch, right_node->default_branch)); + } + for (; right_it != right_end; ++right_it) { + auto [right_value, right_branch] = *right_it; + add_branch(right_value, And(left_node->default_branch, right_branch)); + } + return NodeToPacket(DecisionNode{ + .field = left_node->field, + .default_branch = default_branch, + .branch_by_field_value{ + branch_by_field_value.begin(), + branch_by_field_value.end(), + }, + }); +} + +SymbolicPacket SymbolicPacketManager::Or(SymbolicPacket left, + SymbolicPacket right) { + // Apply De Morgan's law: a || b == !(!a && !b). + // TODO(smolkaj): Implement complement edges so this becomes efficient. + return Not(And(Not(left), Not(right))); +} + +std::string SymbolicPacketManager::PrettyPrint(SymbolicPacket packet) const { + std::string result; + std::deque work_list = {packet}; + absl::flat_hash_set visited = {packet}; + while (!work_list.empty()) { + SymbolicPacket packet = work_list.front(); + work_list.pop_front(); + absl::StrAppend(&result, packet); + + if (IsFullSet(packet) || IsEmptySet(packet)) continue; + + const DecisionNode& node = GetNodeOrDie(packet); + absl::StrAppend(&result, ":\n"); + std::string field = absl::StrFormat( + "%v:'%s'", node.field, absl::CEscape(GetFieldNameOrDie(node.field))); + for (const auto& [value, branch] : node.branch_by_field_value) { + absl::StrAppendFormat(&result, " %s == %d -> %v\n", field, value, + branch); + if (IsFullSet(branch) || IsEmptySet(branch)) continue; + bool new_branch = visited.insert(branch).second; + if (new_branch) work_list.push_back(branch); + } + SymbolicPacket fallthrough = node.default_branch; + absl::StrAppendFormat(&result, " %s == * -> %v\n", field, fallthrough); + if (IsFullSet(fallthrough) || IsEmptySet(fallthrough)) continue; + bool new_branch = visited.insert(fallthrough).second; + if (new_branch) work_list.push_back(fallthrough); + } + return result; +} + +std::string SymbolicPacketManager::PrettyPrint(const DecisionNode& node) const { + std::string result; + std::vector work_list; + std::string field = absl::StrFormat( + "%v:'%s'", node.field, absl::CEscape(GetFieldNameOrDie(node.field))); + for (const auto& [value, branch] : node.branch_by_field_value) { + absl::StrAppendFormat(&result, " %s == %d -> %v\n", field, value, branch); + if (!IsFullSet(branch) || !IsEmptySet(branch)) work_list.push_back(branch); + } + SymbolicPacket fallthrough = node.default_branch; + absl::StrAppendFormat(&result, " %s == * -> %v\n", field, fallthrough); + if (!IsFullSet(fallthrough) && !IsEmptySet(fallthrough)) { + work_list.push_back(fallthrough); + } + + for (const SymbolicPacket& branch : work_list) { + absl::StrAppend(&result, PrettyPrint(branch)); + } + + return result; +} + +absl::Status SymbolicPacketManager::CheckInternalInvariants() const { + // Invariant: `packet_by_node_[n] = s` iff `nodes_[s.node_index_] == n`. + for (const auto& [node, packet] : packet_by_node_) { + RET_CHECK(packet.node_index_ < nodes_.size()); + RET_CHECK(nodes_[packet.node_index_] == node); + } + for (int i = 0; i < nodes_.size(); ++i) { + const DecisionNode& node = nodes_[i]; + auto it = packet_by_node_.find(node); + RET_CHECK(it != packet_by_node_.end()); + RET_CHECK(it->second == SymbolicPacket(i)); + } + + // Invariant: `field_by_name_[n] = f` iff `fields_[f.index_] == n`. + for (const auto& [field_name, field] : field_by_name_) { + RET_CHECK(field.index_ < fields_.size()); + RET_CHECK(fields_[field.index_] == field_name); + } + for (int i = 0; i < fields_.size(); ++i) { + const std::string& field_name = fields_[i]; + auto it = field_by_name_.find(field_name); + RET_CHECK(it != field_by_name_.end()); + RET_CHECK(it->second == Field(i)); + } + + // Node Invariants. + for (const DecisionNode& node : nodes_) { + // Invariant: `branch_by_field_value` is non-empty. + // Maintained by `NodeToPacket`. + RET_CHECK(!node.branch_by_field_value.empty()); + + // Invariant: node field is strictly smaller than sub-node fields. + RET_CHECK(IsFullSet(node.default_branch) || + IsEmptySet(node.default_branch) || + GetNodeOrDie(node.default_branch).field > node.field); + for (const auto& [value, branch] : node.branch_by_field_value) { + RET_CHECK(IsFullSet(branch) || IsEmptySet(branch) || + GetNodeOrDie(branch).field > node.field); + + // Invariant: Each case in `branch_by_field_value` is != + // `default_branch`. + RET_CHECK(branch != node.default_branch); + } + } + + return absl::OkStatus(); +} + +} // namespace netkat diff --git a/netkat/symbolic_packet.h b/netkat/symbolic_packet.h index 30f9cca..62f42ce 100644 --- a/netkat/symbolic_packet.h +++ b/netkat/symbolic_packet.h @@ -13,7 +13,7 @@ // limitations under the License. // // ----------------------------------------------------------------------------- -// symbolic_packet.h +// File: symbolic_packet.h // ----------------------------------------------------------------------------- // // Defines `SymbolicPacket` and its companion class `SymbolicPacketManager`. @@ -32,31 +32,32 @@ // Why have a `SmbolicPacketManager` class? // ----------------------------------------------------------------------------- // -// With few exceptions, the APIs for creating, manipulating, and inspecting -// `SymbolicPacket`s are all defined as methods and static functions of the -// `SymbolicPacketManager` class. But why? +// The APIs for creating, manipulating, and inspecting `SymbolicPacket`s are all +// defined as methods and static functions of the `SymbolicPacketManager` class. +// But why? // -// TL;DR, all data associated with a `SymbolicPacket` is stored in data vectors -// owned by the manager class. Under the hood, a `SymbolicPacket` is just an -// index into these data vectors. This design pattern is motivated by -// computational and memory efficiency, and is standard for BDD-based libraries. +// TL;DR, all data associated with `SymbolicPacket`s is stored by the manager +// class; `SymbolicPacket` itself is just a lightweight (32-bit) handle. This +// design pattern is motivated by computational and memory efficiency, and is +// standard for BDD-based libraries. // // The manager object acts as an "arena" that owns and manages all memory // associated with `SymbolicPacket`s, enhancing data locality and sharing. This -// technique is known as hash-consing and is similar to the flyweight pattern -// and string interning. It has a long list of benefits, most importantly: +// technique is known as interning or hash-consing and is similar to the +// flyweight pattern. It has a long list of benefits, most importantly: // // * Canonicity: Can gurantee that semantically identical `SymbolicPacket` are -// represnted by the same index into `SymbolicPacketManager`, making semantic -// `SymbolicPacket` comparison O(1) (just comparing two integers)! +// represnted by the same handle, making semantic `SymbolicPacket` comparison +// O(1) (just comparing two integers)! // // * Memory efficiency: The graph structures used to encode symbolic packets are -// maximally shared across all packets, avoiding redundant copies. +// maximally shared across all packets, avoiding redundant copies of isomorph +// subgraphs. // // * Cache friendliness: Storing all data in contiguous arrays within the // manager improves data locality and thus cache utilization. // -// * Light representation: Since `SymbolicPacket`s are simply integres in +// * Light representation: Since `SymbolicPacket`s are simply integers in // memory, they are cheap to store, copy, compare, and hash. // // * Memoization: Thanks to canonicity and lightness of representation, @@ -68,7 +69,7 @@ // can be memoized as a loockup table of type (int, int) -> int. // // ----------------------------------------------------------------------------- - +// // CAUTION: This implementation has NOT yet been optimized for performance. // Performance can likely be improved significantly, e.g. as follows: // * Profiling and benchmarking to identify inefficiencies. @@ -79,58 +80,50 @@ #define GOOGLE_NETKAT_NETKAT_SYMBOLIC_PACKET_H_ #include -#include +#include #include +#include +#include "absl/container/fixed_array.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" #include "absl/strings/str_format.h" +#include "netkat/evaluator.h" #include "netkat/netkat.pb.h" namespace netkat { -// TODO(smolkaj): Implement this. -class SymbolicPacketManager; - -// A "symbolic packet" is a data structure representing a set of packets. -// It is a compressed representation that can efficiently encode typical large -// and even infinite sets seen in practice. +// A "symbolic packet" is a lightweight handle (32 bits) that represents a set +// of packets. Handles can only be created by a `SymbolicPacketManager` object, +// which owns the graph-based representation of the set. The representation can +// efficiently encode typical large and even infinite sets seen in practice. +// +// The APIs of this object are almost entirely defined as methods and static +// members function of the companion class `SymbolicPacketManager`. See the +// section "Why have a `SmbolicPacketManager` class?" at the top of the file to +// learn why. +// +// CAUTION: Each `SymbolicPacket` is implicitly associated with the manager +// object that created it; using it with a different manager has undefined +// behavior. +// +// This data structure enjoys the following powerful *canonicity property*: two +// symbolic packets represent the same set if and only if they have the same +// memory representation. Since the memory representation is just 32 bits, set +// equality is cheap: O(1)! // // Compared to NetKAT predicates, which semantically also represent sets of // packets, symbolic packets have a few advantages: // * Cheap to store, copy, hash, and compare: O(1) // * Cheap to check set equality: O(1) // * Cheap to check set membership and set containment: O(# packet fields) -// -// NOTES ON THE API: -// * The majority of operations on `SymbolicPacket`s are defined as methods on -// the companion class `SymbolicPacketManager`. See the section "Why have a -// `SmbolicPacketManager` class?" at the top of the file to learn why. -// * Each `SymbolicPacket` is implicitly associated with the manager object that -// created it; using it with a different manager object has undefined -// behavior. -// * `EmptySet()` and `FullSet()` are exceptions to the above rule in that they -// work with any `SymbolicPacketManager` object. -class SymbolicPacket { +class [[nodiscard]] SymbolicPacket { public: // Default constructor: the empty set of packets. - SymbolicPacket() = default; - - // TODO(smolkaj): Move the EmptySet/FullSet APIs to the - // `SymbolicPacketManager` class for consistency. - - // The symbolic packet representing the empty set of packets. - static SymbolicPacket EmptySet() { return SymbolicPacket(kEmptySetIndex); } - - // The symbolic packet representing the set of all packets. - static SymbolicPacket FullSet() { return SymbolicPacket(kFullSetIndex); } - - // Returns true iff this symbolic packet represents the empty set of packets. - bool IsEmptySet() const { return node_index_ == kEmptySetIndex; } - - // Returns true iff this symbolic packet represents the set of all packets. - bool IsFullSet() const { return node_index_ == kFullSetIndex; } + SymbolicPacket(); // Two symbolic packets compare equal iff they represent the same set of - // concrete packets. Comparison is O(1), thanks to a canonical representation. + // concrete packets. Comparison is O(1), thanks to interning/hash-consing. friend auto operator<=>(SymbolicPacket a, SymbolicPacket b) = default; // Hashing, see https://abseil.io/docs/cpp/guides/hash. @@ -142,56 +135,227 @@ class SymbolicPacket { // Formatting, see https://abseil.io/docs/cpp/guides/abslstringify. template friend void AbslStringify(Sink& sink, SymbolicPacket packet) { - if (packet.IsEmptySet()) { - absl::Format(&sink, "SymbolicPacket"); - } else if (packet.IsFullSet()) { - absl::Format(&sink, "SymbolicPacket"); - } else { - absl::Format(&sink, "SymbolicPacket<%d>", packet.node_index_); - } + absl::Format(&sink, "%s", packet.ToString()); } + std::string ToString() const; private: - // In memory, a `SymbolicPacket` is just an integer, making it cheap to - // store, copy, hash, and compare. - // - // The `node_index_` is either: - // * A sentinel value: - // * `kEmptySetIndex`, representing the empty set of packets. - // * `kFullSetIndex`, representing the full set of packets. - // * An index into the `nodes_` vector of the `SymbolicPacketManager` object - // associated with this `SymbolicPacket`. In this case, the semantics of - // this object is given by the node stored at `nodes_[node_index_]` in the - // manager object. The index is otherwise arbitrary and meaningless, and - // thus, so is this object unless we have access to the associated manager - // object. - // - // We use a bit width of 32 as a tradeoff between minimizing memory usage - // (which is critical for scaling to large NetKAT models) and maximizing the - // number of `SymbolicPacket`s that can be created (which is also critical for - // scaling to large NetKAT models) -- we expect millions, but probably not - // billions, of symbolic packets in practice, and 2^32 ~= 4 billion. - // - // The sentinel values are chosen maximally to avoid collisions with valid - // indices, which are assigned dynamically by the manager object starting at - // 0. For performance reasons, there is no runtime protection against - // collisions and overflow if we create too many distinct `SymbolicPacket`s. + // An index into the `nodes_` vector of the `SymbolicPacketManager` object + // associated with this `SymbolicPacket`. The semantics of this symbolic + // packet is entirely determined by the node `nodes_[node_index_]`. The index + // is otherwise arbitrary and meaningless. // - // This data structure enjoys the following powerful canonicity property: two - // symbolic packets represent the same set if and only if they have the same - // `node_index_`. - uint32_t node_index_ = SentinelValue::kEmptySetIndex; - enum SentinelValue : uint32_t { - kEmptySetIndex = std::numeric_limits::max(), - kFullSetIndex = std::numeric_limits::max() - 1, - }; + // We use a 32-bit index as a tradeoff between minimizing memory usage and + // maximizing the number of `SymbolicPacket`s that can be created, both + // aspects that impact how well we scale to large NetKAT models. We expect + // millions, but not billions, of symbolic packets in practice, and 2^32 ~= 4 + // billion. + uint32_t node_index_; explicit SymbolicPacket(uint32_t node_index) : node_index_(node_index) {} friend class SymbolicPacketManager; }; -static_assert( - sizeof(SymbolicPacket) <= 4, - "SymbolicPacket should have small memory footprint for performance"); +// Protect against regressions in the memory layout, as it effects performance. +static_assert(sizeof(SymbolicPacket) <= 4); + +// A lightweight handle (16 bits) that represents a packet field like "dst_ip". +// Handles can only be created by a `SymbolicPacketManager` object, which stores +// the string. Interning/hash-consing fields in this way saves memory and makes +// fields cheap to store, copy, hash, and compare: O(1). +// +// The APIs of this object are almost entirely defined as methods and static +// members function of the companion class `SymbolicPacketManager`. See the +// section "Why have a `SmbolicPacketManager` class?" at the top of the file to +// learn why. +// +// CAUTION: Each `Field` is implicitly associated with the manager object that +// created it; using it with a different manager object has undefined behavior. +class [[nodiscard]] Field { + public: + // `Field`s can only be created by `SymbolicPacketManager`. + Field() = delete; + friend class SymbolicPacketManager; + + // O(1) comparison, thanks to interning/hash-consing. + friend auto operator<=>(Field a, Field b) = default; + + // Hashing, see https://abseil.io/docs/cpp/guides/hash. + template + friend H AbslHashValue(H h, Field field) { + return H::combine(std::move(h), field.index_); + } + + // Formatting, see https://abseil.io/docs/cpp/guides/abslstringify. + template + friend void AbslStringify(Sink& sink, Field field) { + absl::Format(&sink, "Field<%d>", field.index_); + } + + private: + // An index into the `fields_` vector of the `SymbolicPacketManager` object + // associated with this `Field`; `fields_[index_]` is the name of the field. + // The index is otherwise arbitrary and meaningless. + // + // We use a 16-bit index as a tradeoff between minimizing memory usage while + // supporting sufficiently many fields. We expect 100s, but not more than + // 2^16 ~= 65k fields. + uint16_t index_; + explicit Field(uint16_t index) : index_(index) {} +}; + +// Protect against regressions in the memory layout, as it effects performance. +static_assert(sizeof(Field) <= 2); + +// An "arena" in which `SymbolicPacket`s can be created and manipulated. +// +// This class defines the majority of operations on `SymbolicPacket`s and owns +// all the memory associated with the `SymbolicPacket`s returned by the class's +// methods. +// +// CAUTION: Using a `SymbolicPacket` returned by one `SymbolicPacketManager` +// object with a different manager is undefined behavior. +class SymbolicPacketManager { + public: + SymbolicPacketManager() = default; + + // The class is move-only: not copyable, but movable. + SymbolicPacketManager(const SymbolicPacketManager&) = delete; + SymbolicPacketManager& operator=(const SymbolicPacketManager&) = delete; + SymbolicPacketManager(SymbolicPacketManager&&) = default; + SymbolicPacketManager& operator=(SymbolicPacketManager&&) = default; + + // The symbolic packet representing the empty set of packets. + static SymbolicPacket EmptySet(); + + // The symbolic packet representing the set of all packets. + static SymbolicPacket FullSet(); + + // Returns true iff this symbolic packet represents the empty set of packets. + [[nodiscard]] bool IsEmptySet(SymbolicPacket packet) const; + + // Returns true iff this symbolic packet represents the set of all packets. + [[nodiscard]] bool IsFullSet(SymbolicPacket packet) const; + + // Returns true if the set represented by `symbolic_packet` contains the given + // `concrete_packet`, or false otherwise. + [[nodiscard]] bool Contains(SymbolicPacket symbolic_packet, + netkat::Packet concrete_packet) const; + + // Compiles the given `PredicateProto` into a `SymbolicPacket` that + // represents the set of packets satisfying the predicate. + SymbolicPacket Compile(const PredicateProto& pred); + + // Returns the set of packets whose `field` is equal to `value`. + SymbolicPacket Match(const std::string& field, int value); + + // Returns the set of packets that are in the `left` *AND* in the `right` set. + // Also known as set intersection. + SymbolicPacket And(SymbolicPacket left, SymbolicPacket right); + + // Returns the set of packets that are in the `left` *OR* in the `right` set. + // Also known as set union. + SymbolicPacket Or(SymbolicPacket left, SymbolicPacket right); + + // Returns the set of packets that are *NOT* in the given set. + // Also known as set complement. + SymbolicPacket Not(SymbolicPacket negand); + + // Returns a human-readable string representation of the given `packet`, + // intended for debugging. + [[nodiscard]] std::string PrettyPrint(SymbolicPacket packet) const; + + // Dynamically checks all class invariants. Exposed for testing. + absl::Status CheckInternalInvariants() const; + + // There are many additional set operations supported by the data structure. + // We may implement them as needed. For example: + // * subset - is one set a subset of another? + // * witness - given a (non-empty) set, return one (or n) elements from the + // set. + // * sample - return a memember from the set uniformly at random. + + private: + // A decision node in the symbolic packet DAG. The node branches on the value + // of a single `field`, and (the consequent of) each branch is another + // `SymbolicPacket`. Semantically, represents a cascading conditional of the + // form: + // + // if (field == value_1) then branch_1 + // else if (field == value_2) then branch_2 + // ... + // else default_branch + struct DecisionNode { + // The packet field whose value this decision node branches on. + // + // INVARIANT: `field` is strictly smaller than field indices of all + // sub-nodes. + Field field; + + // The consequent of the "else" branch of this decision node. + SymbolicPacket default_branch; + + // The "if" branches of the decision node, "keyed" by the value they branch + // on. Each element of the array is a (value, branch)-pair encoding + // "if (field == value) then branch". + // + // INVARIANTS: + // 1. Maintained by `NodeToPacket`: `branch_by_field_value` is non-empty. + // (If it is empty, the decision node gets replaced by `default_branch`.) + // 2. Maintained by `AddCase`: Each branch is != `default_branch`. (If it + // is == `default_branch`, it gets omitted.) + // 3. Maintained by the callers of `AddCase`: The pairs are ordered by + // strictly increasing value. No two branches have the same value. + // + // Choice of data structure: + // * Logically this is a map, but we don't require fast look ups and thus + // optimize for a compact, contiguous memory layout without indirection. + // * No need to dynamically resize, hence we can safe some bytes that + // dynamic data structures need for bookkeeping. + absl::FixedArray, + /*use_heap_allocation_above_size=*/0> + branch_by_field_value; + + // Protect against regressions in memory layout, as it effects performance. + static_assert(sizeof(branch_by_field_value) == 16); + + friend auto operator<=>(const DecisionNode& a, + const DecisionNode& b) = default; + + // Hashing, see https://abseil.io/docs/cpp/guides/hash. + template + friend H AbslHashValue(H h, const DecisionNode& node) { + return H::combine(std::move(h), node.field, node.default_branch, + node.branch_by_field_value); + } + }; + + // Protect against regressions in memory layout, as it effects performance. + static_assert(sizeof(DecisionNode) == 24); + static_assert(alignof(DecisionNode) == 8); + + // The decision nodes forming the BDD-style DAG representation of symbolic + // packets. `SymbolicPacket::node_index_` indexes into this vector. + std::vector nodes_; + + // A so called "unique table" to ensure each node is only added to `nodes_` + // once, and thus has a unique `SymbolicPacket::node_index`. + // INVARIANT: `packet_by_node_[n] = s` iff `nodes_[s.node_index_] == n`. + absl::flat_hash_map packet_by_node_; + + // `Field::index_` indexes into this vector. + std::vector fields_; + + // A so called "unique table" to ensure each field is only added to `fields_` + // once, and thus has a unique `Field::index_`. + // INVARIANT: `field_by_name_[n] = field` iff `fields_[field.index_] == n`. + absl::flat_hash_map field_by_name_; + + SymbolicPacket NodeToPacket(DecisionNode&& node); + const DecisionNode& GetNodeOrDie(SymbolicPacket packet) const; + const std::string& GetFieldNameOrDie(Field field) const; + Field GetField(const std::string& field_name); + [[nodiscard]] std::string PrettyPrint(const DecisionNode& node) const; +}; } // namespace netkat diff --git a/netkat/symbolic_packet_test.cc b/netkat/symbolic_packet_test.cc index ff2c4bb..86ae9f1 100644 --- a/netkat/symbolic_packet_test.cc +++ b/netkat/symbolic_packet_test.cc @@ -15,44 +15,238 @@ #include "netkat/symbolic_packet.h" +#include + +#include "absl/base/no_destructor.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" +#include "fuzztest/fuzztest.h" +#include "gmock/gmock.h" #include "gtest/gtest.h" +#include "netkat/evaluator.h" +#include "netkat/netkat_proto_constructors.h" namespace netkat { -namespace { -TEST(SymbolicPacketTest, DefaultConstructorYieldsEmptySet) { - EXPECT_TRUE(SymbolicPacket().IsEmptySet()); +// We use a global manager object across all tests to exercise statefulness. +// This also enables pretty printing for debugging. +SymbolicPacketManager& Manager() { + static absl::NoDestructor manager; + return *manager; +} + +// Custom pretty printer that GoogleTest will use instead of `AbslStringify`. +void PrintTo(const SymbolicPacket& packet, std::ostream* os) { + *os << Manager().PrettyPrint(packet); } -TEST(SymbolicPacketTest, EmptySetIsEmptySet) { - EXPECT_TRUE(SymbolicPacket::EmptySet().IsEmptySet()); +class CheckSymbolicPacketManagerInvariantsOnTearDown + : public testing::Environment { + public: + ~CheckSymbolicPacketManagerInvariantsOnTearDown() override {} + void SetUp() override {} + void TearDown() override { ASSERT_OK(Manager().CheckInternalInvariants()); } +}; +testing::Environment* const foo_env = testing::AddGlobalTestEnvironment( + new CheckSymbolicPacketManagerInvariantsOnTearDown); + +namespace { + +using testing::StartsWith; + +TEST(SymbolicPacketManagerTest, EmptySetIsEmptySet) { + EXPECT_TRUE(Manager().IsEmptySet(Manager().EmptySet())); + EXPECT_FALSE(Manager().IsFullSet(Manager().EmptySet())); } -TEST(SymbolicPacketTest, FullSetIsFullSet) { - EXPECT_TRUE(SymbolicPacket::FullSet().IsFullSet()); +TEST(SymbolicPacketManagerTest, FullSetIsFullSet) { + EXPECT_TRUE(Manager().IsFullSet(Manager().FullSet())); + EXPECT_FALSE(Manager().IsEmptySet(Manager().FullSet())); } -TEST(SymbolicPacketTest, EmptySetDoesNotEqualFullSet) { - EXPECT_NE(SymbolicPacket::EmptySet(), SymbolicPacket::FullSet()); +TEST(SymbolicPacketManagerTest, EmptySetDoesNotEqualFullSet) { + EXPECT_NE(Manager().EmptySet(), Manager().FullSet()); } -TEST(SymbolicPacketTest, AbslStringifyWorksForEmptySet) { - EXPECT_EQ(absl::StrCat(SymbolicPacket::EmptySet()), "SymbolicPacket"); +TEST(SymbolicPacketManagerTest, AbslStringifyWorksForEmptySet) { + EXPECT_THAT(absl::StrCat(SymbolicPacketManager::EmptySet()), + StartsWith("SymbolicPacket")); } -TEST(SymbolicPacketTest, AbslStringifyWorksForFullSet) { - EXPECT_EQ(absl::StrCat(SymbolicPacket::FullSet()), "SymbolicPacket"); +TEST(SymbolicPacketManagerTest, AbslStringifyWorksForFullSet) { + EXPECT_THAT(absl::StrCat(SymbolicPacketManager::FullSet()), + StartsWith("SymbolicPacket")); } -TEST(SymbolicPacketTest, AbslHashValueWorks) { +TEST(SymbolicPacketManagerTest, AbslHashValueWorks) { absl::flat_hash_set set = { - SymbolicPacket::EmptySet(), - SymbolicPacket::FullSet(), + SymbolicPacketManager::EmptySet(), + SymbolicPacketManager::FullSet(), }; EXPECT_EQ(set.size(), 2); } +TEST(SymbolicPacketManagerTest, TrueCompilesToFullSet) { + EXPECT_EQ(Manager().Compile(TrueProto()), Manager().FullSet()); +} + +TEST(SymbolicPacketManagerTest, FalseCompilesToEmptySet) { + EXPECT_EQ(Manager().Compile(FalseProto()), Manager().EmptySet()); +} + +void MatchCompilesToMatch(std::string field, int value) { + EXPECT_EQ(Manager().Compile(MatchProto(field, value)), + Manager().Match(field, value)); +} +FUZZ_TEST(SymbolicPacketManagerTest, MatchCompilesToMatch); + +void AndCompilesToAnd(const PredicateProto& left, const PredicateProto& right) { + EXPECT_EQ(Manager().Compile(AndProto(left, right)), + Manager().And(Manager().Compile(left), Manager().Compile(right))); +} +FUZZ_TEST(SymbolicPacketManagerTest, AndCompilesToAnd); + +void OrCompilesToOr(const PredicateProto& left, const PredicateProto& right) { + EXPECT_EQ(Manager().Compile(OrProto(left, right)), + Manager().Or(Manager().Compile(left), Manager().Compile(right))); +} +FUZZ_TEST(SymbolicPacketManagerTest, OrCompilesToOr); + +void NotCompilesToNot(const PredicateProto& pred) { + EXPECT_EQ(Manager().Compile(NotProto(pred)), + Manager().Not(Manager().Compile(pred))); +} +FUZZ_TEST(SymbolicPacketManagerTest, NotCompilesToNot); + +void CompilationPreservesSemantics(const PredicateProto& pred, + const Packet& packet) { + SymbolicPacketManager& mgr = Manager(); + EXPECT_EQ(mgr.Contains(mgr.Compile(pred), packet), Evaluate(pred, packet)); +} +FUZZ_TEST(SymbolicPacketManagerTest, CompilationPreservesSemantics); + +void EqualPredicatesCompileToEqualSymbolicPackets(const PredicateProto& pred) { + EXPECT_EQ(Manager().Compile(pred), Manager().Compile(pred)); +} +FUZZ_TEST(SymbolicPacketManagerTest, + EqualPredicatesCompileToEqualSymbolicPackets); + +void NegationCompilesToDifferentSymbolicPacket(const PredicateProto& pred) { + EXPECT_NE(Manager().Compile(pred), Manager().Compile(NotProto(pred))); +} +FUZZ_TEST(SymbolicPacketManagerTest, NegationCompilesToDifferentSymbolicPacket); + +void DoubleNegationCompilesToSameSymbolicPacket(const PredicateProto& pred) { + EXPECT_EQ(Manager().Compile(pred), + Manager().Compile(NotProto(NotProto(pred)))); +} +FUZZ_TEST(SymbolicPacketManagerTest, + DoubleNegationCompilesToSameSymbolicPacket); + +TEST(SymbolicPacketManagerTest, TrueNotEqualsMatch) { + EXPECT_NE(Manager().Compile(TrueProto()), + Manager().Compile(MatchProto("hi", 42))); +} +TEST(SymbolicPacketManagerTest, FalseNotEqualsMatch) { + EXPECT_NE(Manager().Compile(FalseProto()), + Manager().Compile(MatchProto("hi", 42))); +} +TEST(SymbolicPacketManagerTest, MatchNotEqualsDifferentMatch) { + EXPECT_NE(Manager().Compile(MatchProto("hi", 42)), + Manager().Compile(MatchProto("bye", 42))); + EXPECT_NE(Manager().Compile(MatchProto("hi", 42)), + Manager().Compile(MatchProto("hi", 24))); +} +TEST(SymbolicPacketManagerTest, NotTrueEqualsFalse) { + EXPECT_EQ(Manager().Compile(NotProto(TrueProto())), + Manager().Compile(FalseProto())); +} + +void AndIsIdempotent(const PredicateProto& pred) { + EXPECT_EQ(Manager().Compile(AndProto(pred, pred)), Manager().Compile(pred)); +} +FUZZ_TEST(SymbolicPacketManagerTest, AndIsIdempotent); + +void OrIsIdempotent(const PredicateProto& pred) { + EXPECT_EQ(Manager().Compile(OrProto(pred, pred)), Manager().Compile(pred)); +} +FUZZ_TEST(SymbolicPacketManagerTest, OrIsIdempotent); + +void PredOrItsNegationIsTrue(const PredicateProto& pred) { + EXPECT_EQ(Manager().Compile(OrProto(pred, NotProto(pred))), + Manager().Compile(TrueProto())); +} +FUZZ_TEST(SymbolicPacketManagerTest, PredOrItsNegationIsTrue); + +void PredAndItsNegationIsFalse(const PredicateProto& pred) { + EXPECT_EQ(Manager().Compile(AndProto(pred, NotProto(pred))), + Manager().Compile(FalseProto())); +} +FUZZ_TEST(SymbolicPacketManagerTest, PredAndItsNegationIsFalse); + +void AndTrueIsIdentity(const PredicateProto& pred) { + EXPECT_EQ(Manager().Compile(AndProto(pred, TrueProto())), + Manager().Compile(pred)); +} +FUZZ_TEST(SymbolicPacketManagerTest, AndTrueIsIdentity); + +void OrFalseIsIdentity(const PredicateProto& pred) { + EXPECT_EQ(Manager().Compile(OrProto(pred, FalseProto())), + Manager().Compile(pred)); +} +FUZZ_TEST(SymbolicPacketManagerTest, OrFalseIsIdentity); + +void AndFalseIsFalse(const PredicateProto& pred) { + EXPECT_EQ(Manager().Compile(AndProto(pred, FalseProto())), + Manager().Compile(FalseProto())); +} +FUZZ_TEST(SymbolicPacketManagerTest, AndFalseIsFalse); + +void OrTrueIsTrue(const PredicateProto& pred) { + EXPECT_EQ(Manager().Compile(OrProto(pred, TrueProto())), + Manager().Compile(TrueProto())); +} +FUZZ_TEST(SymbolicPacketManagerTest, OrTrueIsTrue); + +void AndIsCommutative(const PredicateProto& a, const PredicateProto& b) { + EXPECT_EQ(Manager().Compile(AndProto(a, b)), + Manager().Compile(AndProto(b, a))); +} +FUZZ_TEST(SymbolicPacketManagerTest, AndIsCommutative); + +void OrIsCommutative(const PredicateProto& a, const PredicateProto& b) { + EXPECT_EQ(Manager().Compile(OrProto(a, b)), Manager().Compile(OrProto(b, a))); +} +FUZZ_TEST(SymbolicPacketManagerTest, OrIsCommutative); + +void DistribiutiveLawHolds(const PredicateProto& a, const PredicateProto& b, + const PredicateProto& c) { + EXPECT_EQ(Manager().Compile(AndProto(a, OrProto(b, c))), + Manager().Compile(OrProto(AndProto(a, b), AndProto(a, c)))); +} +FUZZ_TEST(SymbolicPacketManagerTest, DistribiutiveLawHolds); + +void DeMorgansLawsHolds(const PredicateProto& a, const PredicateProto& b) { + EXPECT_EQ(Manager().Compile(NotProto(AndProto(a, b))), + Manager().Compile(OrProto(NotProto(a), NotProto(b)))); + EXPECT_EQ(Manager().Compile(NotProto(OrProto(a, b))), + Manager().Compile(AndProto(NotProto(a), NotProto(b)))); +} +FUZZ_TEST(SymbolicPacketManagerTest, DeMorgansLawsHolds); + +void AndIsAssociative(const PredicateProto& a, const PredicateProto& b, + const PredicateProto& c) { + EXPECT_EQ(Manager().Compile(AndProto(a, AndProto(b, c))), + Manager().Compile(AndProto(AndProto(a, b), c))); +} +FUZZ_TEST(SymbolicPacketManagerTest, AndIsAssociative); + +void OrIsAssociative(const PredicateProto& a, const PredicateProto& b, + const PredicateProto& c) { + EXPECT_EQ(Manager().Compile(OrProto(a, OrProto(b, c))), + Manager().Compile(OrProto(OrProto(a, b), c))); +} +FUZZ_TEST(SymbolicPacketManagerTest, OrIsAssociative); + } // namespace } // namespace netkat