Skip to content

Commit

Permalink
perf: construct trie once
Browse files Browse the repository at this point in the history
  • Loading branch information
jcosentino11 committed Feb 2, 2024
1 parent 611f7c9 commit 85bf780
Show file tree
Hide file tree
Showing 2 changed files with 338 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

package com.aws.greengrass.clientdevices.auth;

import com.aws.greengrass.authorization.WildcardTrie;
import com.aws.greengrass.clientdevices.auth.configuration.GroupManager;
import com.aws.greengrass.clientdevices.auth.configuration.Permission;
import com.aws.greengrass.clientdevices.auth.exception.PolicyException;
import com.aws.greengrass.clientdevices.auth.session.Session;
import com.aws.greengrass.clientdevices.auth.util.WildcardTrie;
import com.aws.greengrass.logging.api.Logger;
import com.aws.greengrass.logging.impl.LogManager;
import com.aws.greengrass.util.Utils;
Expand Down Expand Up @@ -37,6 +37,7 @@ public final class PermissionEvaluationUtils {
private static final Pattern SERVICE_RESOURCE_PATTERN = Pattern.compile(
String.format(SERVICE_RESOURCE_FORMAT, SERVICE_PATTERN_STRING, SERVICE_RESOURCE_TYPE_PATTERN_STRING,
SERVICE_RESOURCE_NAME_PATTERN_STRING), Pattern.UNICODE_CHARACTER_CLASS);
private final WildcardTrie wildcardTrie = new WildcardTrie();
private final GroupManager groupManager;

/**
Expand Down Expand Up @@ -131,9 +132,8 @@ private boolean compareResource(Resource requestResource, String policyResource)
return true;
}

WildcardTrie trie = new WildcardTrie();
trie.add(policyResource);
return trie.matchesStandard(requestResource.getResourceStr());
wildcardTrie.set(policyResource);
return wildcardTrie.matchesStandard(requestResource.getResourceStr());
}

private Operation parseOperation(String operationStr) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,334 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

package com.aws.greengrass.clientdevices.auth.util;

import com.aws.greengrass.authorization.AuthorizationHandler.ResourceLookupPolicy;
import com.aws.greengrass.util.DefaultConcurrentHashMap;

import java.util.HashMap;
import java.util.Map;

/**
* Copied from nucleus with some customizations for performance
*/
public class WildcardTrie {
protected static final String GLOB_WILDCARD = "*";
protected static final String MQTT_MULTILEVEL_WILDCARD = "#";
protected static final String MQTT_SINGLELEVEL_WILDCARD = "+";
protected static final String MQTT_LEVEL_SEPARATOR = "/";
protected static final char nullChar = '\0';
protected static final char escapeChar = '$';
protected static final char wildcardChar = GLOB_WILDCARD.charAt(0);
protected static final char multiLevelWildcardChar = MQTT_MULTILEVEL_WILDCARD.charAt(0);
protected static final char singleLevelWildcardChar = MQTT_SINGLELEVEL_WILDCARD.charAt(0);
protected static final char levelSeparatorChar = MQTT_LEVEL_SEPARATOR.charAt(0);


private boolean isTerminal;
private boolean isTerminalLevel;
private boolean isWildcard;
private boolean isMQTTWildcard;
private boolean matchAll;
private final Map<String, WildcardTrie> children =
new DefaultConcurrentHashMap<>(WildcardTrie::new);

public void set(String subject) {
children.clear();
add(subject);
}

/**
* Add allowed resources for a particular operation.
* - A new node is created for every occurrence of a wildcard (*, #, +).
* - Only nodes with valid usage of wildcards are marked with isWildcard or isMQTTWildcard.
* - Any other characters are grouped together to form a node.
* - Just a '*' or '#' creates a Node setting matchAll to true and would match all resources
*
* @param subject resource pattern
*/
public void add(String subject) {
if (subject == null) {
return;
}
if (subject.equals(GLOB_WILDCARD)) {
WildcardTrie initial = this.children.get(GLOB_WILDCARD);
initial.matchAll = true;
initial.isTerminal = true;
initial.isWildcard = true;
return;
}
if (subject.equals(MQTT_MULTILEVEL_WILDCARD)) {
WildcardTrie initial = this.children.get(MQTT_MULTILEVEL_WILDCARD);
initial.matchAll = true;
initial.isTerminal = true;
initial.isMQTTWildcard = true;
return;
}
if (subject.equals(MQTT_SINGLELEVEL_WILDCARD)) {
WildcardTrie initial = this.children.get(MQTT_SINGLELEVEL_WILDCARD);
initial.isTerminal = true;
initial.isMQTTWildcard = true;
return;
}
if (subject.startsWith(MQTT_SINGLELEVEL_WILDCARD + MQTT_LEVEL_SEPARATOR)) {
WildcardTrie initial = this.children.get(MQTT_SINGLELEVEL_WILDCARD);
initial.isMQTTWildcard = true;
initial.add(subject.substring(1), true);
}

add(subject, true);
}

@SuppressWarnings("PMD.AvoidDeeplyNestedIfStmts")
private WildcardTrie add(String subject, boolean isTerminal) {
if (subject.isEmpty()) {
this.isTerminal |= isTerminal;
return this;
}
int subjectLength = subject.length();
WildcardTrie current = this;
StringBuilder sb = new StringBuilder(subjectLength);
for (int i = 0; i < subjectLength; i++) {
char currentChar = subject.charAt(i);
// Create separate Nodes for wildcards *, # and +
// Also tag them wildcard if its a valid usage
if (currentChar == wildcardChar) {
current = current.add(sb.toString(), false);
current = current.children.get(GLOB_WILDCARD);
current.isWildcard = true;
// If the string ends with *, then the wildcard is a terminal
if (i == subjectLength - 1) {
current.isTerminal = isTerminal;
return current;
}
return current.add(subject.substring(i + 1), true);
}
if (currentChar == multiLevelWildcardChar) {
WildcardTrie terminalLevel = current.add(sb.toString(), false);
current = terminalLevel.children.get(MQTT_MULTILEVEL_WILDCARD);
if (i == subjectLength - 1) {
current.isTerminal = true;
// check if # wildcard usage is valid
if (i > 0 && subject.charAt(i - 1) == levelSeparatorChar) {
current.isMQTTWildcard = true;
current.matchAll = true;
terminalLevel.isTerminalLevel = true;
}
return current;
}
return current.add(subject.substring(i + 1), true);
}
if (currentChar == singleLevelWildcardChar) {
current = current.add(sb.toString(), false);
current = current.children.get(MQTT_SINGLELEVEL_WILDCARD);
if (i == subjectLength - 1) {
current.isTerminal = true;
// check if '+' wildcard usage is valid
// if it's used at the last level
if (i > 0 && subject.charAt(i - 1) == levelSeparatorChar) {
current.isMQTTWildcard = true;
}
return current;
}
// check if '+' wildcard usage is valid
// if it's used in middle levels
if (i > 0 && subject.charAt(i - 1) == levelSeparatorChar
&& subject.charAt(i + 1) == levelSeparatorChar) {
current.isMQTTWildcard = true;
}
return current.add(subject.substring(i + 1), true);
}
if (currentChar == escapeChar) {
char actualChar = getActualChar(subject.substring(i));
if (actualChar != nullChar) {
sb.append(actualChar);
i = i + 3;
continue;
}
}
sb.append(currentChar);
}
// Handle non-wildcard value
current = current.children.get(sb.toString());
current.isTerminal |= isTerminal;
return current;
}

/**
* The method tries to parse the given string using escape sequence ${c} (where c is a character to be escaped)
* and returns the character c if the pattern is matched. In any other scenario it returns null character ('\0')
*
* @param str string provided to get
*/
static char getActualChar(String str) {
if (str.length() < 4) {
return nullChar;
}
// Match the escape format ${c}
if (str.charAt(0) == escapeChar && str.charAt(1) == '{' && str.charAt(3) == '}') {
return str.charAt(2);
}
return nullChar;
}

/**
* Match given string to the corresponding allowed resources trie. MQTT wildcards are not processed.
*
* @param str string to match.
*/
@SuppressWarnings({"PMD.UselessParentheses", "PMD.CollapsibleIfStatements"})
public boolean matchesStandard(String str) {
if (str == null) {
return true;
}
if ((isWildcard && isTerminal) || (isTerminal && str.isEmpty())) {
return true;
}

boolean hasMatch = false;
Map<String, WildcardTrie> matchingChildren = new HashMap<>();
for (Map.Entry<String, WildcardTrie> e : children.entrySet()) {
// Succeed fast
if (hasMatch) {
return true;
}
String key = e.getKey();
WildcardTrie value = e.getValue();

// Process * wildcards
if (value.isWildcard && key.equals(GLOB_WILDCARD)) {
hasMatch = value.matchesStandard(str);
continue;
}

// Match normal characters
if (str.startsWith(key)) {
hasMatch = value.matchesStandard(str.substring(key.length()));
// Succeed fast
if (hasMatch) {
return true;
}
}

// If I'm a wildcard, then I need to maybe chomp many characters to match my children
if (isWildcard) {
int foundChildIndex = str.indexOf(key);
int keyLength = key.length();
while (foundChildIndex >= 0) {
matchingChildren.put(str.substring(foundChildIndex + keyLength), value);
foundChildIndex = str.indexOf(key, foundChildIndex + 1);
}
}
}
// Succeed fast
if (hasMatch) {
return true;
}
if (isWildcard && !matchingChildren.isEmpty()) {
return matchingChildren.entrySet().stream().anyMatch((e) -> e.getValue().matchesStandard(e.getKey()));
}

return false;
}

/**
* Match given string to the corresponding allowed resources trie. MQTT wildcards are processed only if
* its a valid usage, otherwise treated as normal characters.
*
* @param str string to match
*/
@SuppressWarnings({"PMD.UselessParentheses", "PMD.CollapsibleIfStatements"})
public boolean matchesMQTT(String str) {
if (str == null) {
return true;
}
if ((isWildcard && isTerminal) || (isTerminal && str.isEmpty())) {
return true;
}
if (isMQTTWildcard) {
if (matchAll || (isTerminal && (str.indexOf(MQTT_LEVEL_SEPARATOR) == -1))) {
return true;
}
}

boolean hasMatch = false;
Map<String, WildcardTrie> matchingChildren = new HashMap<>();
for (Map.Entry<String, WildcardTrie> e : children.entrySet()) {
// Succeed fast
if (hasMatch) {
return true;
}
String key = e.getKey();
WildcardTrie value = e.getValue();

// Process *, # and + wildcards (only process MQTT wildcards that have valid usages)
if ((value.isWildcard && key.equals(GLOB_WILDCARD))
|| (value.isMQTTWildcard && (key.equals(MQTT_SINGLELEVEL_WILDCARD)
|| key.equals(MQTT_MULTILEVEL_WILDCARD)))) {
hasMatch = value.matchesMQTT(str);
continue;
}

// Match normal characters
if (str.startsWith(key)) {
hasMatch = value.matchesMQTT(str.substring(key.length()));
// Succeed fast
if (hasMatch) {
return true;
}
}

// Check if it's terminalLevel to allow matching of string without "/" in the end
// "abc/#" should match "abc".
// "abc/*xy/#" should match "abc/12xy"
String terminalKey = key.substring(0, key.length() - 1);
if (value.isTerminalLevel) {
if (str.equals(terminalKey)) {
return true;
}
if (str.endsWith(terminalKey)) {
key = terminalKey;
}
}

int keyLength = key.length();
// If I'm a wildcard, then I need to maybe chomp many characters to match my children
if (isWildcard) {
int foundChildIndex = str.indexOf(key);
while (foundChildIndex >= 0 && foundChildIndex < str.length()) {
matchingChildren.put(str.substring(foundChildIndex + keyLength), value);
foundChildIndex = str.indexOf(key, foundChildIndex + 1);
}
}
// If I'm a MQTT wildcard (specifically +, as # is already covered),
// then I need to maybe chomp many characters to match my children
if (isMQTTWildcard) {
int foundChildIndex = str.indexOf(key);
// Matched characters inside + should not contain a "/"
while (foundChildIndex >= 0
&& foundChildIndex < str.length()
&& (str.substring(0,foundChildIndex).indexOf(MQTT_LEVEL_SEPARATOR) == -1)) {
matchingChildren.put(str.substring(foundChildIndex + keyLength), value);
foundChildIndex = str.indexOf(key, foundChildIndex + 1);
}
}
}
// Succeed fast
if (hasMatch) {
return true;
}
if ((isWildcard || isMQTTWildcard) && !matchingChildren.isEmpty()) {
return matchingChildren.entrySet().stream().anyMatch((e) -> e.getValue().matchesMQTT(e.getKey()));
}

return false;
}

public boolean matches(String str, ResourceLookupPolicy lookupPolicy) {
return lookupPolicy == ResourceLookupPolicy.MQTT_STYLE ? matchesMQTT(str)
: matchesStandard(str);
}
}

0 comments on commit 85bf780

Please sign in to comment.