Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use last policy #829

Merged
merged 12 commits into from
Jan 8, 2025
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,19 @@ Write the date in place of the "Unreleased" in the case a new version is release
- `docker-compose.yml` now uses the healthcheck endpoint `/healthz`
- In client, support specifying API key expiration time as string with
units, like ``"7d"` or `"10m"`.
- Fix bug where access policies were not applied to child nodes during request
- Add metadata-based access control to SimpleAccessPolicy
- Add example test of metadata-based allowed_scopes which requires the path to the target node

### Fixed

- Bug in Python client resulted in error when accessing data sources on a
just-created object.
- Fix bug where access policies were not applied to child nodes during request

### Changed

- Change access policy API to be async for filters and allowed_scopes

## 2024-12-09

Expand Down
122 changes: 119 additions & 3 deletions tiled/_tests/test_access_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,73 @@

import numpy
import pytest
from starlette.status import HTTP_403_FORBIDDEN
from fastapi import HTTPException
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND

from ..access_policies import (
ALL_SCOPES,
PUBLIC_SCOPES,
SimpleAccessPolicy,
SpecialUsers,
)
from ..adapters.array import ArrayAdapter
from ..adapters.mapping import MapAdapter
from ..client import Context, from_context
from ..client.utils import ClientError
from ..server.app import build_app_from_config
from ..server.core import NoEntry
from .utils import enter_username_password, fail_with_status_code

arr = numpy.ones((5, 5))
arr_zeros = numpy.zeros((5, 5))
arr_ad = ArrayAdapter.from_array(arr)


class EntryBasedAccessPolicy(SimpleAccessPolicy):
nmaytan marked this conversation as resolved.
Show resolved Hide resolved
"""
This example access policy demonstrates how the metadata on some nested child node
can be efficiently consulted and incorporated in logic that determines access scopes.
In this test example, the metadata on the node quite literally lists some scopes that
it should not allow. In realistic examples it could be incorporated in site-specific logic.
"""

async def allowed_scopes(self, node, principal, path_parts):
# If this is being called, filter_access has let us get this far.
if principal is SpecialUsers.public:
allowed = PUBLIC_SCOPES
elif principal.type == "service":
allowed = self.scopes
else:
allowed = self.scopes

if self._get_id(principal) in self.admins:
allowed = ALL_SCOPES
else:
# Allowed scopes will be filtered based on some metadata of the target entry
try:
for i, segment in enumerate(path_parts):
if hasattr(node, "lookup_adapter"):
node = await node.lookup_adapter(path_parts[i:])
if node is None:
raise NoEntry(path_parts)
break
else:
try:
node = node[segment]
except (KeyError, TypeError):
raise NoEntry(path_parts)
except NoEntry:
raise HTTPException(
status_code=HTTP_404_NOT_FOUND,
detail=f"No such entry: {path_parts}",
)
remove_scope = node.metadata().get("remove_scope", None)
if remove_scope in allowed:
allowed = allowed.copy()
allowed.remove(remove_scope)
return allowed


def tree_a(access_policy=None):
return MapAdapter({"A1": arr_ad, "A2": arr_ad}, access_policy=access_policy)

Expand Down Expand Up @@ -48,10 +103,10 @@ def context(tmpdir_module):
"access_control": {
"access_policy": "tiled.access_policies:SimpleAccessPolicy",
"args": {
"access_lists": {"alice": ["a", "c", "d", "e"]},
"access_lists": {"alice": ["a", "c", "d", "e", "g", "h"]},
"provider": "toy",
"admins": ["admin"],
"public": ["f"],
"public": ["f", "g"],
},
},
"trees": [
Expand Down Expand Up @@ -128,6 +183,34 @@ def context(tmpdir_module):
},
},
{"tree": ArrayAdapter.from_array(arr), "path": "/f"},
{
"tree": "tiled.catalog:in_memory",
"args": {"writable_storage": tmpdir_module / "g"},
"path": "/g",
"access_control": {
"access_policy": "tiled.access_policies:SimpleAccessPolicy",
"args": {
"provider": "toy",
"key": "project",
"access_lists": {"alice": ["projectA"], "bob": ["projectB"]},
"admins": ["admin"],
"public": ["projectC"],
},
},
},
{
"tree": "tiled.catalog:in_memory",
"args": {"writable_storage": tmpdir_module / "h"},
"path": "/h",
"access_control": {
"access_policy": "tiled._tests.test_access_control:EntryBasedAccessPolicy",
"args": {
"provider": "toy",
"access_lists": {"alice": ["x", "y"]},
"admins": ["admin"],
},
},
},
],
}
app = build_app_from_config(config)
Expand All @@ -138,19 +221,37 @@ def context(tmpdir_module):
admin_client[k].write_array(arr, key="A1")
admin_client[k].write_array(arr, key="A2")
admin_client[k].write_array(arr, key="x")
for k, v in {"A3": "projectA", "A4": "projectB", "r": "projectC"}.items():
admin_client["g"].write_array(arr, key=k, metadata={"project": v})
for k, v in {"x": "write:data", "y": None}.items():
admin_client["h"].write_array(arr, key=k, metadata={"remove_scope": v})
yield context


def test_entry_based_scopes(context, enter_username_password):
with enter_username_password("alice", "secret1"):
alice_client = from_context(context, username="alice")
with pytest.raises(ClientError, match="Not enough permissions"):
alice_client["h"]["x"].write(arr_zeros)
alice_client["h"]["y"].write(arr_zeros)


def test_top_level_access_control(context, enter_username_password):
with enter_username_password("alice", "secret1"):
alice_client = from_context(context, username="alice")
assert "a" in alice_client
assert "A2" in alice_client["a"]
assert "A1" not in alice_client["a"]
assert "b" not in alice_client
assert "g" in alice_client
assert "A3" in alice_client["g"]
assert "A4" not in alice_client["g"]
alice_client["a"]["A2"]
alice_client["g"]["A3"]
with pytest.raises(KeyError):
alice_client["b"]
with pytest.raises(KeyError):
alice_client["g"]["A4"]

with enter_username_password("bob", "secret2"):
bob_client = from_context(context, username="bob")
Expand All @@ -159,6 +260,8 @@ def test_top_level_access_control(context, enter_username_password):
bob_client["a"]
with pytest.raises(KeyError):
bob_client["b"]
with pytest.raises(KeyError):
bob_client["g"]["A3"]
alice_client.logout()

# Make sure clearing default identity works without raising an error.
Expand All @@ -177,6 +280,7 @@ def test_access_control_with_api_key_auth(context, enter_username_password):
context.api_key = key_info["secret"]
client = from_context(context)
client["a"]["A2"]
client["g"]["A3"]
finally:
# Clean up Context, which is a module-scopae fixture shared with other tests.
context.api_key = None
Expand All @@ -194,7 +298,11 @@ def test_node_export(enter_username_password, context, buffer):
assert "A2" in exported_dict["contents"]["a"]["contents"]
assert "A1" not in exported_dict["contents"]["a"]["contents"]
assert "b" not in exported_dict
assert "g" in exported_dict["contents"]
assert "A3" in exported_dict["contents"]["g"]["contents"]
assert "A4" not in exported_dict["contents"]["g"]["contents"]
exported_dict["contents"]["a"]["contents"]["A2"]
exported_dict["contents"]["g"]["contents"]["A3"]


def test_create_and_update_allowed(enter_username_password, context):
Expand All @@ -206,8 +314,13 @@ def test_create_and_update_allowed(enter_username_password, context):
alice_client["c"]["x"].update_metadata(metadata={"added_key": 3})
assert alice_client["c"]["x"].metadata["added_key"] == 3

alice_client["g"]["A3"].metadata
alice_client["g"]["A3"].update_metadata(metadata={"added_key": 9})
assert alice_client["g"]["A3"].metadata["added_key"] == 9

# Create
alice_client["c"].write_array([1, 2, 3])
alice_client["g"].write_array([4, 5, 6], metadata={"project": "projectA"})
alice_client.logout()


Expand All @@ -233,8 +346,11 @@ def test_public_access(context):
for key in ["a", "b", "c", "d", "e"]:
assert key not in public_client
public_client["f"].read()
public_client["g"]["r"].read()
with pytest.raises(KeyError):
public_client["a", "A1"]
with pytest.raises(KeyError):
public_client["g", "A3"]


def test_service_principal_access(tmpdir):
Expand Down
35 changes: 24 additions & 11 deletions tiled/_tests/test_protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import dask.dataframe
import numpy
import pandas
import pytest
import sparse
from numpy.typing import NDArray
from pytest_mock import MockFixture
Expand Down Expand Up @@ -367,7 +368,7 @@ def test_tableadapter_protocol(mocker: MockFixture) -> None:
mock_call6.assert_called_once_with("abc")


class CustomAccessPolicy:
class CustomAccessPolicy(AccessPolicy):
ALL = ALL_ACCESS

def __init__(self, scopes: Optional[Scopes] = None) -> None:
Expand All @@ -376,27 +377,38 @@ def __init__(self, scopes: Optional[Scopes] = None) -> None:
def _get_id(self, principal: Principal) -> None:
return None

def allowed_scopes(self, node: BaseAdapter, principal: Principal) -> Scopes:
async def allowed_scopes(
self, node: BaseAdapter, principal: Principal, path_parts: List[Any]
) -> Scopes:
allowed = self.scopes
somemetadata = node.metadata() # noqa: 841
return allowed

def filters(
self, node: BaseAdapter, principal: Principal, scopes: Scopes
async def filters(
self,
node: BaseAdapter,
principal: Principal,
scopes: Scopes,
path_parts: List[Any],
) -> Filters:
queries: Filters = []
somespecs = node.specs() # noqa: 841
return queries


def accesspolicy_protocol_functions(
policy: AccessPolicy, node: BaseAdapter, principal: Principal, scopes: Scopes
async def accesspolicy_protocol_functions(
policy: AccessPolicy,
node: BaseAdapter,
principal: Principal,
scopes: Scopes,
path_parts: List[Any],
) -> None:
policy.allowed_scopes(node, principal)
policy.filters(node, principal, scopes)
await policy.allowed_scopes(node, principal, path_parts)
await policy.filters(node, principal, scopes, path_parts)


def test_accesspolicy_protocol(mocker: MockFixture) -> None:
@pytest.mark.asyncio # type: ignore
async def test_accesspolicy_protocol(mocker: MockFixture) -> None:
mock_call = mocker.patch.object(CustomAwkwardAdapter, "metadata")
mock_call2 = mocker.patch.object(CustomAwkwardAdapter, "specs")

Expand All @@ -410,11 +422,12 @@ def test_accesspolicy_protocol(mocker: MockFixture) -> None:
uuid="12345678124123412345678123456781", type=PrincipalType.user
)
scopes = {"abc"}
path_parts = ["wx", "yz"]

anyawkwardadapter = CustomAwkwardAdapter(container, structure, metadata=metadata)

accesspolicy_protocol_functions(
anyaccesspolicy, anyawkwardadapter, principal, scopes
await accesspolicy_protocol_functions(
anyaccesspolicy, anyawkwardadapter, principal, scopes, path_parts
)
mock_call.assert_called_once()
mock_call2.assert_called_once()
20 changes: 12 additions & 8 deletions tiled/access_policies.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from .queries import KeysFilter
from functools import partial

from .queries import In, KeysFilter
from .scopes import SCOPES
from .utils import Sentinel, SpecialUsers, import_object

Expand All @@ -11,10 +13,10 @@
class DummyAccessPolicy:
"Impose no access restrictions."

def allowed_scopes(self, node, principal):
async def allowed_scopes(self, node, principal, path_parts):
return ALL_SCOPES

def filters(self, node, principal, scopes):
async def filters(self, node, principal, scopes, path_parts):
return []


Expand All @@ -35,10 +37,11 @@ class SimpleAccessPolicy:
ALL = ALL_ACCESS

def __init__(
self, access_lists, *, provider, scopes=None, public=None, admins=None
self, access_lists, *, provider, key=None, scopes=None, public=None, admins=None
):
self.access_lists = {}
self.provider = provider
self.key = key
self.scopes = scopes if (scopes is not None) else ALL_SCOPES
self.public = set(public or [])
self.admins = set(admins or [])
Expand All @@ -61,7 +64,7 @@ def _get_id(self, principal):
)
return id

def allowed_scopes(self, node, principal):
async def allowed_scopes(self, node, principal, path_parts):
# If this is being called, filter_access has let us get this far.
if principal is SpecialUsers.public:
allowed = PUBLIC_SCOPES
Expand All @@ -76,10 +79,11 @@ def allowed_scopes(self, node, principal):
allowed = self.scopes
return allowed

def filters(self, node, principal, scopes):
async def filters(self, node, principal, scopes, path_parts):
queries = []
query_filter = KeysFilter if not self.key else partial(In, self.key)
if principal is SpecialUsers.public:
queries.append(KeysFilter(self.public))
queries.append(query_filter(self.public))
else:
# Services have no identities; just use the uuid.
if principal.type == "service":
Expand All @@ -101,5 +105,5 @@ def filters(self, node, principal, scopes):
f"Unexpected access_list {access_list} of type {type(access_list)}. "
f"Expected iterable or {self.ALL}, instance of {type(self.ALL)}."
)
queries.append(KeysFilter(allowed))
queries.append(query_filter(allowed))
return queries
12 changes: 9 additions & 3 deletions tiled/adapters/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,17 @@ def __getitem__(self, key: str) -> ArrayAdapter:

class AccessPolicy(Protocol):
@abstractmethod
def allowed_scopes(self, node: BaseAdapter, principal: Principal) -> Scopes:
async def allowed_scopes(
self, node: BaseAdapter, principal: Principal, path_parts: List[Any]
) -> Scopes:
pass

@abstractmethod
def filters(
self, node: BaseAdapter, principal: Principal, scopes: Scopes
async def filters(
self,
node: BaseAdapter,
principal: Principal,
scopes: Scopes,
path_parts: List[Any],
) -> Filters:
pass
Loading
Loading