Skip to content

Commit

Permalink
[NetKAT] Add a function to evaluate a policy on a packet.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 702512910
  • Loading branch information
jonathan-dilorenzo authored and copybara-github committed Dec 4, 2024
1 parent 4ef2697 commit 7ae1002
Show file tree
Hide file tree
Showing 4 changed files with 226 additions and 12 deletions.
4 changes: 4 additions & 0 deletions netkat/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ cc_library(
deps = [
":netkat_cc_proto",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log",
],
)
Expand Down Expand Up @@ -87,6 +88,9 @@ cc_test(
deps = [
":evaluator",
":netkat_cc_proto",
":netkat_proto_constructors",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_fuzztest//fuzztest",
"@com_google_googletest//:gtest_main",
],
)
53 changes: 53 additions & 0 deletions netkat/evaluator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "netkat/evaluator.h"

#include "absl/container/flat_hash_set.h"
#include "absl/log/log.h"
#include "netkat/netkat.pb.h"

Expand Down Expand Up @@ -46,4 +47,56 @@ bool Evaluate(const PredicateProto& predicate, const Packet& packet) {
<< static_cast<int>(predicate.predicate_case());
}

absl::flat_hash_set<Packet> Evaluate(
const PolicyProto& policy, const absl::flat_hash_set<Packet>& packets) {
absl::flat_hash_set<Packet> result;
for (const Packet& packet : packets) {
result.merge(Evaluate(policy, packet));
}
return result;
}

absl::flat_hash_set<Packet> Evaluate(const PolicyProto& policy,
const Packet& packet) {
switch (policy.policy_case()) {
case PolicyProto::kFilter:
return Evaluate(policy.filter(), packet)
? absl::flat_hash_set<Packet>({packet})
: absl::flat_hash_set<Packet>();
case PolicyProto::kModification: {
Packet modified_packet = packet;
// Adds field if it doesn't exist, and modifies it otherwise.
modified_packet[policy.modification().field()] =
policy.modification().value();
return {modified_packet};
}
case PolicyProto::kRecord:
// Record is treated as a no-op.
return {packet};
case PolicyProto::kSequenceOp:
return Evaluate(policy.sequence_op().right(),
Evaluate(policy.sequence_op().left(), packet));
case PolicyProto::kUnionOp: {
absl::flat_hash_set<Packet> result =
Evaluate(policy.union_op().left(), packet);
result.merge(Evaluate(policy.union_op().right(), packet));
return result;
}
case PolicyProto::kIterateOp: {
// p* = 1 + p + p;p + p;p;p + ...
absl::flat_hash_set<Packet> result = {packet}; // 1
// Evaluate p on result until fixed point, marked by no change in size.
int last_size;
do {
last_size = result.size();
result.merge(Evaluate(policy.iterate_op().iterable(), result)); // p^n
} while (last_size != result.size());
return result;
}
case PolicyProto::POLICY_NOT_SET:
// Unset policy is treated as DENY.
return {};
}
}

} // namespace netkat
17 changes: 16 additions & 1 deletion netkat/evaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@
//
// Defines a library of functions for evaluating NetKAT predicates and policies
// on concrete packets.
//
// See go/netkat-hld for more details.

#ifndef GOOGLE_NETKAT_NETKAT_EVALUATOR_H_
#define GOOGLE_NETKAT_NETKAT_EVALUATOR_H_

#include <string>

#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "netkat/netkat.pb.h"

namespace netkat {
Expand All @@ -42,9 +45,21 @@ using Packet = absl::flat_hash_map<std::string, int>;
// Returns true if the given `packet` satisfies the given `predicate`, false
// otherwise.
//
// Note: Empty predicates are considered unsatisfiable.
// Note: Uninitialized predicates are considered unsatisfiable.
bool Evaluate(const PredicateProto& predicate, const Packet& packet);

// Returns the output packets produced by running the given policy on the given
// input packet. Treats `Record` (aka `dup`) as no-op and does not keep track of
// packet histories.
//
// Note: Uninitialized policies are considered DENY, returning the empty set.
absl::flat_hash_set<Packet> Evaluate(const PolicyProto& policy,
const Packet& packet);

// Lifts policy evaluation to sets of packets.
absl::flat_hash_set<Packet> Evaluate(
const PolicyProto& policy, const absl::flat_hash_set<Packet>& packets);

} // namespace netkat

