From c1f6d9f9c8ef7d316576522b91b750dec363c4f6 Mon Sep 17 00:00:00 2001 From: Joseph Cosentino Date: Wed, 18 Sep 2024 12:30:34 -0700 Subject: [PATCH] chore: passed existing tests, add more cases for combinations --- .../clientdevices/auth/util/WildcardTrie.java | 89 +++++++++++-------- .../auth/util/WildcardTrieTest.java | 24 +++-- 2 files changed, 69 insertions(+), 44 deletions(-) diff --git a/src/main/java/com/aws/greengrass/clientdevices/auth/util/WildcardTrie.java b/src/main/java/com/aws/greengrass/clientdevices/auth/util/WildcardTrie.java index da3ae031f..d4446d95b 100644 --- a/src/main/java/com/aws/greengrass/clientdevices/auth/util/WildcardTrie.java +++ b/src/main/java/com/aws/greengrass/clientdevices/auth/util/WildcardTrie.java @@ -17,11 +17,15 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.function.Function; import java.util.function.Supplier; @RequiredArgsConstructor public class WildcardTrie { + private static final Function EXCEPTION_UNSUPPORTED_WILDCARD_TYPE = wildcardType -> + new UnsupportedOperationException("wildcard type " + wildcardType.name() + " not supported"); + private Node root; private final MatchOptions opts; @@ -42,14 +46,17 @@ private Node withPattern(@NonNull Node n, @NonNull String s) { for (int i = 0; i < s.length(); i++) { char c = s.charAt(i); if (isWildcard(s.charAt(i))) { + // create child node from non-wildcard chars that have been accumulated so far + Node node = token.length() > 0 ? n.children.get(token.toString()) : n; + // create child node for wildcard char itself WildcardType type = WildcardType.from(c); - Node node = n.children.get(type.val()); + node = node.children.get(type.val()); node.wildcardType = type; if (i == s.length() - 1) { // we've reached the last token return node; } - return withPattern(node, s.substring(token.length() + 2)); + return withPattern(node, s.substring(i + 1)); } else { token.append(c); } @@ -80,55 +87,59 @@ public boolean matches(@NonNull String s) { return matches(root, s); } - private boolean matches(@NonNull Node n, String s) { + private boolean matches(@NonNull Node n, @NonNull String s) { if (n.isTerminal()) { - if (n.isWildcard()) { - switch (n.wildcardType) { - case SINGLE: - return s.length() == 1; - case GLOB: - return true; - default: - throw new UnsupportedOperationException("wildcard type " + n.wildcardType.name() + " not supported"); - } - } else { - return s.isEmpty(); + return n.wildcardType == WildcardType.GLOB || s.isEmpty(); + } + + if (n.isWildcard()) { + switch (n.wildcardType) { + case SINGLE: + return n.children.keySet().stream().anyMatch(token -> { + Node child = n.children.get(token); + // skip over one character for single wildcard + if (child.isWildcard()) { + return !s.isEmpty() && matches(child, s.substring(1)); + } else { + return !s.isEmpty() && s.startsWith(token.substring(0, 1)) && matches(child, s.substring(1)); + } + }); + case GLOB: + return n.children.keySet().stream().anyMatch(token -> { + Node child = n.children.get(token); + if (child.isWildcard()) { + return true;// TODO + } else { + // consume the input string to find a match + return allIndicesOf(s, token).stream() + .anyMatch(tokenIndex -> + matches(child, s.substring(tokenIndex + token.length())) + ); + } + }); + default: + throw EXCEPTION_UNSUPPORTED_WILDCARD_TYPE.apply(n.wildcardType); } } - for (String token : n.children.keySet()) { + return n.children.keySet().stream().anyMatch(token -> { Node child = n.children.get(token); - - if (n.isWildcard()) { // parent is a wildcard - switch (n.wildcardType) { + if (child.isWildcard()) { + switch (child.wildcardType) { case SINGLE: - // skip over one character for single wildcard - return matches(child, s.substring(1)); + // skip past the next character for ? matching + return !s.isEmpty() && matches(child, s.substring(1)); case GLOB: - // consume the input string to find a match - return allIndicesOf(s, token).stream() - .anyMatch(tokenIndex -> - matches(child, s.substring(tokenIndex + token.length())) - ); + // skip past token and figure out retroactively if the glob matched + return matches(child, s); default: - throw new UnsupportedOperationException("wildcard type " + n.wildcardType.name() + " not supported"); + throw EXCEPTION_UNSUPPORTED_WILDCARD_TYPE.apply(child.wildcardType); } - } - - if (child.isWildcard()) { - // skip past the wildcard node, - // on the next iteration we need to figure out - // the part the wildcard matched (if at all). - return matches(child, s); } else { // match found, keep following this trie branch - if (s.startsWith(token)) { - return matches(child, s.substring(token.length())); - } + return s.startsWith(token) && matches(child, s.substring(token.length())); } - } - - return false; + }); } private static List allIndicesOf(@NonNull String s, @NonNull String sub) { diff --git a/src/test/java/com/aws/greengrass/clientdevices/auth/util/WildcardTrieTest.java b/src/test/java/com/aws/greengrass/clientdevices/auth/util/WildcardTrieTest.java index d195d8f95..a4ddf1f76 100644 --- a/src/test/java/com/aws/greengrass/clientdevices/auth/util/WildcardTrieTest.java +++ b/src/test/java/com/aws/greengrass/clientdevices/auth/util/WildcardTrieTest.java @@ -18,7 +18,7 @@ import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.params.provider.Arguments.arguments; -// TODO fix failing cases + class WildcardTrieTest { static Stream validMatches() { @@ -36,7 +36,14 @@ static Stream validMatches() { // single character wildcard ? arguments("?", asList("f", "*", "?")), arguments("??", asList("ff", "**", "??", "*?")), - arguments("?f?", asList("fff", "f", "ff", "*f", "*f*", "f*")) + arguments("?f?", asList("fff", "*f*")), + // glob and single + arguments("?*", asList("?", "*", "a", "??", "**", "ab", "***", "???", "abc")), + arguments("*?", asList("?", "*", "a", "??", "**", "ab", "***", "???", "abc")), + arguments("?*?", asList("??", "**", "aa", "???", "***", "aaa", "????", "****", "aaaa")), + arguments("*?*", asList("?", "*", "a", "??", "**", "ab", "***", "???", "abc")), + arguments("a?*b", asList("acb", "a?b", "a*b", "a?cb")), + arguments("a*?b", asList("acb", "a?b", "a*b", "a?cb")) ); } @@ -45,7 +52,8 @@ static Stream validMatches() { void GIVEN_trie_with_wildcards_WHEN_valid_matches_provided_THEN_pass(String pattern, List matches) { WildcardTrie.MatchOptions opts = WildcardTrie.MatchOptions.builder().useSingleCharWildcard(true).build(); WildcardTrie trie = new WildcardTrie(opts).withPattern(pattern); - matches.forEach(m -> assertTrue(trie.matches(m))); + matches.forEach(m -> assertTrue(trie.matches(m), + String.format("String \"%s\" did not match the pattern \"%s\"", m, pattern))); } @@ -63,7 +71,12 @@ static Stream invalidMatches() { // single character wildcard ? arguments("?", asList("aa", "??", "**")), arguments("??", asList("aaa", "???", "***")), - arguments("?a?", asList("aaaa", "fff")) + arguments("?a?", asList("fff", "aaaa")), + arguments("?f?", asList("ff", "f", "*f", "f*")), + // glob and single + arguments("?*?", asList("a", "?", "*")), + arguments("a?*b", asList("ab", "abc")), + arguments("a*?b", asList("ab", "abc")) ); } @@ -72,6 +85,7 @@ static Stream invalidMatches() { void GIVEN_trie_with_wildcards_WHEN_invalid_matches_provided_THEN_fail(String pattern, List matches) { WildcardTrie.MatchOptions opts = WildcardTrie.MatchOptions.builder().useSingleCharWildcard(true).build(); WildcardTrie trie = new WildcardTrie(opts).withPattern(pattern); - matches.forEach(m -> assertFalse(trie.matches(m))); + matches.forEach(m -> assertFalse(trie.matches(m), + String.format("String \"%s\" incorrectly matched the pattern \"%s\"", m, pattern))); } }