Skip to content

Commit

Permalink
Add support for functions ownership
Browse files Browse the repository at this point in the history
  • Loading branch information
homar committed Jan 21, 2025
1 parent ce79254 commit d70bdfa
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,16 @@ public Optional<Identity> getFunctionRunAsIdentity(Session session, CatalogSchem
return Optional.empty();
}

@Override
public void setFunctionOwner(Session session, CatalogSchemaFunctionName function, TrinoPrincipal principal)
{
throw notSupportedException(function.getCatalogName());
}

@Override
public void functionCreated(Session session, CatalogSchemaFunctionName function, TrinoPrincipal principal)
{}

@Override
public void schemaCreated(Session session, CatalogSchemaName schema) {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
import io.trino.spi.predicate.TupleDomain;
import io.trino.spi.security.GrantInfo;
import io.trino.spi.security.Identity;
import io.trino.spi.security.PrincipalType;
import io.trino.spi.security.Privilege;
import io.trino.spi.security.RoleGrant;
import io.trino.spi.security.TrinoPrincipal;
Expand Down Expand Up @@ -2712,7 +2713,15 @@ public void createLanguageFunction(Session session, QualifiedObjectName name, La
CatalogHandle catalogHandle = catalogMetadata.getCatalogHandle();
ConnectorMetadata metadata = catalogMetadata.getMetadata(session);

metadata.createLanguageFunction(session.toConnectorSession(catalogHandle), name.asSchemaFunctionName(), function, replace);
SchemaFunctionName schemaFunctionName = name.asSchemaFunctionName();
if (catalogMetadata.getSecurityManagement() == SYSTEM) {
systemSecurityMetadata.functionCreated(
session,
new CatalogSchemaFunctionName(catalogHandle.getCatalogName().toString(), schemaFunctionName),
new TrinoPrincipal(PrincipalType.USER, session.getUser()));
}

metadata.createLanguageFunction(session.toConnectorSession(catalogHandle), schemaFunctionName, function, replace);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,16 @@ default void validateEntityKindAndPrivileges(Session session, String entityKind,
*/
Optional<Identity> getFunctionRunAsIdentity(Session session, CatalogSchemaFunctionName functionName);

/**
* Set the owner of the specified function
*/
void setFunctionOwner(Session session, CatalogSchemaFunctionName function, TrinoPrincipal principal);

/**
* A function is created
*/
void functionCreated(Session session, CatalogSchemaFunctionName function, TrinoPrincipal principal);

/**
* A schema was created
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ public class TestingAccessControlManager
extends AccessControlManager
{
private static final BiPredicate<Identity, String> IDENTITY_TABLE_TRUE = (identity, table) -> true;
private static final BiPredicate<Identity, String> IDENTITY_FUNCTION_TRUE = (identity, table) -> true;

private final Set<TestingPrivilege> denyPrivileges = new HashSet<>();
private final Map<RowFilterKey, List<ViewExpression>> rowFilters = new HashMap<>();
Expand All @@ -142,6 +143,7 @@ public class TestingAccessControlManager
private Predicate<String> deniedSchemas = s -> true;
private Predicate<SchemaTableName> deniedTables = s -> true;
private BiPredicate<Identity, String> denyIdentityTable = IDENTITY_TABLE_TRUE;
private BiPredicate<Identity, String> denyIdentityFunction = IDENTITY_FUNCTION_TRUE;

@Inject
public TestingAccessControlManager(
Expand Down Expand Up @@ -216,6 +218,11 @@ public void denyIdentityTable(BiPredicate<Identity, String> denyIdentityTable)
this.denyIdentityTable = requireNonNull(denyIdentityTable, "denyIdentityTable is null");
}

public void denyIdentityFunction(BiPredicate<Identity, String> denyIdentityFunction)
{
this.denyIdentityFunction = requireNonNull(denyIdentityFunction, "denyIdentityFunction is null");
}

@Override
public Set<String> filterCatalogs(SecurityContext securityContext, Set<String> catalogs)
{
Expand Down Expand Up @@ -698,6 +705,9 @@ public void checkCanSelectFromColumns(SecurityContext context, QualifiedObjectNa
@Override
public boolean canExecuteFunction(SecurityContext context, QualifiedObjectName functionName)
{
if (!denyIdentityFunction.test(context.getIdentity(), functionName.asSchemaFunctionName().toString())) {
return false;
}
if (shouldDenyPrivilege(context.getIdentity().getUser(), functionName.toString(), EXECUTE_FUNCTION)) {
return false;
}
Expand All @@ -710,6 +720,9 @@ public boolean canExecuteFunction(SecurityContext context, QualifiedObjectName f
@Override
public boolean canCreateViewWithExecuteFunction(SecurityContext context, QualifiedObjectName functionName)
{
if (!denyIdentityFunction.test(context.getIdentity(), functionName.asSchemaFunctionName().toString())) {
return false;
}
if (shouldDenyPrivilege(context.getIdentity().getUser(), functionName.toString(), GRANT_EXECUTE_FUNCTION)) {
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import io.trino.spi.connector.SchemaTableName;
import io.trino.spi.connector.SchemaTablePrefix;
import io.trino.spi.function.BoundSignature;
import io.trino.spi.function.CatalogSchemaFunctionName;
import io.trino.spi.function.FunctionDependencies;
import io.trino.spi.function.FunctionId;
import io.trino.spi.function.FunctionMetadata;
Expand Down Expand Up @@ -552,6 +553,64 @@ public void testJoinBaseTableWithView()
assertAccessAllowed(viewOwnerSession, "DROP VIEW " + viewName);
}

@Test
public void testFunctionOwners()
{
reset();

String functionOwner = "function_owner";
TrinoPrincipal functionOwnerPrincipal = new TrinoPrincipal(USER, functionOwner);

systemSecurityMetadata.grantRoles(getSession(), Set.of("function_owner_role"), Set.of(functionOwnerPrincipal), false, Optional.empty());
systemSecurityMetadata.setFunctionOwner(
getSession(),
new CatalogSchemaFunctionName("memory", "default", "my_test_function_inner"),
functionOwnerPrincipal);
systemSecurityMetadata.setFunctionOwner(
getSession(),
new CatalogSchemaFunctionName("memory", "default", "my_test_function_outer"),
functionOwnerPrincipal);

Session functionOwnerSession = TestingSession.testSessionBuilder()
.setIdentity(Identity.forUser(functionOwner)
.withEnabledRoles(Set.of("function_owner_role"))
.build())
.setCatalog(getSession().getCatalog())
.setSchema(getSession().getSchema())
.build();
assertAccessAllowed(
functionOwnerSession,
"CREATE FUNCTION memory.default.my_test_function_inner (x integer) RETURNS bigint RETURN x + 42");
assertAccessAllowed(
functionOwnerSession,
"SELECT memory.default.my_test_function_inner(2)");

assertAccessAllowed(
functionOwnerSession,
"CREATE FUNCTION memory.default.my_test_function_outer (x integer) RETURNS bigint RETURN x + memory.default.my_test_function_inner(58)");

assertAccessAllowed(
functionOwnerSession,
"SELECT memory.default.my_test_function_outer(2)");

getQueryRunner().getAccessControl()
.denyIdentityFunction((identity, function) -> !(identity.getEnabledRoles().contains("function_owner_role_without_access") && "default.my_test_function_inner".equals(function)));
systemSecurityMetadata.grantRoles(getSession(), Set.of("function_owner_role_without_access"), Set.of(functionOwnerPrincipal), false, Optional.empty());

assertAccessDenied(
functionOwnerSession,
"SELECT memory.default.my_test_function_outer(2)",
"Cannot execute function memory.default.my_test_function_inner");
assertAccessDenied(
"SELECT memory.default.my_test_function_outer(2)",
"Cannot execute function memory.default.my_test_function_inner");
systemSecurityMetadata.revokeRoles(getSession(), Set.of("function_owner_role_without_access"), Set.of(functionOwnerPrincipal), false, Optional.empty());
assertAccessAllowed(
functionOwnerSession,
"SELECT memory.default.my_test_function_outer(2)");
assertAccessAllowed("SELECT memory.default.my_test_function_outer(2)");
}

@Test
public void testViewFunctionAccessControl()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,14 @@ class TestingSystemSecurityMetadata
private final Set<String> roles = synchronizedSet(new HashSet<>());
private final Set<RoleGrant> roleGrants = synchronizedSet(new HashSet<>());
private final Map<CatalogSchemaTableName, Identity> viewOwners = synchronizedMap(new HashMap<>());
private final Map<CatalogSchemaFunctionName, Identity> functionOwners = synchronizedMap(new HashMap<>());

public void reset()
{
roles.clear();
roleGrants.clear();
viewOwners.clear();
functionOwners.clear();
}

@Override
Expand Down Expand Up @@ -246,9 +248,26 @@ public void setViewOwner(Session session, CatalogSchemaTableName view, TrinoPrin
@Override
public Optional<Identity> getFunctionRunAsIdentity(Session session, CatalogSchemaFunctionName functionName)
{
return Optional.empty();
return Optional.ofNullable(functionOwners.get(functionName))
.map(identity -> Identity.from(identity)
.withEnabledRoles(getRoleGrantsRecursively(new TrinoPrincipal(USER, identity.getUser()))
.stream()
.map(RoleGrant::getRoleName)
.collect(toImmutableSet()))
.build());
}

@Override
public void setFunctionOwner(Session session, CatalogSchemaFunctionName function, TrinoPrincipal principal)
{
checkArgument(principal.getType() == USER, "Only a user can be a function owner");
functionOwners.put(function, Identity.ofUser(principal.getName()));
}

@Override
public void functionCreated(Session session, CatalogSchemaFunctionName function, TrinoPrincipal principal)
{ }

@Override
public void schemaCreated(Session session, CatalogSchemaName schema) {}

Expand Down

0 comments on commit d70bdfa

Please sign in to comment.