#endif // GOOGLE_NETKAT_NETKAT_EVALUATOR_H_
164 changes: 153 additions & 11 deletions netkat/evaluator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,24 @@

#include "netkat/evaluator.h"

#include "absl/container/flat_hash_set.h"
#include "fuzztest/fuzztest.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "netkat/netkat.pb.h"
#include "netkat/netkat_proto_constructors.h"

namespace netkat {
namespace {

TEST(EvaluatorTest, TrueIsTrueOnAnyPackets) {
using ::fuzztest::Arbitrary;
using ::fuzztest::InRange;
using ::testing::ContainerEq;
using ::testing::IsEmpty;
using ::testing::IsSupersetOf;
using ::testing::UnorderedElementsAre;

TEST(EvaluatePredicateProtoTest, TrueIsTrueOnAnyPackets) {
PredicateProto true_predicate;
true_predicate.mutable_bool_constant()->set_value(true);

Expand All @@ -31,7 +42,7 @@ TEST(EvaluatorTest, TrueIsTrueOnAnyPackets) {
Packet({{"field1", 1}, {"field2", 2}, {"field3", 3}})));
}

TEST(EvaluatorTest, FalseIsFalseOnAnyPacket) {
TEST(EvaluatePredicateProtoTest, FalseIsFalseOnAnyPacket) {
PredicateProto false_predicate;
false_predicate.mutable_bool_constant()->set_value(false);

Expand All @@ -41,7 +52,7 @@ TEST(EvaluatorTest, FalseIsFalseOnAnyPacket) {
Packet({{"field1", 1}, {"field2", 2}, {"field3", 3}})));
}

TEST(EvaluatorTest, EmptyPredicateIsFalseOnAnyPacket) {
TEST(EvaluatePredicateProtoTest, EmptyPredicateIsFalseOnAnyPacket) {
PredicateProto empty_predicate;

EXPECT_FALSE(Evaluate(empty_predicate, Packet()));
Expand All @@ -50,7 +61,7 @@ TEST(EvaluatorTest, EmptyPredicateIsFalseOnAnyPacket) {
Packet({{"field1", 1}, {"field2", 2}, {"field3", 3}})));
}

TEST(EvaluatorTest, NotTrueIsFalseOnAnyPackets) {
TEST(EvaluatePredicateProtoTest, NotTrueIsFalseOnAnyPackets) {
PredicateProto not_true_predicate;
not_true_predicate.mutable_not_op()
->mutable_negand()
Expand All @@ -63,7 +74,7 @@ TEST(EvaluatorTest, NotTrueIsFalseOnAnyPackets) {
Packet({{"field1", 1}, {"field2", 2}, {"field3", 3}})));
}

TEST(EvaluatorTest, NotFalseIsTrueOnAnyPacket) {
TEST(EvaluatePredicateProtoTest, NotFalseIsTrueOnAnyPacket) {
PredicateProto not_false_predicate;
not_false_predicate.mutable_not_op()
->mutable_negand()
Expand All @@ -76,7 +87,7 @@ TEST(EvaluatorTest, NotFalseIsTrueOnAnyPacket) {
Packet({{"field1", 1}, {"field2", 2}, {"field3", 3}})));
}

TEST(EvaluatorTest, NotNotTrueIsTrueOnAnyPackets) {
TEST(EvaluatePredicateProtoTest, NotNotTrueIsTrueOnAnyPackets) {
PredicateProto not_not_true_predicate;
not_not_true_predicate.mutable_not_op()
->mutable_negand()
Expand All @@ -91,7 +102,7 @@ TEST(EvaluatorTest, NotNotTrueIsTrueOnAnyPackets) {
Packet({{"field1", 1}, {"field2", 2}, {"field3", 3}})));
}

TEST(EvaluatorTest, MatchesFieldWithCorrectValue) {
TEST(EvaluatePredicateProtoTest, MatchesFieldWithCorrectValue) {
PredicateProto match_predicate;
match_predicate.mutable_match()->set_field("field1");
match_predicate.mutable_match()->set_value(1);
Expand All @@ -102,7 +113,7 @@ TEST(EvaluatorTest, MatchesFieldWithCorrectValue) {
Packet({{"field1", 1}, {"field2", 2}, {"field3", 3}})));
}

TEST(EvaluatorTest, DoesNotMatchFieldWithWrongValue) {
TEST(EvaluatePredicateProtoTest, DoesNotMatchFieldWithWrongValue) {
PredicateProto match_predicate;
match_predicate.mutable_match()->set_field("field1");
match_predicate.mutable_match()->set_value(2);
Expand All @@ -113,7 +124,7 @@ TEST(EvaluatorTest, DoesNotMatchFieldWithWrongValue) {
Packet({{"field1", 1}, {"field2", 2}, {"field3", 3}})));
}

TEST(EvaluatorTest, AndIsLogicallyCorrect) {
TEST(EvaluatePredicateProtoTest, AndIsLogicallyCorrect) {
PredicateProto true_and_true_predicate;
true_and_true_predicate.mutable_and_op()
->mutable_left()
Expand Down Expand Up @@ -160,7 +171,7 @@ TEST(EvaluatorTest, AndIsLogicallyCorrect) {
Packet({{"field1", 1}, {"field2", 2}, {"field3", 3}})));
}

TEST(EvaluatorTest, OrIsLogicallyCorrect) {
TEST(EvaluatePredicateProtoTest, OrIsLogicallyCorrect) {
PredicateProto true_or_true_predicate;
true_or_true_predicate.mutable_or_op()
->mutable_left()
Expand Down Expand Up @@ -207,7 +218,7 @@ TEST(EvaluatorTest, OrIsLogicallyCorrect) {
Packet({{"field1", 1}, {"field2", 2}, {"field3", 3}})));
}

TEST(EvaluatorTest, DeMorganHolds) {
TEST(EvaluatePredicateProtoTest, DeMorganHolds) {
const Packet kEmptyPacket = Packet();
const Packet kOneFieldPacket = Packet({{"field1", 1}});
const Packet kThreeFieldsPacket =
Expand Down Expand Up @@ -294,5 +305,136 @@ TEST(EvaluatorTest, DeMorganHolds) {
}
}

/*--- Basic policy properties ------------------------------------------------*/

void LiftedEvaluationIsCorrect(absl::flat_hash_set<Packet> packets,
PolicyProto policy) {
absl::flat_hash_set<Packet> expected_packets;
for (const Packet& packet : packets) {
expected_packets.merge(Evaluate(policy, packet));
}
EXPECT_THAT(Evaluate(policy, packets), ContainerEq(expected_packets));
}
FUZZ_TEST(EvaluatePolicyProtoTest, LiftedEvaluationIsCorrect);

void RecordIsAccept(Packet packet) {
EXPECT_THAT(Evaluate(RecordProto(), packet), UnorderedElementsAre(packet));
}
FUZZ_TEST(EvaluatePolicyProtoTest, RecordIsAccept);

void UninitializedPolicyIsDeny(Packet packet) {
EXPECT_THAT(Evaluate(PolicyProto(), packet), IsEmpty());
}
FUZZ_TEST(EvaluatePolicyProtoTest, UninitializedPolicyIsDeny);

void FilterIsCorrect(Packet packet, PredicateProto predicate) {
if (Evaluate(predicate, packet)) {
EXPECT_THAT(Evaluate(FilterProto(predicate), packet),
UnorderedElementsAre(packet));
} else {
EXPECT_THAT(Evaluate(FilterProto(predicate), packet), IsEmpty());
}
}
FUZZ_TEST(EvaluatePolicyProtoTest, FilterIsCorrect);

void ModifyModifies(Packet packet, std::string field, int value) {
Packet expected_packet = packet;
expected_packet[field] = value;
EXPECT_THAT(Evaluate(ModificationProto(field, value), packet),
UnorderedElementsAre(expected_packet));
}
FUZZ_TEST(EvaluatePolicyProtoTest, ModifyModifies);

void UnionCombines(Packet packet, PolicyProto left, PolicyProto right) {
absl::flat_hash_set<Packet> expected_packets = Evaluate(left, packet);
expected_packets.merge(Evaluate(right, packet));

EXPECT_THAT(Evaluate(UnionProto(left, right), packet),
ContainerEq(expected_packets));
}
FUZZ_TEST(EvaluatePolicyProtoTest, UnionCombines);

void SequenceSequences(Packet packet, PolicyProto left, PolicyProto right) {
absl::flat_hash_set<Packet> expected_packets =
Evaluate(right, Evaluate(left, packet));

EXPECT_THAT(Evaluate(SequenceProto(left, right), packet),
ContainerEq(expected_packets));
}
FUZZ_TEST(EvaluatePolicyProtoTest, SequenceSequences);

PolicyProto UnionUpToNthPower(PolicyProto iterable, int n) {
PolicyProto union_policy = AcceptProto();
PolicyProto next_sequence = iterable;
for (int i = 1; i <= n; ++i) {
union_policy = UnionProto(union_policy, next_sequence);
next_sequence = SequenceProto(iterable, next_sequence);
}
return union_policy;
}

void IterateIsSupersetOfUnionOfNSequences(Packet packet, PolicyProto iterable,
int n) {
EXPECT_THAT(Evaluate(IterateProto(iterable), packet),
IsSupersetOf(Evaluate(UnionUpToNthPower(iterable, n), packet)));
}
FUZZ_TEST(EvaluatePolicyProtoTest, IterateIsSupersetOfUnionOfNSequences)
.WithDomains(/*packet=*/Arbitrary<Packet>(),
/*iterable=*/Arbitrary<PolicyProto>(),
/*n=*/InRange(0, 100));

void IterateIsUnionOfNSequencesForSomeN(Packet packet, PolicyProto iterable) {
absl::flat_hash_set<Packet> iterate_output_packets =
Evaluate(IterateProto(iterable), packet);

// Evaluate successively larger unions until we find one that matches all
// packets in `iterate_packets`.
absl::flat_hash_set<Packet> union_output_packets;
int last_size;
int n = 0;
do {
last_size = union_output_packets.size();
union_output_packets = Evaluate(UnionUpToNthPower(iterable, n++), packet);
} while (iterate_output_packets != union_output_packets &&
union_output_packets.size() > last_size);

EXPECT_THAT(iterate_output_packets, ContainerEq(union_output_packets));
}
FUZZ_TEST(EvaluatePolicyProtoTest, IterateIsUnionOfNSequencesForSomeN);

TEST(EvaluatePolicyProtoTest, SimpleIterateThroughFiltersAndModifies) {
// f == 0; f:=1 + f == 1; f := 2 + f == 2; f := 3
PolicyProto iterable = UnionProto(
SequenceProto(FilterProto(MatchProto("f", 0)), ModificationProto("f", 1)),
UnionProto(SequenceProto(FilterProto(MatchProto("f", 1)),
ModificationProto("f", 2)),
SequenceProto(FilterProto(MatchProto("f", 2)),
ModificationProto("f", 3))));

// If the packet contains the field, then the output is the union of the
// input and the modified packets.
EXPECT_THAT(Evaluate(IterateProto(iterable), Packet({{"f", 0}})),
UnorderedElementsAre(Packet({{"f", 0}}), Packet({{"f", 1}}),
Packet({{"f", 2}}), Packet({{"f", 3}})));

// If the packet doesn't contain the field, then the only output is the
// input.
EXPECT_THAT(Evaluate(IterateProto(iterable), Packet()),
UnorderedElementsAre(Packet()));
}

/*--- Advanced policy properties ---------------------------------------------*/
void ModifyThenMatchIsEquivalentToModify(Packet packet, std::string field,
int value) {
// f := n;f == n is equivalent to f := n.
EXPECT_THAT(Evaluate(SequenceProto(ModificationProto(field, value),
FilterProto(MatchProto(field, value))),
packet),
ContainerEq(Evaluate(ModificationProto(field, value), packet)));
}
FUZZ_TEST(EvaluatePolicyProtoTest, ModifyThenMatchIsEquivalentToModify);

// TODO(dilo): Add tests for each of the NetKAT axioms.

} // namespace
} // namespace netkat

0 comments on commit 7ae1002

Please sign in to comment.