Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vendorize fastapi-utls.cbv #17205

Merged
merged 5 commits into from
Dec 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion lib/galaxy/dependencies/pinned-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ ecdsa==0.18.0 ; python_version >= "3.8" and python_version < "3.12"
edam-ontology==1.25.2 ; python_version >= "3.8" and python_version < "3.12"
email-validator==2.1.0.post1 ; python_version >= "3.8" and python_version < "3.12"
exceptiongroup==1.2.0 ; python_version >= "3.8" and python_version < "3.11"
fastapi-utils==0.2.1 ; python_version >= "3.8" and python_version < "3.12"
fastapi==0.98.0 ; python_version >= "3.8" and python_version < "3.12"
filelock==3.13.1 ; python_version >= "3.8" and python_version < "3.12"
frozenlist==1.4.1 ; python_version >= "3.8" and python_version < "3.12"
Expand Down
2 changes: 1 addition & 1 deletion lib/galaxy/webapps/galaxy/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
APIKeyHeader,
APIKeyQuery,
)
from fastapi_utils.cbv import cbv
from pydantic import ValidationError
from pydantic.main import BaseModel
from starlette.datastructures import Headers
Expand Down Expand Up @@ -71,6 +70,7 @@
from galaxy.structured_app import StructuredApp
from galaxy.web.framework.decorators import require_admin_message
from galaxy.webapps.base.controller import BaseAPIController
from galaxy.webapps.galaxy.api.cbv import cbv
from galaxy.work.context import (
GalaxyAbstractRequest,
GalaxyAbstractResponse,
Expand Down
119 changes: 119 additions & 0 deletions lib/galaxy/webapps/galaxy/api/cbv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""
Original implementation by David Montague (@dmontagu)
https://github.com/dmontagu/fastapi-utils
"""
from __future__ import annotations

import inspect
from collections.abc import Callable
from typing import (
Any,
get_type_hints,
TypeVar,
)

from fastapi import (
APIRouter,
Depends,
)
from pydantic.typing import is_classvar
from starlette.routing import (
Route,
WebSocketRoute,
)

T = TypeVar("T")

CBV_CLASS_KEY = "__cbv_class__"


def cbv(router: APIRouter) -> Callable[[type[T]], type[T]]:
"""
This function returns a decorator that converts the decorated into a class-based view for the provided router.

Any methods of the decorated class that are decorated as endpoints using the router provided to this function
will become endpoints in the router. The first positional argument to the methods (typically `self`)
will be populated with an instance created using FastAPI's dependency-injection.

For more detail, review the documentation at
https://fastapi-utils.davidmontague.xyz/user-guide/class-based-views/#the-cbv-decorator
"""

def decorator(cls: type[T]) -> type[T]:
return _cbv(router, cls)

return decorator


def _cbv(router: APIRouter, cls: type[T]) -> type[T]:
"""
Replaces any methods of the provided class `cls` that are endpoints of routes in `router` with updated
function calls that will properly inject an instance of `cls`.
"""
_init_cbv(cls)
cbv_router = APIRouter()
function_members = inspect.getmembers(cls, inspect.isfunction)
functions_set = {func for _, func in function_members}
cbv_routes = [
route
for route in router.routes
if isinstance(route, (Route, WebSocketRoute)) and route.endpoint in functions_set
]
for route in cbv_routes:
router.routes.remove(route)
_update_cbv_route_endpoint_signature(cls, route)
cbv_router.routes.append(route)
router.include_router(cbv_router)
return cls


def _init_cbv(cls: type[Any]) -> None:
"""
Idempotently modifies the provided `cls`, performing the following modifications:
* The `__init__` function is updated to set any class-annotated dependencies as instance attributes
* The `__signature__` attribute is updated to indicate to FastAPI what arguments should be passed to the initializer
"""
if getattr(cls, CBV_CLASS_KEY, False): # pragma: no cover
return # Already initialized
old_init: Callable[..., Any] = cls.__init__
old_signature = inspect.signature(old_init)
old_parameters = list(old_signature.parameters.values())[1:] # drop `self` parameter
new_parameters = [
x for x in old_parameters if x.kind not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD)
]
dependency_names: list[str] = []
for name, hint in get_type_hints(cls).items():
if is_classvar(hint):
continue
parameter_kwargs = {"default": getattr(cls, name, Ellipsis)}
dependency_names.append(name)
new_parameters.append(
inspect.Parameter(name=name, kind=inspect.Parameter.KEYWORD_ONLY, annotation=hint, **parameter_kwargs)
)
new_signature = old_signature.replace(parameters=new_parameters)

def new_init(self: Any, *args: Any, **kwargs: Any) -> None:
for dep_name in dependency_names:
dep_value = kwargs.pop(dep_name)
setattr(self, dep_name, dep_value)
old_init(self, *args, **kwargs)

setattr(cls, "__signature__", new_signature) # noqa: B010
setattr(cls, "__init__", new_init) # noqa: B010
setattr(cls, CBV_CLASS_KEY, True)


def _update_cbv_route_endpoint_signature(cls: type[Any], route: Route | WebSocketRoute) -> None:
"""
Fixes the endpoint signature for a cbv route to ensure FastAPI performs dependency injection properly.
"""
old_endpoint = route.endpoint
old_signature = inspect.signature(old_endpoint)
old_parameters: list[inspect.Parameter] = list(old_signature.parameters.values())
old_first_parameter = old_parameters[0]
new_first_parameter = old_first_parameter.replace(default=Depends(cls))
new_parameters = [new_first_parameter] + [
parameter.replace(kind=inspect.Parameter.KEYWORD_ONLY) for parameter in old_parameters[1:]
]
new_signature = old_signature.replace(parameters=new_parameters)
setattr(route.endpoint, "__signature__", new_signature) # noqa: B010
1 change: 0 additions & 1 deletion packages/web_apps/setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ install_requires =
Babel
Cheetah3!=3.2.6.post2
fastapi>=0.71.0,!=0.89.0,<0.99
fastapi-utils
gunicorn
gxformat2
importlib-resources;python_version<'3.9'
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ docutils = "!=0.17, !=0.17.1"
dparse = "*"
edam-ontology = "*"
fastapi = ">=0.71.0, !=0.89.0, <0.99" # https://github.com/tiangolo/fastapi/issues/4041 https://github.com/tiangolo/fastapi/issues/5861
fastapi-utils = "*"
fs = "*"
future = "*"
galaxy_sequence_utils = "*"
Expand Down
98 changes: 98 additions & 0 deletions test/unit/webapps/api/test_cbv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""
Original implementation by David Montague (@dmontagu)
https://github.com/dmontagu/fastapi-utils
"""
from __future__ import annotations

from typing import (
Any,
ClassVar,
Optional,
)

from fastapi import (
APIRouter,
Depends,
FastAPI,
)
from starlette.testclient import TestClient

from galaxy.webapps.galaxy.api.cbv import cbv


def test_cbv() -> None:
router = APIRouter()

def dependency() -> int:
return 1

@cbv(router)
class CBV:
x: int = Depends(dependency)
cx: ClassVar[int] = 1
cy: ClassVar[int]

def __init__(self, z: int = Depends(dependency)):
self.y = 1
self.z = z

@router.get("/", response_model=int)
def f(self) -> int:
return self.cx + self.x + self.y + self.z

@router.get("/classvar", response_model=bool)
def g(self) -> bool:
return hasattr(self, "cy")

app = FastAPI()
app.include_router(router)
client = TestClient(app)
response_1 = client.get("/")
assert response_1.status_code == 200
assert response_1.content == b"4"

response_2 = client.get("/classvar")
assert response_2.status_code == 200
assert response_2.content == b"false"


def test_method_order_preserved() -> None:
router = APIRouter()

@cbv(router)
class TestCBV:
@router.get("/test")
def get_test(self) -> int:
return 1

@router.get("/{item_id}")
def get_item(self) -> int: # Alphabetically before `get_test`
return 2

app = FastAPI()
app.include_router(router)

assert TestClient(app).get("/test").json() == 1
assert TestClient(app).get("/other").json() == 2


def test_multiple_decorators() -> None:
router = APIRouter()

@cbv(router)
class RootHandler:
@router.get("/items/?")
@router.get("/items/{item_path:path}")
@router.get("/database/{item_path:path}")
def root(self, item_path: Optional[str] = None, item_query: Optional[str] = None) -> Any: # noqa: UP007
if item_path:
return {"item_path": item_path}
if item_query:
return {"item_query": item_query}
return []

client = TestClient(router)

assert client.get("/items").json() == []
assert client.get("/items/1").json() == {"item_path": "1"}
assert client.get("/database/abc").json() == {"item_path": "abc"}
Loading