From 85bf78079da3ae5c1418e6f195ab17fbe0067add Mon Sep 17 00:00:00 2001 From: Joseph Cosentino Date: Fri, 2 Feb 2024 09:46:59 -0800 Subject: [PATCH] perf: construct trie once --- .../auth/PermissionEvaluationUtils.java | 8 +- .../clientdevices/auth/util/WildcardTrie.java | 334 ++++++++++++++++++ 2 files changed, 338 insertions(+), 4 deletions(-) create mode 100644 src/main/java/com/aws/greengrass/clientdevices/auth/util/WildcardTrie.java diff --git a/src/main/java/com/aws/greengrass/clientdevices/auth/PermissionEvaluationUtils.java b/src/main/java/com/aws/greengrass/clientdevices/auth/PermissionEvaluationUtils.java index 8c487374f..ad95b143c 100644 --- a/src/main/java/com/aws/greengrass/clientdevices/auth/PermissionEvaluationUtils.java +++ b/src/main/java/com/aws/greengrass/clientdevices/auth/PermissionEvaluationUtils.java @@ -5,11 +5,11 @@ package com.aws.greengrass.clientdevices.auth; -import com.aws.greengrass.authorization.WildcardTrie; import com.aws.greengrass.clientdevices.auth.configuration.GroupManager; import com.aws.greengrass.clientdevices.auth.configuration.Permission; import com.aws.greengrass.clientdevices.auth.exception.PolicyException; import com.aws.greengrass.clientdevices.auth.session.Session; +import com.aws.greengrass.clientdevices.auth.util.WildcardTrie; import com.aws.greengrass.logging.api.Logger; import com.aws.greengrass.logging.impl.LogManager; import com.aws.greengrass.util.Utils; @@ -37,6 +37,7 @@ public final class PermissionEvaluationUtils { private static final Pattern SERVICE_RESOURCE_PATTERN = Pattern.compile( String.format(SERVICE_RESOURCE_FORMAT, SERVICE_PATTERN_STRING, SERVICE_RESOURCE_TYPE_PATTERN_STRING, SERVICE_RESOURCE_NAME_PATTERN_STRING), Pattern.UNICODE_CHARACTER_CLASS); + private final WildcardTrie wildcardTrie = new WildcardTrie(); private final GroupManager groupManager; /** @@ -131,9 +132,8 @@ private boolean compareResource(Resource requestResource, String policyResource) return true; } - WildcardTrie trie = new WildcardTrie(); - trie.add(policyResource); - return trie.matchesStandard(requestResource.getResourceStr()); + wildcardTrie.set(policyResource); + return wildcardTrie.matchesStandard(requestResource.getResourceStr()); } private Operation parseOperation(String operationStr) { diff --git a/src/main/java/com/aws/greengrass/clientdevices/auth/util/WildcardTrie.java b/src/main/java/com/aws/greengrass/clientdevices/auth/util/WildcardTrie.java new file mode 100644 index 000000000..268a8f1ae --- /dev/null +++ b/src/main/java/com/aws/greengrass/clientdevices/auth/util/WildcardTrie.java @@ -0,0 +1,334 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.aws.greengrass.clientdevices.auth.util; + +import com.aws.greengrass.authorization.AuthorizationHandler.ResourceLookupPolicy; +import com.aws.greengrass.util.DefaultConcurrentHashMap; + +import java.util.HashMap; +import java.util.Map; + +/** + * Copied from nucleus with some customizations for performance + */ +public class WildcardTrie { + protected static final String GLOB_WILDCARD = "*"; + protected static final String MQTT_MULTILEVEL_WILDCARD = "#"; + protected static final String MQTT_SINGLELEVEL_WILDCARD = "+"; + protected static final String MQTT_LEVEL_SEPARATOR = "/"; + protected static final char nullChar = '\0'; + protected static final char escapeChar = '$'; + protected static final char wildcardChar = GLOB_WILDCARD.charAt(0); + protected static final char multiLevelWildcardChar = MQTT_MULTILEVEL_WILDCARD.charAt(0); + protected static final char singleLevelWildcardChar = MQTT_SINGLELEVEL_WILDCARD.charAt(0); + protected static final char levelSeparatorChar = MQTT_LEVEL_SEPARATOR.charAt(0); + + + private boolean isTerminal; + private boolean isTerminalLevel; + private boolean isWildcard; + private boolean isMQTTWildcard; + private boolean matchAll; + private final Map children = + new DefaultConcurrentHashMap<>(WildcardTrie::new); + + public void set(String subject) { + children.clear(); + add(subject); + } + + /** + * Add allowed resources for a particular operation. + * - A new node is created for every occurrence of a wildcard (*, #, +). + * - Only nodes with valid usage of wildcards are marked with isWildcard or isMQTTWildcard. + * - Any other characters are grouped together to form a node. + * - Just a '*' or '#' creates a Node setting matchAll to true and would match all resources + * + * @param subject resource pattern + */ + public void add(String subject) { + if (subject == null) { + return; + } + if (subject.equals(GLOB_WILDCARD)) { + WildcardTrie initial = this.children.get(GLOB_WILDCARD); + initial.matchAll = true; + initial.isTerminal = true; + initial.isWildcard = true; + return; + } + if (subject.equals(MQTT_MULTILEVEL_WILDCARD)) { + WildcardTrie initial = this.children.get(MQTT_MULTILEVEL_WILDCARD); + initial.matchAll = true; + initial.isTerminal = true; + initial.isMQTTWildcard = true; + return; + } + if (subject.equals(MQTT_SINGLELEVEL_WILDCARD)) { + WildcardTrie initial = this.children.get(MQTT_SINGLELEVEL_WILDCARD); + initial.isTerminal = true; + initial.isMQTTWildcard = true; + return; + } + if (subject.startsWith(MQTT_SINGLELEVEL_WILDCARD + MQTT_LEVEL_SEPARATOR)) { + WildcardTrie initial = this.children.get(MQTT_SINGLELEVEL_WILDCARD); + initial.isMQTTWildcard = true; + initial.add(subject.substring(1), true); + } + + add(subject, true); + } + + @SuppressWarnings("PMD.AvoidDeeplyNestedIfStmts") + private WildcardTrie add(String subject, boolean isTerminal) { + if (subject.isEmpty()) { + this.isTerminal |= isTerminal; + return this; + } + int subjectLength = subject.length(); + WildcardTrie current = this; + StringBuilder sb = new StringBuilder(subjectLength); + for (int i = 0; i < subjectLength; i++) { + char currentChar = subject.charAt(i); + // Create separate Nodes for wildcards *, # and + + // Also tag them wildcard if its a valid usage + if (currentChar == wildcardChar) { + current = current.add(sb.toString(), false); + current = current.children.get(GLOB_WILDCARD); + current.isWildcard = true; + // If the string ends with *, then the wildcard is a terminal + if (i == subjectLength - 1) { + current.isTerminal = isTerminal; + return current; + } + return current.add(subject.substring(i + 1), true); + } + if (currentChar == multiLevelWildcardChar) { + WildcardTrie terminalLevel = current.add(sb.toString(), false); + current = terminalLevel.children.get(MQTT_MULTILEVEL_WILDCARD); + if (i == subjectLength - 1) { + current.isTerminal = true; + // check if # wildcard usage is valid + if (i > 0 && subject.charAt(i - 1) == levelSeparatorChar) { + current.isMQTTWildcard = true; + current.matchAll = true; + terminalLevel.isTerminalLevel = true; + } + return current; + } + return current.add(subject.substring(i + 1), true); + } + if (currentChar == singleLevelWildcardChar) { + current = current.add(sb.toString(), false); + current = current.children.get(MQTT_SINGLELEVEL_WILDCARD); + if (i == subjectLength - 1) { + current.isTerminal = true; + // check if '+' wildcard usage is valid + // if it's used at the last level + if (i > 0 && subject.charAt(i - 1) == levelSeparatorChar) { + current.isMQTTWildcard = true; + } + return current; + } + // check if '+' wildcard usage is valid + // if it's used in middle levels + if (i > 0 && subject.charAt(i - 1) == levelSeparatorChar + && subject.charAt(i + 1) == levelSeparatorChar) { + current.isMQTTWildcard = true; + } + return current.add(subject.substring(i + 1), true); + } + if (currentChar == escapeChar) { + char actualChar = getActualChar(subject.substring(i)); + if (actualChar != nullChar) { + sb.append(actualChar); + i = i + 3; + continue; + } + } + sb.append(currentChar); + } + // Handle non-wildcard value + current = current.children.get(sb.toString()); + current.isTerminal |= isTerminal; + return current; + } + + /** + * The method tries to parse the given string using escape sequence ${c} (where c is a character to be escaped) + * and returns the character c if the pattern is matched. In any other scenario it returns null character ('\0') + * + * @param str string provided to get + */ + static char getActualChar(String str) { + if (str.length() < 4) { + return nullChar; + } + // Match the escape format ${c} + if (str.charAt(0) == escapeChar && str.charAt(1) == '{' && str.charAt(3) == '}') { + return str.charAt(2); + } + return nullChar; + } + + /** + * Match given string to the corresponding allowed resources trie. MQTT wildcards are not processed. + * + * @param str string to match. + */ + @SuppressWarnings({"PMD.UselessParentheses", "PMD.CollapsibleIfStatements"}) + public boolean matchesStandard(String str) { + if (str == null) { + return true; + } + if ((isWildcard && isTerminal) || (isTerminal && str.isEmpty())) { + return true; + } + + boolean hasMatch = false; + Map matchingChildren = new HashMap<>(); + for (Map.Entry e : children.entrySet()) { + // Succeed fast + if (hasMatch) { + return true; + } + String key = e.getKey(); + WildcardTrie value = e.getValue(); + + // Process * wildcards + if (value.isWildcard && key.equals(GLOB_WILDCARD)) { + hasMatch = value.matchesStandard(str); + continue; + } + + // Match normal characters + if (str.startsWith(key)) { + hasMatch = value.matchesStandard(str.substring(key.length())); + // Succeed fast + if (hasMatch) { + return true; + } + } + + // If I'm a wildcard, then I need to maybe chomp many characters to match my children + if (isWildcard) { + int foundChildIndex = str.indexOf(key); + int keyLength = key.length(); + while (foundChildIndex >= 0) { + matchingChildren.put(str.substring(foundChildIndex + keyLength), value); + foundChildIndex = str.indexOf(key, foundChildIndex + 1); + } + } + } + // Succeed fast + if (hasMatch) { + return true; + } + if (isWildcard && !matchingChildren.isEmpty()) { + return matchingChildren.entrySet().stream().anyMatch((e) -> e.getValue().matchesStandard(e.getKey())); + } + + return false; + } + + /** + * Match given string to the corresponding allowed resources trie. MQTT wildcards are processed only if + * its a valid usage, otherwise treated as normal characters. + * + * @param str string to match + */ + @SuppressWarnings({"PMD.UselessParentheses", "PMD.CollapsibleIfStatements"}) + public boolean matchesMQTT(String str) { + if (str == null) { + return true; + } + if ((isWildcard && isTerminal) || (isTerminal && str.isEmpty())) { + return true; + } + if (isMQTTWildcard) { + if (matchAll || (isTerminal && (str.indexOf(MQTT_LEVEL_SEPARATOR) == -1))) { + return true; + } + } + + boolean hasMatch = false; + Map matchingChildren = new HashMap<>(); + for (Map.Entry e : children.entrySet()) { + // Succeed fast + if (hasMatch) { + return true; + } + String key = e.getKey(); + WildcardTrie value = e.getValue(); + + // Process *, # and + wildcards (only process MQTT wildcards that have valid usages) + if ((value.isWildcard && key.equals(GLOB_WILDCARD)) + || (value.isMQTTWildcard && (key.equals(MQTT_SINGLELEVEL_WILDCARD) + || key.equals(MQTT_MULTILEVEL_WILDCARD)))) { + hasMatch = value.matchesMQTT(str); + continue; + } + + // Match normal characters + if (str.startsWith(key)) { + hasMatch = value.matchesMQTT(str.substring(key.length())); + // Succeed fast + if (hasMatch) { + return true; + } + } + + // Check if it's terminalLevel to allow matching of string without "/" in the end + // "abc/#" should match "abc". + // "abc/*xy/#" should match "abc/12xy" + String terminalKey = key.substring(0, key.length() - 1); + if (value.isTerminalLevel) { + if (str.equals(terminalKey)) { + return true; + } + if (str.endsWith(terminalKey)) { + key = terminalKey; + } + } + + int keyLength = key.length(); + // If I'm a wildcard, then I need to maybe chomp many characters to match my children + if (isWildcard) { + int foundChildIndex = str.indexOf(key); + while (foundChildIndex >= 0 && foundChildIndex < str.length()) { + matchingChildren.put(str.substring(foundChildIndex + keyLength), value); + foundChildIndex = str.indexOf(key, foundChildIndex + 1); + } + } + // If I'm a MQTT wildcard (specifically +, as # is already covered), + // then I need to maybe chomp many characters to match my children + if (isMQTTWildcard) { + int foundChildIndex = str.indexOf(key); + // Matched characters inside + should not contain a "/" + while (foundChildIndex >= 0 + && foundChildIndex < str.length() + && (str.substring(0,foundChildIndex).indexOf(MQTT_LEVEL_SEPARATOR) == -1)) { + matchingChildren.put(str.substring(foundChildIndex + keyLength), value); + foundChildIndex = str.indexOf(key, foundChildIndex + 1); + } + } + } + // Succeed fast + if (hasMatch) { + return true; + } + if ((isWildcard || isMQTTWildcard) && !matchingChildren.isEmpty()) { + return matchingChildren.entrySet().stream().anyMatch((e) -> e.getValue().matchesMQTT(e.getKey())); + } + + return false; + } + + public boolean matches(String str, ResourceLookupPolicy lookupPolicy) { + return lookupPolicy == ResourceLookupPolicy.MQTT_STYLE ? matchesMQTT(str) + : matchesStandard(str); + } +}