diff --git a/.dockerignore b/.dockerignore
index 9839db2..55551dc 100644
--- a/.dockerignore
+++ b/.dockerignore
@@ -17,4 +17,4 @@
**/.mypy_cache/
**/.ruff_cache/
**/.vscode/
-**/.moto_recording
\ No newline at end of file
+**/.moto_recording
diff --git a/.github/workflows/build-and-push.yml b/.github/workflows/build-and-push.yml
index 7d8f668..75a6956 100644
--- a/.github/workflows/build-and-push.yml
+++ b/.github/workflows/build-and-push.yml
@@ -2,7 +2,7 @@ name: Build and push docker image and package
on:
release:
- types:
+ types:
- published
workflow_dispatch:
diff --git a/.github/workflows/conventional-commits.yml b/.github/workflows/conventional-commits.yml
index 6ecc0d4..f896c89 100644
--- a/.github/workflows/conventional-commits.yml
+++ b/.github/workflows/conventional-commits.yml
@@ -12,4 +12,4 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- - uses: chanzuckerberg/github-actions/.github/actions/conventional-commits@main
\ No newline at end of file
+ - uses: chanzuckerberg/github-actions/.github/actions/conventional-commits@main
diff --git a/.github/workflows/release-please.yml b/.github/workflows/release-please.yml
index 9c616a6..7435455 100644
--- a/.github/workflows/release-please.yml
+++ b/.github/workflows/release-please.yml
@@ -25,4 +25,4 @@ jobs:
with:
release-type: python
token: ${{ steps.generate_token.outputs.token }}
- bump-minor-pre-major: true
\ No newline at end of file
+ bump-minor-pre-major: true
diff --git a/README.md b/README.md
index 1605a13..6a4c78d 100644
--- a/README.md
+++ b/README.md
@@ -5,7 +5,7 @@ Platformics is a GraphQL API framework that relies on code generation to impleme
The libraries and tools that make Platformics work:
![image](docs/images/platformics_libs.svg)
-
+
### Links to these tools/libraries
- [LinkML](https://linkml.io/) - Schema modeling language
- [FastAPI](https://fastapi.tiangolo.com/) - Async HTTP router
@@ -60,6 +60,7 @@ The version in `pyproject.toml` is managed using [poetry-dynamic-versioning](htt
- [Work with platformics](docs/HOWTO-working-with-platformics.md)
- [Extend the generated API](docs/HOWTO-extend-generated-api.md)
- [Customize Codegen templates](docs/HOWTO-customize-templates.md)
+- [Override Default Authorization Behaviors](docs/HOWTO-override-authorization.md)
## Contributing
This project adheres to the Contributor Covenant code of conduct. By participating, you are expected to uphold this code. Please report unacceptable behavior to opensource@chanzuckerberg.com.
diff --git a/docs/HOWTO-customize-templates.md b/docs/HOWTO-customize-templates.md
index c65b20a..22b8588 100644
--- a/docs/HOWTO-customize-templates.md
+++ b/docs/HOWTO-customize-templates.md
@@ -7,7 +7,7 @@ Platformics supports replacing one codegen template with another, either globall
2. Create a directory that will contain your overriden templates, such as `template_overrides`
3. Copy that template to your overrides directory with the same path relative to `templates` in the platformics repo. For example, if you want to override `platformics/codegen/templates/database/models/class_name.py.j2`, copy it to `template_overrides/database/models/class_name.py.j2`
4. Modify the template as much as you want
-5. When you run codegen, include the overrides folder as a parameter to the codegen tool. For example, update the `codegen` target in the `Makefile` for your project directory to look like:
+5. When you run codegen, include the overrides folder as a parameter to the codegen tool. For example, update the `codegen` target in the `Makefile` for your project directory to look like:
```
$(docker_compose_run) $(APP_CONTAINER) platformics api generate --schemafile ./schema/schema.yaml --template-override-paths template_overrides --output-prefix .
```
@@ -17,7 +17,7 @@ Platformics supports replacing one codegen template with another, either globall
2. Create a directory that will contain your overriden templates, such as `template_overrides`
3. Copy that template to your overrides directory with the same path relative to `templates` in the platformics repo, but the **filename needs to reflect the camel_case class name**. For example, if you want to override `platformics/codegen/templates/database/models/class_name.py.j2`, for a class called `MyData`, copy it to `template_overrides/database/models/my_data.py.j2`
4. Modify the template as much as you want
-5. When you run codegen, include the overrides folder as a parameter to the codegen tool. For example, update the `codegen` target in the `Makefile` for your project directory to look like:
+5. When you run codegen, include the overrides folder as a parameter to the codegen tool. For example, update the `codegen` target in the `Makefile` for your project directory to look like:
```
$(docker_compose_run) $(APP_CONTAINER) platformics api generate --schemafile ./schema/schema.yaml --template-override-paths template_overrides --output-prefix .
```
diff --git a/docs/HOWTO-override-authorization.md b/docs/HOWTO-override-authorization.md
new file mode 100644
index 0000000..81cfbe9
--- /dev/null
+++ b/docs/HOWTO-override-authorization.md
@@ -0,0 +1,116 @@
+# How To: Override Default Authorization
+
+## Auth Principals
+
+By default, Platformics reads user and role information from JWT's with a special structure:
+
+```json
+{
+ "sub": "USERID GOES HERE",
+ "project_claims": {
+ "member": [123, 456],
+ "owner": [789],
+ "viewer": [333]
+ }
+}
+
+```
+
+However, this may not work for every use case - if your application needs to fetch user and role information from some other source (cookies, external databases, etc) then you'll need to replace Platformics' default behavior with your own. This is pretty straightforward though, since Platformics uses dependency injection to allow many of its default behaviors to be customized!
+
+```python
+# your_app/main.py
+
+from platformics.settings import APISettings
+from database import models
+from fastapi import Depends
+from platformics.api.core.deps import get_auth_principal
+from platformics.security.authorization import Principal
+from platformics.graphql_api.core.deps import get_settings, get_user_token
+from platformics.security.token_auth import get_token_claims
+from starlette.requests import Request
+
+...
+
+# Create and run app
+app = get_app(settings, schema, models)
+
+
+# This is a FastAPI Dependency (https://fastapi.tiangolo.com/tutorial/dependencies/) and can
+# depend on any of platformics' built-in dependencies, or any extra dependencies you may choose
+# to define!
+def override_auth_principal(request: Request, settings: APISettings = Depends(get_settings), user_token: typing.optional[str] = Depends(get_user_token)) -> typing.Optional[Principal]:
+ if user_token:
+ claims = get_token_claims(user_token)
+ else:
+ claims = {"sub": "anonymous"}
+
+ # Create an anonymous auth scope if we don't have a logged in user!
+ return Principal(
+ claims["sub"[,
+ roles=["user"],
+ attr={
+ "user_id": claims["sub"],
+ "owner_projects": [],
+ "member_projects": [],
+ "service_identity": [],
+ # This value can be read from a secret or external db or anything you wish.
+ # It's just hardcoded here for brevity.
+ "viewer_projects": [444],
+ },
+ )
+
+# This override ensures that every time the API tries to fetch information about a user and their
+# roles, your code will be called instead of the Platformics built-in functionality.
+app.dependency_overrides[get_auth_principal] = override_auth_principal
+
+...
+```
+
+## Authorized Queries
+
+Platformics generates authorized SQL queries via [Cerbos' SQLAlchemy](https://docs.cerbos.dev/cerbos/latest/recipes/orm/sqlalchemy/index.html) integration by default. If you need to add additional filters to queries, or even skip using Cerbos entirely, you'll need to extend the base `platformics.security.authorization.AuthzClient` class to suit your own needs, and update the app's dependencies to use your modified AuthzClient class instead:
+
+```python
+# your_app/main.py
+import typing
+
+from cerbos.sdk.model import Resource, ResourceDesc
+from platformics.security.authorization import Principal, AuthzClient
+from platformics.settings import APISettings
+from sqlalchemy.sql import Select
+from platformics.graphql_api.core.deps import get_authz_client
+from fastapi import Depends
+
+...
+
+# You can override any subset of the following methods!
+class CustomAuthzClient(AuthzClient):
+ def __init__(self, settings: APISettings):
+ # Set up your class
+ ...
+
+ def can_create(self, resource, principal: Principal) -> bool:
+ # Return a boolean value representing whether the user has permission to create the resource
+ ...
+
+ def can_update(self, resource, principal: Principal) -> bool:
+ # Return a boolean value representing whether the user has permission to update the resource
+ ...
+
+ def get_resource_query(self, principal: Principal, action: AuthzAction, model_cls, relationship) -> Select:
+ # Return a SQLAlchemy query for the given model_cls with security filters already applied
+ ...
+
+ def modify_where_clause(self, principal: Principal, action: AuthzAction, model_cls, where_clauses) -> Select:
+ # Add additional filters to a query before it is executed.
+ ...
+
+def get_customized_authz_client(settings: APISettings = Depends(get_settings)):
+ return CustomAuthzClient(settings)
+
+# This override ensures that every time the API tries to fetch an authorization client
+# roles, your code will be called instead of the Platformics built-in functionality.
+app.dependency_overrides[get_authz_client] = get_customized_authz_client
+
+...
diff --git a/docs/HOWTO-working-with-platformics.md b/docs/HOWTO-working-with-platformics.md
index 939a833..e3ce086 100644
--- a/docs/HOWTO-working-with-platformics.md
+++ b/docs/HOWTO-working-with-platformics.md
@@ -15,7 +15,7 @@ Notable files and subdirectories:
* `files.py` - GQL types, mutations, queries for files
* `codegen/`
* `lib/linkml_wrappers.py` - convenience functions for converting LinkML to generated code
- * `templates/` - all Jinja templates for codegen. Entity-related templates can be overridden with [custom templates](https://github.com/chanzuckerberg/platformics/tree/main/platformics/docs/HOWTO-customize-templates.md).
+ * `templates/` - all Jinja templates for codegen. Entity-related templates can be overridden with [custom templates](https://github.com/chanzuckerberg/platformics/tree/main/platformics/docs/HOWTO-customize-templates.md).
* `generator.py` - script handling all logic of applying Jinja templates to LinkML schema to generate code
* `database/`
* `models/`
@@ -31,7 +31,7 @@ Notable files and subdirectories:
Notable files and subdirectories:
* `api/` - entrypoint for GQL API service
* `helpers/` - generated GQL types and helper functions for GROUPBY queries
- * `types/` - generated GQL types
+ * `types/` - generated GQL types
* `mutations.py` - generated mutations (create, update, delete) for each entity type
* `queries.py` - generated queries (list and aggregate) for each entity type
* `schema.graphql` - GQL format schema
@@ -40,7 +40,7 @@ Notable files and subdirectories:
* `cerbos/` - generated access policies for user actions for each entity type
* `database/` - code related to establishing DB connections / sessions
* `migrations/` - alembic migrations
- * `models/` - generated SQLAlchemy models
+ * `models/` - generated SQLAlchemy models
* `schema/`
* `schema.yaml` - LinkML schema used to codegen entity-related files
* `test_infra/`
@@ -59,7 +59,7 @@ Containers (`test_app/docker-compose.yml`)
* `platformics-db`: Postgres database
* `graphql-api`: API
-When developing on `platformics` itself, running `make dev` will start all of the above containers, then stop the `graphql-api` container and start a new `dev-app` compose service.
+When developing on `platformics` itself, running `make dev` will start all of the above containers, then stop the `graphql-api` container and start a new `dev-app` compose service.
The compose service called `dev-app` has the `platformics` directory in this repo mounted inside the `test_app` application as a sub-module, so it can be edited directly and be debugged via the VSCode debugger.
`graphql-api` and `dev-app` share a port, so the `graphql-api` container is stopped before starting the `dev-app` container.
@@ -77,4 +77,4 @@ For either of these two flows, the main app will be listening on port 9009 and d
### Queries
-To view SQL logs for queries, set `DB_ECHO=true` in `docker-compose.yml`. Run `make start` or `docker compose up -d` to apply the change.
\ No newline at end of file
+To view SQL logs for queries, set `DB_ECHO=true` in `docker-compose.yml`. Run `make start` or `docker compose up -d` to apply the change.
diff --git a/docs/images/platformics_libs.svg b/docs/images/platformics_libs.svg
index 2b97709..0ce11a0 100644
--- a/docs/images/platformics_libs.svg
+++ b/docs/images/platformics_libs.svg
@@ -1 +1 @@
-
\ No newline at end of file
+
diff --git a/platformics/codegen/generator.py b/platformics/codegen/generator.py
index 53a421d..4e43b1b 100644
--- a/platformics/codegen/generator.py
+++ b/platformics/codegen/generator.py
@@ -5,7 +5,9 @@
import logging
import os
+import re
+import jinja2.ext
from jinja2 import Environment, FileSystemLoader
from linkml_runtime.utils.schemaview import SchemaView
@@ -122,6 +124,16 @@ def generate_entity_import_files(
print(f"... wrote {filename}")
+def regex_replace(txt, rgx, val, ignorecase=False, multiline=False):
+ flag = 0
+ if ignorecase:
+ flag |= re.I
+ if multiline:
+ flag |= re.M
+ compiled_rgx = re.compile(rgx, flag)
+ return compiled_rgx.sub(val, txt)
+
+
def generate(schemafile: str, output_prefix: str, render_files: bool, template_override_paths: tuple[str]) -> None:
"""
Launch code generation
@@ -130,7 +142,8 @@ def generate(schemafile: str, output_prefix: str, render_files: bool, template_o
template_paths.append(
os.path.join(os.path.abspath(os.path.dirname(__file__)), "templates/"),
) # default template path
- environment = Environment(loader=FileSystemLoader(template_paths))
+ environment = Environment(loader=FileSystemLoader(template_paths), extensions=[jinja2.ext.loopcontrols])
+ environment.filters["regex_replace"] = regex_replace
view = SchemaView(schemafile)
view.imports_closure()
wrapped_view = ViewWrapper(view)
diff --git a/platformics/codegen/templates/database/migrations/env.py.j2 b/platformics/codegen/templates/database/migrations/env.py.j2
index 1f6583e..7b61cda 100644
--- a/platformics/codegen/templates/database/migrations/env.py.j2
+++ b/platformics/codegen/templates/database/migrations/env.py.j2
@@ -43,6 +43,8 @@ def run_migrations_offline() -> None:
)
with context.begin_transaction():
+ context.get_context()._ensure_version_table() # pylint: disable=protected-access
+ connection.execute(sa.sql.text("LOCK TABLE alembic_version IN ACCESS EXCLUSIVE MODE"))
context.run_migrations()
diff --git a/platformics/codegen/templates/graphql_api/types/class_name.py.j2 b/platformics/codegen/templates/graphql_api/types/class_name.py.j2
index eaab6b1..23023c9 100644
--- a/platformics/codegen/templates/graphql_api/types/class_name.py.j2
+++ b/platformics/codegen/templates/graphql_api/types/class_name.py.j2
@@ -34,14 +34,12 @@ from platformics.graphql_api.types.entities import EntityInterface
from graphql_api.types.{{related_field.related_class.snake_name}} import ({{related_field.related_class.name}}Aggregate, format_{{related_field.related_class.snake_name}}_aggregate_output)
{%- endif %}
{%- endfor %}
-from cerbos.sdk.client import CerbosClient
-from cerbos.sdk.model import Principal, Resource
from fastapi import Depends
from platformics.graphql_api.core.errors import PlatformicsError
-from platformics.graphql_api.core.deps import get_cerbos_client, get_db_session, require_auth_principal, is_system_user
+from platformics.graphql_api.core.deps import get_authz_client, get_db_session, require_auth_principal, is_system_user
from platformics.graphql_api.core.query_input_types import aggregator_map, orderBy, EnumComparators, DatetimeComparators, IntComparators, FloatComparators, StrComparators, UUIDComparators, BoolComparators
from platformics.graphql_api.core.strawberry_extensions import DependencyExtension
-from platformics.security.authorization import CerbosAction, get_resource_query
+from platformics.security.authorization import AuthzAction, AuthzClient, Principal
from sqlalchemy import inspect
from sqlalchemy.engine.row import RowMapping
from sqlalchemy.ext.asyncio import AsyncSession
@@ -417,7 +415,7 @@ Utilities
@strawberry.field(extensions=[DependencyExtension()])
async def resolve_{{ cls.plural_snake_name }}(
session: AsyncSession = Depends(get_db_session, use_cache=False),
- cerbos_client: CerbosClient = Depends(get_cerbos_client),
+ authz_client: AuthzClient = Depends(get_authz_client),
principal: Principal = Depends(require_auth_principal),
where: Optional[{{ cls.name }}WhereClause] = None,
order_by: Optional[list[{{ cls.name }}OrderByClause]] = [],
@@ -430,7 +428,7 @@ async def resolve_{{ cls.plural_snake_name }}(
offset = limit_offset["offset"] if limit_offset and "offset" in limit_offset else None
if offset and not limit:
raise PlatformicsError("Cannot use offset without limit")
- return await get_db_rows(db.{{ cls.name }}, session, cerbos_client, principal, where, order_by, CerbosAction.VIEW, limit, offset) # type: ignore
+ return await get_db_rows(db.{{ cls.name }}, session, authz_client, principal, where, order_by, AuthzAction.VIEW, limit, offset) # type: ignore
def format_{{ cls.snake_name }}_aggregate_output(query_results: Sequence[RowMapping] | RowMapping) -> {{ cls.name }}Aggregate:
@@ -481,7 +479,7 @@ def format_{{ cls.snake_name }}_aggregate_row(row: RowMapping) -> {{ cls.name }}
async def resolve_{{ cls.plural_snake_name }}_aggregate(
info: Info,
session: AsyncSession = Depends(get_db_session, use_cache=False),
- cerbos_client: CerbosClient = Depends(get_cerbos_client),
+ authz_client: AuthzClient = Depends(get_authz_client),
principal: Principal = Depends(require_auth_principal),
where: Optional[{{ cls.name }}WhereClause] = None,
# TODO: add support for groupby, limit/offset
@@ -499,7 +497,7 @@ async def resolve_{{ cls.plural_snake_name }}_aggregate(
if not aggregate_selections:
raise PlatformicsError("No aggregate functions selected")
- rows = await get_aggregate_db_rows(db.{{ cls.name }}, session, cerbos_client, principal, where, aggregate_selections, [], groupby_selections) # type: ignore
+ rows = await get_aggregate_db_rows(db.{{ cls.name }}, session, authz_client, principal, where, aggregate_selections, [], groupby_selections) # type: ignore
aggregate_output = format_{{ cls.snake_name }}_aggregate_output(rows)
return aggregate_output
@@ -508,7 +506,7 @@ async def resolve_{{ cls.plural_snake_name }}_aggregate(
async def create_{{ cls.snake_name }}(
input: {{ cls.name }}CreateInput,
session: AsyncSession = Depends(get_db_session, use_cache=False),
- cerbos_client: CerbosClient = Depends(get_cerbos_client),
+ authz_client: AuthzClient = Depends(get_authz_client),
principal: Principal = Depends(require_auth_principal),
is_system_user: bool = Depends(is_system_user),
) -> db.{{ cls.name }}:
@@ -531,23 +529,12 @@ async def create_{{ cls.snake_name }}(
{%- endif %}
{%- endif %}
-
- {%- for field in cls.create_fields %}
- {%- if field.name == "collection_id" %}
- # Validate that the user can create entities in this collection
- attr = {"collection_id": validated.collection_id, "owner_user_id": int(principal.id)}
- resource = Resource(id="NEW_ID", kind=db.{{ cls.name }}.__tablename__, attr=attr)
- if not cerbos_client.is_allowed("create", principal, resource):
- raise PlatformicsError("Unauthorized: Cannot create entity in this collection")
- {%- endif %}
- {%- endfor %}
-
# Validate that the user can read all of the entities they're linking to.
{%- for field in cls.create_fields %}
{%- if field.is_entity and not field.is_virtual_relationship %}
# Check that {{field.name}} relationship is accessible.
if validated.{{field.name}}_id:
- {{field.name}} = await get_db_rows(db.{{ field.related_class.name }}, session, cerbos_client, principal, {"id": {"_eq": validated.{{field.name}}_id } }, [], CerbosAction.VIEW)
+ {{field.name}} = await get_db_rows(db.{{ field.related_class.name }}, session, authz_client, principal, {"id": {"_eq": validated.{{field.name}}_id } }, [], AuthzAction.VIEW)
if not {{field.name}}:
raise PlatformicsError("Unauthorized: {{field.name}} does not exist")
{%- endif %}
@@ -556,6 +543,11 @@ async def create_{{ cls.snake_name }}(
# Save to DB
params["owner_user_id"] = int(principal.id)
new_entity = db.{{ cls.name }}(**params)
+
+ # Are we actually allowed to create this entity?
+ if not authz_client.can_create(new_entity, principal):
+ raise PlatformicsError("Unauthorized: Cannot create entity")
+
session.add(new_entity)
await session.commit()
return new_entity
@@ -568,7 +560,7 @@ async def update_{{ cls.snake_name }}(
input: {{ cls.name }}UpdateInput,
where: {{ cls.name }}WhereClauseMutations,
session: AsyncSession = Depends(get_db_session, use_cache=False),
- cerbos_client: CerbosClient = Depends(get_cerbos_client),
+ authz_client: AuthzClient = Depends(get_authz_client),
principal: Principal = Depends(require_auth_principal),
is_system_user: bool = Depends(is_system_user),
) -> Sequence[db.{{ cls.name }}]:
@@ -588,7 +580,7 @@ async def update_{{ cls.snake_name }}(
{%- if field.is_entity and not field.is_virtual_relationship %}
# Check that {{field.name}} relationship is accessible.
if validated.{{field.name}}_id:
- {{field.name}} = await get_db_rows(db.{{ field.related_class.name }}, session, cerbos_client, principal, {"id": {"_eq": validated.{{field.name}}_id } }, [], CerbosAction.VIEW)
+ {{field.name}} = await get_db_rows(db.{{ field.related_class.name }}, session, authz_client, principal, {"id": {"_eq": validated.{{field.name}}_id } }, [], AuthzAction.VIEW)
if not {{field.name}}:
raise PlatformicsError("Unauthorized: {{field.name}} does not exist")
{%- if field.type != cls.name %}
@@ -611,21 +603,10 @@ async def update_{{ cls.snake_name }}(
{%- endif %}
# Fetch entities for update, if we have access to them
- entities = await get_db_rows(db.{{ cls.name }}, session, cerbos_client, principal, where, [], CerbosAction.UPDATE)
+ entities = await get_db_rows(db.{{ cls.name }}, session, authz_client, principal, where, [], AuthzAction.UPDATE)
if len(entities) == 0:
raise PlatformicsError("Unauthorized: Cannot update entities")
- {%- for field in self.mutable_fields %}
- {%- if field.name == "collection_id" %}
- # Validate that the user has access to the new collection ID
- if validated.collection_id:
- attr = {"collection_id": validated.collection_id}
- resource = Resource(id="SOME_ID", kind=db.{{ cls.name }}.__tablename__, attr=attr)
- if not cerbos_client.is_allowed(CerbosAction.UPDATE, principal, resource):
- raise PlatformicsError("Unauthorized: Cannot access new collection")
- {%- endif %}
- {%- endfor %}
-
# Update DB
updated_at = datetime.datetime.now()
for entity in entities:
@@ -633,6 +614,10 @@ async def update_{{ cls.snake_name }}(
for key in params:
if params[key] is not None:
setattr(entity, key, params[key])
+
+ if not authz_client.can_update(entity, principal):
+ raise PlatformicsError("Unauthorized: Cannot access new collection")
+
await session.commit()
return entities
{%- endif %}
@@ -642,14 +627,14 @@ async def update_{{ cls.snake_name }}(
async def delete_{{ cls.snake_name }}(
where: {{ cls.name }}WhereClauseMutations,
session: AsyncSession = Depends(get_db_session, use_cache=False),
- cerbos_client: CerbosClient = Depends(get_cerbos_client),
+ authz_client: AuthzClient = Depends(get_authz_client),
principal: Principal = Depends(require_auth_principal),
) -> Sequence[db.{{ cls.name }}]:
"""
Delete {{ cls.name }} objects. Used for mutations (see graphql_api/mutations.py).
"""
# Fetch entities for deletion, if we have access to them
- entities = await get_db_rows(db.{{ cls.name }}, session, cerbos_client, principal, where, [], CerbosAction.DELETE)
+ entities = await get_db_rows(db.{{ cls.name }}, session, authz_client, principal, where, [], AuthzAction.DELETE)
if len(entities) == 0:
raise PlatformicsError("Unauthorized: Cannot delete entities")
diff --git a/platformics/codegen/templates/support/enums.py.j2 b/platformics/codegen/templates/support/enums.py.j2
index 1bdc6bf..27509e9 100644
--- a/platformics/codegen/templates/support/enums.py.j2
+++ b/platformics/codegen/templates/support/enums.py.j2
@@ -13,6 +13,7 @@ import enum
@strawberry.enum
class {{enum.name}}(enum.Enum):
{%- for value in enum.permissible_values %}
- {{value}} = "{{value}}"
+ {#- SQLAlchemy freaks out about spaces in enum values :'( #}
+ {{value | regex_replace('[^0-9A-Za-z]', '_') }} = "{{value | regex_replace('[^0-9A-Za-z]', '_') }}"
{%- endfor %}
{%- endfor %}
diff --git a/platformics/graphql_api/core/deps.py b/platformics/graphql_api/core/deps.py
index 6f0bc3a..7b4cec8 100644
--- a/platformics/graphql_api/core/deps.py
+++ b/platformics/graphql_api/core/deps.py
@@ -2,17 +2,15 @@
import boto3
from botocore.client import Config
-from cerbos.sdk.client import CerbosClient
-from cerbos.sdk.model import Principal
from fastapi import Depends
from mypy_boto3_s3.client import S3Client
from mypy_boto3_sts.client import STSClient
from sqlalchemy.ext.asyncio import AsyncSession
from starlette.requests import Request
-from platformics.graphql_api.core.error_handler import PlatformicsError
from platformics.database.connect import AsyncDB, init_async_db
-from platformics.security.token_auth import get_token_claims
+from platformics.graphql_api.core.error_handler import PlatformicsError
+from platformics.security.authorization import AuthzClient, Principal, hydrate_auth_principal
from platformics.settings import APISettings
@@ -30,7 +28,7 @@ async def get_engine(
settings: APISettings = Depends(get_settings),
) -> typing.AsyncGenerator[AsyncDB, None]:
"""Wrap resolvers in a DB engine"""
- engine = init_async_db(settings.DB_URI, echo=settings.DB_ECHO)
+ engine = init_async_db(settings.DB_URI, echo=settings.DB_ECHO) # type: ignore
try:
yield engine
finally:
@@ -48,8 +46,8 @@ async def get_db_session(
await session.close() # type: ignore
-def get_cerbos_client(settings: APISettings = Depends(get_settings)) -> CerbosClient:
- return CerbosClient(host=settings.CERBOS_URL)
+def get_authz_client(settings: APISettings = Depends(get_settings)) -> AuthzClient:
+ return AuthzClient(settings=settings)
def get_user_token(request: Request) -> typing.Optional[str]:
@@ -70,38 +68,11 @@ def get_auth_principal(
settings: APISettings = Depends(get_settings),
user_token: typing.Optional[str] = Depends(get_user_token),
) -> typing.Optional[Principal]:
- if not user_token:
- return None
try:
- claims = get_token_claims(settings.JWK_PRIVATE_KEY, user_token)
+ principal = hydrate_auth_principal(settings, user_token)
except: # noqa
- return None
-
- if "project_roles" not in claims:
- raise PlatformicsError("Unauthorized")
-
- project_claims = claims["project_roles"]
-
- try:
- for role, project_ids in project_claims.items():
- assert role in ["member", "owner", "viewer"]
- assert isinstance(project_ids, list)
- for item in project_ids:
- assert int(item)
- except Exception:
raise PlatformicsError("Unauthorized") from None
-
- return Principal(
- claims["sub"],
- roles=["user"],
- attr={
- "user_id": int(claims["sub"]),
- "owner_projects": project_claims.get("owner", []),
- "member_projects": project_claims.get("member", []),
- "viewer_projects": project_claims.get("viewer", []),
- "service_identity": claims["service_identity"],
- },
- )
+ return principal
def require_auth_principal(
diff --git a/platformics/graphql_api/core/gql_loaders.py b/platformics/graphql_api/core/gql_loaders.py
index 4d5a228..c6fee78 100644
--- a/platformics/graphql_api/core/gql_loaders.py
+++ b/platformics/graphql_api/core/gql_loaders.py
@@ -2,16 +2,14 @@
from collections import defaultdict
from typing import Any, Mapping, Optional, Sequence, Tuple
-from cerbos.sdk.client import CerbosClient
-from cerbos.sdk.model import Principal
from sqlalchemy.orm import RelationshipProperty
from strawberry.dataloader import DataLoader
import platformics.database.models as db
+from platformics.database.connect import AsyncDB
from platformics.graphql_api.core.errors import PlatformicsError
from platformics.graphql_api.core.query_builder import get_aggregate_db_query, get_db_query, get_db_rows
-from platformics.database.connect import AsyncDB
-from platformics.security.authorization import CerbosAction
+from platformics.security.authorization import AuthzAction, AuthzClient, Principal
E = typing.TypeVar("E", db.File, db.Entity) # type: ignore
T = typing.TypeVar("T")
@@ -38,11 +36,11 @@ class EntityLoader:
_loaders: dict[RelationshipProperty, DataLoader]
_aggregate_loaders: dict[RelationshipProperty, DataLoader]
- def __init__(self, engine: AsyncDB, cerbos_client: CerbosClient, principal: Principal) -> None:
+ def __init__(self, engine: AsyncDB, authz_client: AuthzClient, principal: Principal) -> None:
self._loaders = {}
self._aggregate_loaders = {}
self.engine = engine
- self.cerbos_client = cerbos_client
+ self.authz_client = authz_client
self.principal = principal
async def resolve_nodes(self, cls: Any, node_ids: list[str]) -> Sequence[E]:
@@ -51,7 +49,7 @@ async def resolve_nodes(self, cls: Any, node_ids: list[str]) -> Sequence[E]:
"""
db_session = self.engine.session()
where = {"entity_id": {"_in": node_ids}}
- rows = await get_db_rows(cls, db_session, self.cerbos_client, self.principal, where)
+ rows = await get_db_rows(cls, db_session, self.authz_client, self.principal, where)
await db_session.close()
return rows
@@ -88,8 +86,8 @@ async def load_fn(keys: list[Any]) -> typing.Sequence[Any]:
filters.append(remote.in_(keys))
query = get_db_query(
related_model,
- CerbosAction.VIEW,
- self.cerbos_client,
+ AuthzAction.VIEW,
+ self.authz_client,
self.principal,
where,
order_by, # type: ignore
@@ -161,8 +159,8 @@ async def load_fn(keys: list[Any]) -> typing.Sequence[Any]:
query, group_by = get_aggregate_db_query(
related_model,
- CerbosAction.VIEW,
- self.cerbos_client,
+ AuthzAction.VIEW,
+ self.authz_client,
self.principal,
where,
aggregate_selections,
diff --git a/platformics/graphql_api/core/query_builder.py b/platformics/graphql_api/core/query_builder.py
index 652d605..2a98b52 100644
--- a/platformics/graphql_api/core/query_builder.py
+++ b/platformics/graphql_api/core/query_builder.py
@@ -7,8 +7,6 @@
from typing import Any, Optional, Sequence, Tuple
import strcase
-from cerbos.sdk.client import CerbosClient
-from cerbos.sdk.model import Principal
from sqlalchemy import ColumnElement, and_, distinct, inspect
from sqlalchemy.engine.row import RowMapping
from sqlalchemy.ext.asyncio import AsyncSession
@@ -17,10 +15,10 @@
from typing_extensions import TypedDict
import platformics.database.models as db
+from platformics.database.models.base import Base
from platformics.graphql_api.core.errors import PlatformicsError
from platformics.graphql_api.core.query_input_types import aggregator_map, operator_map, orderBy
-from platformics.database.models.base import Base
-from platformics.security.authorization import CerbosAction, get_resource_query
+from platformics.security.authorization import AuthzAction, AuthzClient, Principal
E = typing.TypeVar("E", db.File, db.Entity)
T = typing.TypeVar("T")
@@ -51,8 +49,8 @@ class IndexedOrderByClause(TypedDict):
def convert_where_clauses_to_sql(
principal: Principal,
- cerbos_client: CerbosClient,
- action: CerbosAction,
+ authz_client: AuthzClient,
+ action: AuthzAction,
query: Select,
sa_model: Base,
where_clause: dict[str, Any],
@@ -102,7 +100,7 @@ def convert_where_clauses_to_sql(
# Unless deleted_at is explicitly set in the where clause OR we are performing a DELETE action,
# we should only return rows where deleted_at is null. This is to ensure that we don't return soft-deleted rows.
# Don't do this for files, since they don't have a deleted_at field.
- if "deleted_at" not in local_where_clauses and action != CerbosAction.DELETE and sa_model.__name__ != "File":
+ if "deleted_at" not in local_where_clauses and action != AuthzAction.DELETE and sa_model.__name__ != "File":
local_where_clauses["deleted_at"] = {"_is_null": True}
for group in group_by: # type: ignore
col = strcase.to_snake(group.name)
@@ -119,13 +117,13 @@ def convert_where_clauses_to_sql(
for join_field, join_info in all_joins.items():
relationship = mapper.relationships[join_field] # type: ignore
related_cls = relationship.mapper.entity
- cerbos_query = get_resource_query(principal, cerbos_client, action, related_cls)
+ secure_query = authz_client.get_resource_query(principal, action, related_cls)
# Get the subquery, nested order_by fields, and nested group_by fields that need to be applied to the current query
subquery, subquery_order_by, subquery_group_by = convert_where_clauses_to_sql(
principal,
- cerbos_client,
+ authz_client,
action,
- cerbos_query,
+ secure_query,
related_cls,
join_info.get("where"), # type: ignore
join_info.get("order_by"),
@@ -182,8 +180,8 @@ def convert_where_clauses_to_sql(
def get_db_query(
model_cls: type[E],
- action: CerbosAction,
- cerbos_client: CerbosClient,
+ action: AuthzAction,
+ authz_client: AuthzClient,
principal: Principal,
# TODO it would be nicer if we could have the WhereClause classes inherit from a BaseWhereClause
# so that these type checks could be smarter, but TypedDict doesn't support type checks like that
@@ -194,14 +192,14 @@ def get_db_query(
Given a model class and a where clause, return a SQLAlchemy query that is limited
based on the where clause, and which entities the user has access to.
"""
- query = get_resource_query(principal, cerbos_client, action, model_cls)
+ query = authz_client.get_resource_query(principal, action, model_cls)
# Add indices to the order_by fields so that we can preserve the order of the fields
if order_by is None:
order_by = []
order_by = [IndexedOrderByClause({"field": x, "index": i}) for i, x in enumerate(order_by)] # type: ignore
query, order_by, _group_by = convert_where_clauses_to_sql(
principal,
- cerbos_client,
+ authz_client,
action,
query,
model_cls, # type: ignore
@@ -220,11 +218,11 @@ def get_db_query(
async def get_db_rows(
model_cls: type[E], # type: ignore
session: AsyncSession,
- cerbos_client: CerbosClient,
+ authz_client: AuthzClient,
principal: Principal,
where: Any,
order_by: Optional[list[dict[str, Any]]] = None,
- action: CerbosAction = CerbosAction.VIEW,
+ action: AuthzAction = AuthzAction.VIEW,
limit: Optional[int] = None,
offset: Optional[int] = None,
) -> typing.Sequence[E]:
@@ -233,7 +231,7 @@ async def get_db_rows(
"""
if order_by is None:
order_by = []
- query = get_db_query(model_cls, action, cerbos_client, principal, where, order_by)
+ query = get_db_query(model_cls, action, authz_client, principal, where, order_by)
if limit:
query = query.limit(limit)
if offset:
@@ -244,8 +242,8 @@ async def get_db_rows(
def get_aggregate_db_query(
model_cls: type[E],
- action: CerbosAction,
- cerbos_client: CerbosClient,
+ action: AuthzAction,
+ authz_client: AuthzClient,
principal: Principal,
where: dict[str, Any],
aggregate: Any,
@@ -264,7 +262,7 @@ def get_aggregate_db_query(
# TODO, this may need to be adjusted, 5 just seemed like a reasonable starting point
if depth >= 5:
raise Exception("Max filter depth exceeded")
- query = get_resource_query(principal, cerbos_client, action, model_cls)
+ query = authz_client.get_resource_query(principal, action, model_cls)
# Deconstruct the aggregate dict and build mappings for the query
aggregate_query_fields = []
if remote is not None:
@@ -291,7 +289,7 @@ def get_aggregate_db_query(
query = query.with_only_columns(*aggregate_query_fields)
query, _order_by, group_by = convert_where_clauses_to_sql(
principal,
- cerbos_client,
+ authz_client,
action,
query,
model_cls, # type: ignore
@@ -308,18 +306,18 @@ def get_aggregate_db_query(
async def get_aggregate_db_rows(
model_cls: type[E], # type: ignore
session: AsyncSession,
- cerbos_client: CerbosClient,
+ authz_client: AuthzClient,
principal: Principal,
where: Any,
aggregate: Any,
order_by: Optional[list[tuple[ColumnElement[Any], ...]]] = None,
group_by: Optional[ColumnElement[Any]] | Optional[list[Any]] = None,
- action: CerbosAction = CerbosAction.VIEW,
+ action: AuthzAction = AuthzAction.VIEW,
) -> Sequence[RowMapping]:
"""
Retrieve aggregate rows from the database, filtered by the where clause and the user's permissions.
"""
- query, group_by = get_aggregate_db_query(model_cls, action, cerbos_client, principal, where, aggregate, group_by)
+ query, group_by = get_aggregate_db_query(model_cls, action, authz_client, principal, where, aggregate, group_by)
if group_by:
query = query.group_by(*group_by) # type: ignore
result = await session.execute(query)
diff --git a/platformics/graphql_api/files.py b/platformics/graphql_api/files.py
index 3bd2901..6592cc5 100644
--- a/platformics/graphql_api/files.py
+++ b/platformics/graphql_api/files.py
@@ -9,23 +9,22 @@
import uuid
from dataclasses import dataclass
-import database.models as db
+import sqlalchemy as sa
import strawberry
import uuid6
-from cerbos.sdk.client import CerbosClient
-from cerbos.sdk.model import Principal
from fastapi import Depends
from mypy_boto3_s3.client import S3Client
from mypy_boto3_sts.client import STSClient
-from sqlalchemy import inspect
+from sqlalchemy.exc import NoResultFound
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql import func
from strawberry.scalars import JSON
from strawberry.types import Info
from typing_extensions import TypedDict
+import platformics.database.models as db
from platformics.graphql_api.core.deps import (
- get_cerbos_client,
+ get_authz_client,
get_db_session,
get_s3_client,
get_settings,
@@ -33,12 +32,19 @@
require_auth_principal,
require_system_user,
)
+from platformics.graphql_api.core.errors import PlatformicsError
from platformics.graphql_api.core.query_builder import get_db_rows
-from platformics.graphql_api.core.query_input_types import EnumComparators, IntComparators, StrComparators, UUIDComparators
+from platformics.graphql_api.core.query_input_types import (
+ EnumComparators,
+ IntComparators,
+ StrComparators,
+ UUIDComparators,
+)
from platformics.graphql_api.core.strawberry_extensions import DependencyExtension
from platformics.graphql_api.types.entities import Entity
-from platformics.security.authorization import CerbosAction, get_resource_query
+from platformics.security.authorization import AuthzAction, AuthzClient, Principal
from platformics.settings import APISettings
+from platformics.support import sqlalchemy_helpers
from platformics.support.file_enums import FileAccessProtocol, FileStatus
from platformics.support.format_handlers import get_validator
@@ -139,8 +145,7 @@ async def load_entities(
Dataloader to fetch related entities, given file IDs.
"""
dataloader = info.context["sqlalchemy_loader"]
- mapper = inspect(db.File)
- relationship = mapper.relationships["entity"]
+ relationship = sqlalchemy_helpers.get_relationship(db.File, "entity")
return await dataloader.loader_for(relationship, where).load(root.entity_id) # type:ignore
@@ -234,14 +239,14 @@ class FileWhereClause(TypedDict):
@strawberry.field(extensions=[DependencyExtension()])
async def resolve_files(
session: AsyncSession = Depends(get_db_session, use_cache=False),
- cerbos_client: CerbosClient = Depends(get_cerbos_client),
+ authz_client: AuthzClient = Depends(get_authz_client),
principal: Principal = Depends(require_auth_principal),
where: typing.Optional[FileWhereClause] = None,
) -> typing.Sequence[File]:
"""
Handles files {} GraphQL queries.
"""
- rows = await get_db_rows(db.File, session, cerbos_client, principal, where, [])
+ rows = await get_db_rows(db.File, session, authz_client, principal, where)
return rows # type: ignore
@@ -266,9 +271,9 @@ async def validate_file(
file_size = s3_client.head_object(Bucket=file.namespace, Key=file.path)["ContentLength"]
except: # noqa
- file.status = db.FileStatus.FAILED
+ file.status = FileStatus.FAILED
else:
- file.status = db.FileStatus.SUCCESS
+ file.status = FileStatus.SUCCESS
file.size = file_size
file.updated_at = func.now()
@@ -325,7 +330,7 @@ def generate_multipart_upload_token(
async def mark_upload_complete(
file_id: strawberry.ID,
principal: Principal = Depends(require_auth_principal),
- cerbos_client: CerbosClient = Depends(get_cerbos_client),
+ authz_client: AuthzClient = Depends(get_authz_client),
session: AsyncSession = Depends(get_db_session, use_cache=False),
s3_client: S3Client = Depends(get_s3_client),
) -> db.File:
@@ -333,7 +338,28 @@ async def mark_upload_complete(
Once a file is uploaded, the front-end should make a markUploadComplete mutation
to mark the file as ready for pipeline analysis.
"""
- query = get_resource_query(principal, cerbos_client, CerbosAction.UPDATE, db.File)
+
+ # Get the type of entity that the file is related to
+ try:
+ file_row = (await session.execute(sa.select(db.File).where(db.File.id == file_id))).scalars().one()
+ except NoResultFound:
+ raise PlatformicsError("Unauthorized: cannot update file") from None
+
+ # Fetch the entity if have access to it
+ entity_class, entity = await get_entity_by_id(
+ session,
+ authz_client,
+ principal,
+ AuthzAction.UPDATE,
+ file_row.entity_id,
+ )
+
+ # See if we actually have access to that file.
+ query = authz_client.get_resource_query(
+ principal,
+ AuthzAction.UPDATE,
+ db.File,
+ )
query = query.filter(db.File.id == file_id)
file = (await session.execute(query)).scalars().one()
if not file:
@@ -351,7 +377,7 @@ async def create_file(
entity_field_name: str,
file: FileCreate,
session: AsyncSession = Depends(get_db_session, use_cache=False),
- cerbos_client: CerbosClient = Depends(get_cerbos_client),
+ authz_client: AuthzClient = Depends(get_authz_client),
principal: Principal = Depends(require_auth_principal),
s3_client: S3Client = Depends(get_s3_client),
sts_client: STSClient = Depends(get_sts_client),
@@ -368,7 +394,7 @@ async def create_file(
file,
-1,
session,
- cerbos_client,
+ authz_client,
principal,
s3_client,
sts_client,
@@ -385,7 +411,7 @@ async def upload_file(
file: FileUpload,
expiration: int = 3600,
session: AsyncSession = Depends(get_db_session, use_cache=False),
- cerbos_client: CerbosClient = Depends(get_cerbos_client),
+ authz_client: AuthzClient = Depends(get_authz_client),
principal: Principal = Depends(require_auth_principal),
s3_client: S3Client = Depends(get_s3_client),
sts_client: STSClient = Depends(get_sts_client),
@@ -400,7 +426,7 @@ async def upload_file(
file,
expiration,
session,
- cerbos_client,
+ authz_client,
principal,
s3_client,
sts_client,
@@ -410,13 +436,36 @@ async def upload_file(
return response
+async def get_entity_by_id(
+ session: AsyncSession,
+ authz_client: AuthzClient,
+ principal: Principal,
+ action: AuthzAction,
+ entity_id: strawberry.ID,
+) -> tuple[typing.Type[db.Base], db.Base]:
+ # Fetch the entity if have access to it
+ try:
+ entity_row = (await session.execute(sa.select(db.Entity).where(db.Entity.id == entity_id))).scalars().one()
+ entity_class = sqlalchemy_helpers.get_orm_class_by_name(type(entity_row).__name__)
+ except NoResultFound:
+ raise PlatformicsError("Unauthorized: cannot create file") from None
+
+ query = authz_client.get_resource_query(principal, action, entity_class)
+ query = query.filter(entity_class.entity_id == entity_id)
+ try:
+ entity = (await session.execute(query)).scalars().one()
+ except NoResultFound:
+ raise PlatformicsError("Unauthorized: cannot create file") from None
+ return entity_class, entity
+
+
async def create_or_upload_file(
entity_id: strawberry.ID,
entity_field_name: str,
file: FileCreate | FileUpload,
expiration: int = 3600,
session: AsyncSession = Depends(get_db_session, use_cache=False),
- cerbos_client: CerbosClient = Depends(get_cerbos_client),
+ authz_client: AuthzClient = Depends(get_authz_client),
principal: Principal = Depends(require_auth_principal),
s3_client: S3Client = Depends(get_s3_client),
sts_client: STSClient = Depends(get_sts_client),
@@ -430,11 +479,7 @@ async def create_or_upload_file(
raise Exception("File name should not contain /")
# Fetch the entity if have access to it
- query = get_resource_query(principal, cerbos_client, CerbosAction.UPDATE, db.Entity)
- query = query.filter(db.Entity.id == entity_id)
- entity = (await session.execute(query)).scalars().one()
- if not entity:
- raise Exception("Entity not found")
+ entity_class, entity = await get_entity_by_id(session, authz_client, principal, AuthzAction.UPDATE, entity_id)
# Does that entity type have a column for storing a file ID?
entity_property_name = f"{entity_field_name}_id"
@@ -442,12 +487,17 @@ async def create_or_upload_file(
raise Exception(f"This entity does not have a corresponding file of type {entity_field_name}")
# Unlink the File(s) currently connected to this entity (only commit to DB once add the new File below)
- query = get_resource_query(principal, cerbos_client, CerbosAction.UPDATE, db.File)
- query = query.filter(db.File.entity_id == entity_id)
- query = query.filter(db.File.entity_field_name == entity_field_name)
- current_files = (await session.execute(query)).scalars().all()
- for current_file in current_files:
- current_file.entity_id = None
+ if getattr(entity, entity_property_name):
+ query = authz_client.get_resource_query(
+ principal,
+ AuthzAction.UPDATE,
+ db.File,
+ )
+ query = query.filter(db.File.entity_id == entity_id)
+ query = query.filter(db.File.entity_field_name == entity_field_name)
+ current_files = (await session.execute(query)).scalars().all()
+ for current_file in current_files:
+ current_file.entity_id = None
# Set file parameters based on user inputs
file_id = uuid6.uuid7()
@@ -470,7 +520,7 @@ async def create_or_upload_file(
path=file_path,
file_format=file.file_format,
compression_type=file.compression_type,
- status=db.FileStatus.PENDING,
+ status=FileStatus.PENDING,
)
# Save file to db first
session.add(new_file)
@@ -517,7 +567,7 @@ async def upload_temporary_file(
async def concatenate_files(
ids: typing.Sequence[uuid.UUID],
session: AsyncSession = Depends(get_db_session, use_cache=False),
- cerbos_client: CerbosClient = Depends(get_cerbos_client),
+ authz_client: AuthzClient = Depends(get_authz_client),
principal: Principal = Depends(require_auth_principal),
s3_client: S3Client = Depends(get_s3_client),
settings: APISettings = Depends(get_settings),
@@ -532,8 +582,8 @@ async def concatenate_files(
raise Exception(f"Cannot concatenate more than {FILE_CONCATENATION_MAX} files")
# Get files in question if have access to them
- where = {"id": {"_in": ids}, "status": {"_eq": db.FileStatus.SUCCESS}}
- files = await get_db_rows(db.File, session, cerbos_client, principal, where, [])
+ where = {"id": {"_in": ids}, "status": {"_eq": FileStatus.SUCCESS}}
+ files = await get_db_rows(db.File, session, authz_client, principal, where)
if len(files) < 2:
raise Exception("Need at least 2 valid files to concatenate")
for file in files:
diff --git a/platformics/graphql_api/setup.py b/platformics/graphql_api/setup.py
index 1dc4a56..e4fb231 100644
--- a/platformics/graphql_api/setup.py
+++ b/platformics/graphql_api/setup.py
@@ -5,17 +5,22 @@
import typing
import strawberry
-from cerbos.sdk.client import CerbosClient
-from cerbos.sdk.model import Principal
from fastapi import Depends, FastAPI
from strawberry.fastapi import GraphQLRouter
from strawberry.schema.config import StrawberryConfig
from strawberry.schema.name_converter import HasGraphQLName, NameConverter
-from platformics.graphql_api.core.deps import get_auth_principal, get_cerbos_client, get_db_module, get_engine, get_s3_client
-from platformics.graphql_api.core.gql_loaders import EntityLoader
from platformics.database.connect import AsyncDB
from platformics.database.models.file import File
+from platformics.graphql_api.core.deps import (
+ get_auth_principal,
+ get_authz_client,
+ get_db_module,
+ get_engine,
+ get_s3_client,
+)
+from platformics.graphql_api.core.gql_loaders import EntityLoader
+from platformics.security.authorization import AuthzClient, Principal
from platformics.settings import APISettings
# ------------------------------------------------------------------------------
@@ -25,15 +30,15 @@
def get_context(
engine: AsyncDB = Depends(get_engine),
- db_module: AsyncDB = Depends(get_db_module),
- cerbos_client: CerbosClient = Depends(get_cerbos_client),
+ authz_client: AuthzClient = Depends(get_authz_client),
principal: Principal = Depends(get_auth_principal),
+ db_module: AsyncDB = Depends(get_db_module),
) -> dict[str, typing.Any]:
"""
Defines sqlalchemy_loader, used by dataloaders
"""
return {
- "sqlalchemy_loader": EntityLoader(engine=engine, cerbos_client=cerbos_client, principal=principal),
+ "sqlalchemy_loader": EntityLoader(engine=engine, authz_client=authz_client, principal=principal),
# This is entirely to support automatically resolving Relay Nodes in the EntityInterface
"db_module": db_module,
}
diff --git a/platformics/security/authorization.py b/platformics/security/authorization.py
index 73ea484..72716cc 100644
--- a/platformics/security/authorization.py
+++ b/platformics/security/authorization.py
@@ -2,44 +2,142 @@
from enum import Enum
from cerbos.sdk.client import CerbosClient
-from cerbos.sdk.model import Principal, ResourceDesc
+from cerbos.sdk.model import Principal as CerbosPrincipal
+from cerbos.sdk.model import Resource, ResourceDesc
from sqlalchemy.sql import Select
import platformics.database.models as db
+from platformics.security.token_auth import get_token_claims
+from platformics.settings import APISettings
+from platformics.support import sqlalchemy_helpers
from platformics.thirdparty.cerbos_sqlalchemy.query import get_query
-class CerbosAction(str, Enum):
+class AuthzAction(str, Enum):
VIEW = "view"
CREATE = "create"
UPDATE = "update"
DELETE = "delete"
-def get_resource_query(
- principal: Principal,
- cerbos_client: CerbosClient,
- action: CerbosAction,
- model_cls: typing.Union[type[db.Entity], type[db.File]], # type: ignore
-) -> Select:
- rd = ResourceDesc(model_cls.__tablename__)
- plan = cerbos_client.plan_resources(action, principal, rd)
- if model_cls == db.File: # type: ignore
- attr_map = {
- "request.resource.attr.owner_user_id": db.Entity.owner_user_id, # type: ignore
- "request.resource.attr.collection_id": db.Entity.collection_id, # type: ignore
- }
- joins = [(db.Entity, db.File.entity_id == db.Entity.id)] # type: ignore
- else:
- attr_map = {
- "request.resource.attr.owner_user_id": model_cls.owner_user_id, # type: ignore
- "request.resource.attr.collection_id": model_cls.collection_id, # type: ignore
- }
- joins = []
- query = get_query(
- plan,
- model_cls, # type: ignore
- attr_map, # type: ignore
- joins, # type: ignore
+# TODO - right now this is an unmodified alias of Cerbos' principal, but we can extend it in the future
+# if need be. The properties/methods of this class are *only* referenced in this file!
+class Principal(CerbosPrincipal):
+ pass
+
+
+def hydrate_auth_principal(
+ settings: APISettings,
+ user_token: typing.Optional[str],
+) -> typing.Optional[Principal]:
+ if not user_token:
+ return None
+ try:
+ claims = get_token_claims(settings.JWK_PRIVATE_KEY, user_token)
+ except: # noqa
+ return None
+
+ if "project_roles" not in claims:
+ raise Exception("no project roles in claims")
+
+ project_claims = claims["project_roles"]
+
+ try:
+ for role, project_ids in project_claims.items():
+ assert role in ["member", "owner", "viewer"]
+ assert isinstance(project_ids, list)
+ for item in project_ids:
+ assert int(item)
+ except Exception:
+ return None
+
+ return Principal(
+ claims["sub"],
+ roles=["user"],
+ attr={
+ "user_id": int(claims["sub"]),
+ "owner_projects": project_claims.get("owner", []),
+ "member_projects": project_claims.get("member", []),
+ "viewer_projects": project_claims.get("viewer", []),
+ "service_identity": claims["service_identity"],
+ },
)
- return query
+
+
+class AuthzClient:
+ def __init__(self, settings: APISettings):
+ self.settings = settings
+ self.client = CerbosClient(host=settings.CERBOS_URL)
+
+ # Convert a model object to a dictionary
+ def _obj_to_dict(self, obj):
+ mydict = {}
+ relationships = obj.__mapper__.relationships
+ for col in obj.__mapper__.all_orm_descriptors:
+ # Don't send related fields to cerbos for authz checks
+ if col.key in relationships:
+ continue
+ value = getattr(obj, col.key)
+ if type(value) not in [int, str, bool, float]:
+ # TODO, we probably want to look into a smarter way to serialize fields for cerbos
+ value = str(value)
+ mydict[col.key] = value
+ return mydict
+
+ def can_create(self, resource, principal: Principal) -> bool:
+ resource_type = type(resource).__tablename__
+ attr = self._obj_to_dict(resource)
+ resource = Resource(id="NEW_ID", kind=resource_type, attr=attr)
+ if self.client.is_allowed(AuthzAction.CREATE, principal, resource):
+ return True
+ return False
+
+ def can_update(self, resource, principal: Principal) -> bool:
+ resource_type = type(resource).__tablename__
+ attr = self._obj_to_dict(resource)
+ # TODO - this should send in the actual resource ID instead of a placeholder string
+ # There are two complexities there though: UUID's don't natively serialize to json,
+ # so they cannot be sent in cerbos perms checks, and we need to find/use the table's
+ # primary key instead of a hardcoded column name.
+ resource = Resource(id="resource_id", kind=resource_type, attr=attr)
+ if self.client.is_allowed(AuthzAction.UPDATE, principal, resource):
+ return True
+ return False
+
+ # Get a SQLAlchemy model with authz filters already applied
+ def get_resource_query(
+ self,
+ principal: Principal,
+ action: AuthzAction,
+ model_cls: type[db.Base], # type: ignore
+ ) -> Select:
+ rd = ResourceDesc(model_cls.__tablename__)
+ plan = self.client.plan_resources(action, principal, rd)
+
+ attr_map = {}
+ joins = []
+ if model_cls == db.File: # type: ignore
+ for col in sqlalchemy_helpers.model_class_cols(db.Entity):
+ attr_map[f"request.resource.attr.{col.key}"] = getattr(db.Entity, col.key)
+ joins = [(db.Entity, db.File.entity_id == db.Entity.id)] # type: ignore
+ else:
+ # Send all non-relationship columns to cerbos to make decisions
+ for col in sqlalchemy_helpers.model_class_cols(model_cls):
+ attr_map[f"request.resource.attr.{col.key}"] = getattr(model_cls, col.key)
+ query = get_query(
+ plan,
+ model_cls, # type: ignore
+ attr_map, # type: ignore
+ joins, # type: ignore
+ )
+ return query
+
+ # An opportunity to modify SQL WHERE clauses before they get sent to the DB.
+ def modify_where_clause(
+ self,
+ principal: Principal,
+ action: AuthzAction,
+ model_cls: type[db.Base], # type: ignore
+ where_clauses: dict[str, typing.Any],
+ ):
+ pass
diff --git a/platformics/support/sqlalchemy_helpers.py b/platformics/support/sqlalchemy_helpers.py
new file mode 100644
index 0000000..aba6fcc
--- /dev/null
+++ b/platformics/support/sqlalchemy_helpers.py
@@ -0,0 +1,43 @@
+from sqlalchemy import inspect
+from sqlalchemy.orm import ColumnProperty
+from sqlalchemy_utils import get_primary_keys
+
+from platformics.database.models.base import Base
+
+
+def model_class_cols(model_cls):
+ cols = []
+ relationships = model_cls.__mapper__.relationships
+ for col in model_cls.__mapper__.all_orm_descriptors:
+ # Don't send related fields to cerbos for authz checks
+ if col.key in relationships:
+ continue
+ cols.append(col)
+ return cols
+
+
+def get_primary_key(model) -> tuple[str, ColumnProperty]:
+ pks = get_primary_keys(model)
+ if len(pks) != 1:
+ raise Exception(f"Expected exactly one primary key for {model.__name__}")
+ for k, v in pks.items():
+ return k, v
+ raise Exception("PK definition missing")
+
+
+def get_relationship(cls, field):
+ mapper = inspect(cls)
+ relationship = mapper.relationships[field]
+ return relationship
+
+
+# TODO FIXME THIS IS TOO OPEN. THIS SHOULD BE LOCKED DOWN TO ONLY ACCEPTABLE TYPES.
+def get_orm_class_by_name(class_name: str) -> Base:
+ for mapper in Base.registry.mappers:
+ cls = mapper.class_
+ if cls.__name__ == class_name:
+ # Don't allow abstract classes to be manipulated directly
+ # if cls.abstract:
+ # continue
+ return cls
+ raise Exception("Invalid class name")
diff --git a/test_app/.vscode/settings.json b/test_app/.vscode/settings.json
index 9b38853..a3a1838 100644
--- a/test_app/.vscode/settings.json
+++ b/test_app/.vscode/settings.json
@@ -4,4 +4,4 @@
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
-}
\ No newline at end of file
+}
diff --git a/test_app/Makefile b/test_app/Makefile
index 1b0bd4e..970577b 100644
--- a/test_app/Makefile
+++ b/test_app/Makefile
@@ -60,6 +60,7 @@ clean: ## Remove all codegen'd artifacts.
rm -rf cerbos
rm -rf support
rm -rf database
+ rm -rf validators
rm -rf .moto_recording
rm -rf test_infra
$(docker_compose) --profile '*' down
diff --git a/test_app/bin/init_moto.sh b/test_app/bin/init_moto.sh
index 68d2fde..0571096 100755
--- a/test_app/bin/init_moto.sh
+++ b/test_app/bin/init_moto.sh
@@ -3,7 +3,7 @@
# Script to initialize moto server; runs inside the motoserver container
# Launch moto server
-moto_server --host 0.0.0.0 --port $MOTO_PORT &
+moto_server --host 0.0.0.0 --port $MOTO_PORT &
# Initialize data once server is ready
sleep 1 && curl -X POST "http://localhost:${MOTO_PORT}/moto-api/recorder/replay-recording"
diff --git a/test_app/tests/test_authz_overrides.py b/test_app/tests/test_authz_overrides.py
new file mode 100644
index 0000000..2796ea2
--- /dev/null
+++ b/test_app/tests/test_authz_overrides.py
@@ -0,0 +1,163 @@
+"""
+Test that our principal-generation and authz client functionality can be overriden the way the docs say they acan.
+"""
+
+import datetime
+import pytest
+import sqlalchemy as sa
+from platformics.database.connect import SyncDB
+from platformics.security.authorization import AuthzClient
+from platformics.graphql_api.core.deps import get_settings, get_authz_client
+from fastapi import Depends
+from conftest import GQLTestClient, SessionStorage
+from test_infra.factories.sample import SampleFactory
+from fastapi import FastAPI
+from platformics.security.authorization import Principal
+from platformics.settings import APISettings
+from platformics.graphql_api.core.deps import (
+ get_auth_principal,
+)
+
+date_now = datetime.datetime.now()
+
+
+@pytest.mark.asyncio
+async def test_principal_override(
+ api_test_schema: FastAPI,
+ sync_db: SyncDB,
+ gql_client: GQLTestClient,
+) -> None:
+ """
+ Test that we can override the way auth principals get generated. Our tests
+ use this functionality under the hood so we know it works, but since this
+ interface is now *explicitly* documented, it's a breaking change to alter the
+ interface.
+ """
+
+ def custom_auth_principal():
+ return Principal(
+ "user123",
+ roles=["user"],
+ attr={
+ "user_id": "user123",
+ "owner_projects": [],
+ "member_projects": [],
+ "service_identity": [],
+ # This value can be read from a secret or external db or anything you wish.
+ # It's just hardcoded here for brevity.
+ "viewer_projects": [444],
+ },
+ )
+
+ api_test_schema.dependency_overrides[get_auth_principal] = custom_auth_principal
+
+ user_id = 12345
+ secondary_user_id = 67890
+ project_id = 444
+
+ # Create mock data
+ with sync_db.session() as session:
+ SessionStorage.set_session(session)
+ SampleFactory.create_batch(
+ 2,
+ collection_location="San Francisco, CA",
+ collection_date=date_now,
+ owner_user_id=user_id,
+ collection_id=project_id,
+ )
+ SampleFactory.create_batch(
+ 6,
+ collection_location="Mountain View, CA",
+ collection_date=date_now,
+ owner_user_id=user_id,
+ collection_id=project_id,
+ )
+ SampleFactory.create_batch(
+ 4,
+ collection_location="Phoenix, AZ",
+ collection_date=date_now,
+ owner_user_id=secondary_user_id,
+ collection_id=9999,
+ )
+
+ # Fetch all samples
+ query = """
+ query MyQuery {
+ samples {
+ id,
+ collectionLocation
+ }
+ }
+ """
+ output = await gql_client.query(query, user_id=user_id, member_projects=[project_id])
+ locations = [sample["collectionLocation"] for sample in output["data"]["samples"]]
+ assert "San Francisco, CA" in locations
+ assert "Mountain View, CA" in locations
+ assert "Phoenix, AZ" not in locations
+
+
+class CustomAuthzClient(AuthzClient):
+ def get_resource_query(self, principal, action, model_cls, relationship):
+ query = sa.select(model_cls).where(model_cls.name.in_(["apple", "asparagus"]))
+ return query
+
+
+def custom_authz_client(settings: APISettings = Depends(get_settings)) -> AuthzClient:
+ return AuthzClient(settings=settings)
+
+
+@pytest.mark.asyncio
+async def test_authz_client_override(
+ api_test_schema: FastAPI,
+ sync_db: SyncDB,
+ gql_client: GQLTestClient,
+) -> None:
+ """
+ Test that we can override the way auth principals get generated. Our tests
+ use this functionality under the hood so we know it works, but since this
+ interface is now *explicitly* documented, it's a breaking change to alter the
+ interface.
+ """
+
+ api_test_schema.dependency_overrides[get_authz_client] = custom_authz_client
+
+ user_id = 12345
+ secondary_user_id = 67890
+ project_id = 444
+
+ # Create mock data
+ with sync_db.session() as session:
+ SessionStorage.set_session(session)
+ SampleFactory.create(
+ name="bananas",
+ collection_date=date_now,
+ owner_user_id=user_id,
+ collection_id=project_id,
+ )
+ SampleFactory.create(
+ name="apples",
+ collection_date=date_now,
+ owner_user_id=user_id,
+ collection_id=project_id,
+ )
+ SampleFactory.create(
+ name="asparagus",
+ collection_date=date_now,
+ owner_user_id=secondary_user_id,
+ collection_id=project_id,
+ )
+
+ # Fetch all samples
+ query = """
+ query MyQuery {
+ samples {
+ id,
+ name
+ }
+ }
+ """
+ output = await gql_client.query(query, user_id=user_id, member_projects=[project_id])
+ names = [sample["name"] for sample in output["data"]["samples"]]
+ assert "apples" in names
+ assert "asparagus" in names
+ assert "banana" not in names
diff --git a/test_app/tests/test_error_handling.py b/test_app/tests/test_error_handling.py
index fd92f35..05a0d61 100644
--- a/test_app/tests/test_error_handling.py
+++ b/test_app/tests/test_error_handling.py
@@ -32,7 +32,7 @@ async def test_unauthorized_error(
# Make sure we haven't masked expected errors.
assert output["data"] is None
- assert "Unauthorized: Cannot create entity in this collection" in output["errors"][0]["message"]
+ assert output["errors"][0]["message"].startswith("Unauthorized: Cannot create entity")
@pytest.mark.asyncio