From 2db1f4d2131d9902af15e00312e1673d0b3d0a61 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Janek=20Nouvertn=C3=A9?= Date: Sat, 11 Jan 2025 20:24:15 +0100 Subject: [PATCH] fix(ASGI mounts): Prevent accidental scope overrides by mounted ASGI apps (#3945) --- .../application_hooks/after_exception_hook.py | 2 +- .../application_hooks/before_send_hook.py | 2 +- .../using_application_state.py | 2 +- docs/examples/routing/mount_custom_app.py | 2 +- .../routing/mounting_starlette_app.py | 2 +- docs/usage/applications.rst | 8 ++-- litestar/app.py | 7 ++- litestar/connection/base.py | 6 +-- litestar/handlers/asgi_handlers.py | 22 +++++++++- litestar/middleware/_internal/cors.py | 2 +- .../_internal/exceptions/middleware.py | 4 +- litestar/middleware/csrf.py | 2 +- litestar/middleware/logging.py | 4 +- litestar/middleware/rate_limit.py | 2 +- litestar/middleware/response_cache.py | 2 +- litestar/routes/asgi.py | 21 ++++++++- litestar/routes/http.py | 4 +- litestar/testing/client/base.py | 1 + litestar/testing/request_factory.py | 3 +- litestar/types/asgi_types.py | 3 +- litestar/utils/scope/__init__.py | 2 +- tests/conftest.py | 1 + tests/unit/test_app.py | 18 +++++++- .../test_asgi_handlers/test_handle_asgi.py | 43 ++++++++++++++++++- 24 files changed, 135 insertions(+), 30 deletions(-) diff --git a/docs/examples/application_hooks/after_exception_hook.py b/docs/examples/application_hooks/after_exception_hook.py index f9df46b478..292d6b709e 100644 --- a/docs/examples/application_hooks/after_exception_hook.py +++ b/docs/examples/application_hooks/after_exception_hook.py @@ -19,7 +19,7 @@ def my_handler() -> None: async def after_exception_handler(exc: Exception, scope: "Scope") -> None: """Hook function that will be invoked after each exception.""" - state = scope["app"].state + state = Litestar.from_scope(scope).state if not hasattr(state, "error_count"): state.error_count = 1 else: diff --git a/docs/examples/application_hooks/before_send_hook.py b/docs/examples/application_hooks/before_send_hook.py index 50d5b04d46..a77cd84323 100644 --- a/docs/examples/application_hooks/before_send_hook.py +++ b/docs/examples/application_hooks/before_send_hook.py @@ -24,7 +24,7 @@ async def before_send_hook_handler(message: Message, scope: Scope) -> None: """ if message["type"] == "http.response.start": headers = MutableScopeHeaders.from_message(message=message) - headers["My Header"] = scope["app"].state.message + headers["My Header"] = Litestar.from_scope(scope).state.message def on_startup(app: Litestar) -> None: diff --git a/docs/examples/application_state/using_application_state.py b/docs/examples/application_state/using_application_state.py index 5aabf4e013..0dc6162c31 100644 --- a/docs/examples/application_state/using_application_state.py +++ b/docs/examples/application_state/using_application_state.py @@ -20,7 +20,7 @@ def middleware_factory(*, app: "ASGIApp") -> "ASGIApp": """A middleware can access application state via `scope`.""" async def my_middleware(scope: "Scope", receive: "Receive", send: "Send") -> None: - state = scope["app"].state + state = Litestar.from_scope(scope).state logger.info("state value in middleware: %s", state.value) await app(scope, receive, send) diff --git a/docs/examples/routing/mount_custom_app.py b/docs/examples/routing/mount_custom_app.py index 2088bebf81..718827968f 100644 --- a/docs/examples/routing/mount_custom_app.py +++ b/docs/examples/routing/mount_custom_app.py @@ -8,7 +8,7 @@ from litestar.types import Receive, Scope, Send -@asgi("/some/sub-path", is_mount=True) +@asgi("/some/sub-path", is_mount=True, copy_scope=True) async def my_asgi_app(scope: "Scope", receive: "Receive", send: "Send") -> None: """ Args: diff --git a/docs/examples/routing/mounting_starlette_app.py b/docs/examples/routing/mounting_starlette_app.py index ba52012eea..a989b62988 100644 --- a/docs/examples/routing/mounting_starlette_app.py +++ b/docs/examples/routing/mounting_starlette_app.py @@ -15,7 +15,7 @@ async def index(request: "Request") -> JSONResponse: return JSONResponse({"forwarded_path": request.url.path}) -starlette_app = asgi(path="/some/sub-path", is_mount=True)( +starlette_app = asgi(path="/some/sub-path", is_mount=True, copy_scope=True)( Starlette( routes=[ Route("/", index), diff --git a/docs/usage/applications.rst b/docs/usage/applications.rst index 63cd97ed94..cc735b89a7 100644 --- a/docs/usage/applications.rst +++ b/docs/usage/applications.rst @@ -110,10 +110,10 @@ is accessible. :ref:`reserved keyword arguments `. It is important to understand in this context that the application instance is injected into the ASGI ``scope`` mapping -for each connection (i.e. request or websocket connection) as ``scope["app"]``. This makes the application -accessible wherever the scope mapping is available, e.g. in middleware, on :class:`~.connection.request.Request` and -:class:`~.connection.websocket.WebSocket` instances (accessible as ``request.app`` / ``socket.app``), and many -other places. +for each connection (i.e. request or websocket connection) as ``scope["litestar_app"]``, and can be retrieved using +:meth:`~.Litestar.from_scope`. This makes the application accessible wherever the scope mapping is available, +e.g. in middleware, on :class:`~.connection.request.Request` and :class:`~.connection.websocket.WebSocket` instances +(accessible as ``request.app`` / ``socket.app``), and many other places. Therefore, :paramref:`~.app.Litestar.state` offers an easy way to share contextual data between disparate parts of the application, as seen below: diff --git a/litestar/app.py b/litestar/app.py index 399c01dc03..0e51a9f0e7 100644 --- a/litestar/app.py +++ b/litestar/app.py @@ -610,10 +610,15 @@ async def __call__( await self.asgi_router.lifespan(receive=receive, send=send) # type: ignore[arg-type] return - scope["app"] = self + scope["app"] = scope["litestar_app"] = self scope.setdefault("state", {}) await self.asgi_handler(scope, receive, self._wrap_send(send=send, scope=scope)) # type: ignore[arg-type] + @classmethod + def from_scope(cls, scope: Scope) -> Litestar: + """Retrieve the Litestar application from the current ASGI scope""" + return scope["litestar_app"] + async def _call_lifespan_hook(self, hook: LifespanHook) -> None: ret = hook(self) if inspect.signature(hook).parameters else hook() # type: ignore[call-arg] diff --git a/litestar/connection/base.py b/litestar/connection/base.py index 6c80e96522..b7e53cc5de 100644 --- a/litestar/connection/base.py +++ b/litestar/connection/base.py @@ -100,7 +100,7 @@ def app(self) -> Litestar: Returns: The :class:`Litestar ` application instance """ - return self.scope["app"] + return self.scope["litestar_app"] @property def route_handler(self) -> HandlerT: @@ -321,7 +321,7 @@ def url_for(self, name: str, **path_parameters: Any) -> str: Returns: A string representing the absolute url of the route handler. """ - litestar_instance = self.scope["app"] + litestar_instance = self.scope["litestar_app"] url_path = litestar_instance.route_reverse(name, **path_parameters) return make_absolute_url(url_path, self.base_url) @@ -339,7 +339,7 @@ def url_for_static_asset(self, name: str, file_path: str) -> str: Returns: A string representing absolute url to the asset. """ - litestar_instance = self.scope["app"] + litestar_instance = self.scope["litestar_app"] url_path = litestar_instance.url_for_static_asset(name, file_path) return make_absolute_url(url_path, self.base_url) diff --git a/litestar/handlers/asgi_handlers.py b/litestar/handlers/asgi_handlers.py index 91f35172d3..7857480704 100644 --- a/litestar/handlers/asgi_handlers.py +++ b/litestar/handlers/asgi_handlers.py @@ -1,5 +1,6 @@ from __future__ import annotations +import warnings from typing import TYPE_CHECKING, Any, Mapping, Sequence from litestar.exceptions import ImproperlyConfiguredException @@ -11,6 +12,7 @@ if TYPE_CHECKING: + from litestar import Litestar from litestar.types import ( ExceptionHandlersMap, Guard, @@ -24,7 +26,7 @@ class ASGIRouteHandler(BaseRouteHandler): Use this decorator to decorate ASGI applications. """ - __slots__ = ("is_mount", "is_static") + __slots__ = ("copy_scope", "is_mount", "is_static") def __init__( self, @@ -37,6 +39,7 @@ def __init__( is_mount: bool = False, is_static: bool = False, signature_namespace: Mapping[str, Any] | None = None, + copy_scope: bool | None = None, **kwargs: Any, ) -> None: """Initialize ``ASGIRouteHandler``. @@ -58,10 +61,14 @@ def __init__( are used to deliver static files. signature_namespace: A mapping of names to types for use in forward reference resolution during signature modelling. type_encoders: A mapping of types to callables that transform them into types supported for serialization. + copy_scope: Copy the ASGI 'scope' before calling the mounted application. Should be set to 'True' unless + side effects via scope mutations by the mounted ASGI application are intentional **kwargs: Any additional kwarg - will be set in the opt dictionary. """ self.is_mount = is_mount or is_static self.is_static = is_static + self.copy_scope = copy_scope + super().__init__( path, exception_handlers=exception_handlers, @@ -72,6 +79,19 @@ def __init__( **kwargs, ) + def on_registration(self, app: Litestar) -> None: + super().on_registration(app) + + if self.copy_scope is None: + warnings.warn( + f"{self}: 'copy_scope' not set for ASGI handler. Leaving 'copy_scope' unset will warn about mounted " + "ASGI applications modifying the scope. Set 'copy_scope=True' to ensure calling into mounted ASGI apps " + "does not cause any side effects via scope mutations, or set 'copy_scope=False' if those mutations are " + "desired. 'copy'scope' will default to 'True' in Litestar 3", + category=DeprecationWarning, + stacklevel=1, + ) + def _validate_handler_function(self) -> None: """Validate the route handler function once it's set by inspecting its return annotations.""" super()._validate_handler_function() diff --git a/litestar/middleware/_internal/cors.py b/litestar/middleware/_internal/cors.py index 623eb9b2eb..0608a134b6 100644 --- a/litestar/middleware/_internal/cors.py +++ b/litestar/middleware/_internal/cors.py @@ -44,7 +44,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: origin = headers.get("origin") if scope["type"] == ScopeType.HTTP and scope["method"] == HttpMethod.OPTIONS and origin: - request = scope["app"].request_class(scope=scope, receive=receive, send=send) + request = scope["litestar_app"].request_class(scope=scope, receive=receive, send=send) asgi_response = self._create_preflight_response(origin=origin, request_headers=headers).to_asgi_response( app=None, request=request ) diff --git a/litestar/middleware/_internal/exceptions/middleware.py b/litestar/middleware/_internal/exceptions/middleware.py index 3801d995cf..6536f3b10f 100644 --- a/litestar/middleware/_internal/exceptions/middleware.py +++ b/litestar/middleware/_internal/exceptions/middleware.py @@ -135,7 +135,7 @@ def __init__( @staticmethod def _get_debug_scope(scope: Scope) -> bool: - return scope["app"].debug + return scope["litestar_app"].debug async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: """ASGI-callable. @@ -161,7 +161,7 @@ async def capture_response_started(event: Message) -> None: if scope_state.response_started: raise LitestarException("Exception caught after response started") from e - litestar_app = scope["app"] + litestar_app = scope["litestar_app"] if litestar_app.logging_config and (logger := litestar_app.logger): self.handle_exception_logging(logger=logger, logging_config=litestar_app.logging_config, scope=scope) diff --git a/litestar/middleware/csrf.py b/litestar/middleware/csrf.py index 0fef4fbe42..794135d67e 100644 --- a/litestar/middleware/csrf.py +++ b/litestar/middleware/csrf.py @@ -106,7 +106,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await self.app(scope, receive, send) return - request: Request[Any, Any, Any] = scope["app"].request_class(scope=scope, receive=receive) + request: Request[Any, Any, Any] = scope["litestar_app"].request_class(scope=scope, receive=receive) content_type, _ = request.content_type csrf_cookie = request.cookies.get(self.config.cookie_name) existing_csrf_token = request.headers.get(self.config.header_name) diff --git a/litestar/middleware/logging.py b/litestar/middleware/logging.py index c909dfed23..2998657a10 100644 --- a/litestar/middleware/logging.py +++ b/litestar/middleware/logging.py @@ -100,7 +100,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: None """ if not hasattr(self, "logger"): - self.logger = scope["app"].get_logger(self.config.logger_name) + self.logger = scope["litestar_app"].get_logger(self.config.logger_name) self.is_struct_logger = structlog_installed and repr(self.logger).startswith(" None: Returns: None """ - extracted_data = await self.extract_request_data(request=scope["app"].request_class(scope, receive)) + extracted_data = await self.extract_request_data(request=scope["litestar_app"].request_class(scope, receive)) self.log_message(values=extracted_data) def log_response(self, scope: Scope) -> None: diff --git a/litestar/middleware/rate_limit.py b/litestar/middleware/rate_limit.py index 11a6653924..b9bdfaf7df 100644 --- a/litestar/middleware/rate_limit.py +++ b/litestar/middleware/rate_limit.py @@ -67,7 +67,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: Returns: None """ - app = scope["app"] + app = scope["litestar_app"] request: Request[Any, Any, Any] = app.request_class(scope) store = self.config.get_store_from_app(app) if await self.should_check_request(request=request): diff --git a/litestar/middleware/response_cache.py b/litestar/middleware/response_cache.py index a2c4bf7e71..8077097206 100644 --- a/litestar/middleware/response_cache.py +++ b/litestar/middleware/response_cache.py @@ -51,7 +51,7 @@ async def wrapped_send(message: Message) -> None: if messages and message["type"] == HTTP_RESPONSE_BODY and not message.get("more_body"): key = (route_handler.cache_key_builder or self.config.key_builder)(Request(scope)) - store = self.config.get_store_from_app(scope["app"]) + store = self.config.get_store_from_app(scope["litestar_app"]) await store.set(key, encode_msgpack(messages), expires_in=expires_in) await send(message) diff --git a/litestar/routes/asgi.py b/litestar/routes/asgi.py index a8564d0e61..736ef5de4f 100644 --- a/litestar/routes/asgi.py +++ b/litestar/routes/asgi.py @@ -1,9 +1,11 @@ from __future__ import annotations +import warnings from typing import TYPE_CHECKING, Any from litestar.connection import ASGIConnection from litestar.enums import ScopeType +from litestar.exceptions import LitestarWarning from litestar.routes.base import BaseRoute if TYPE_CHECKING: @@ -51,4 +53,21 @@ async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: connection = ASGIConnection["ASGIRouteHandler", Any, Any, Any](scope=scope, receive=receive) await self.route_handler.authorize_connection(connection=connection) - await self.route_handler.fn(scope=scope, receive=receive, send=send) + handler_scope = scope.copy() + copy_scope = self.route_handler.copy_scope + + await self.route_handler.fn( + scope=handler_scope if copy_scope is True else scope, + receive=receive, + send=send, + ) + + if copy_scope is None and handler_scope != scope: + warnings.warn( + f"{self.route_handler}: Mounted ASGI app {self.route_handler.fn} modified 'scope' with 'copy_scope' " + "set to 'None'. Set 'copy_scope=True' to avoid mutating the original scope or set 'copy_scope=False' " + "if mutating the scope from within the mounted ASGI app is intentional. Note: 'copy_scope' will " + "default to 'True' by default in Litestar 3", + category=LitestarWarning, + stacklevel=1, + ) diff --git a/litestar/routes/http.py b/litestar/routes/http.py index 6ed17cfaa2..80368e852e 100644 --- a/litestar/routes/http.py +++ b/litestar/routes/http.py @@ -154,7 +154,9 @@ async def _call_handler_function( route_handler=route_handler, parameter_model=parameter_model, request=request ) - response: ASGIApp = await route_handler.to_response(app=scope["app"], data=response_data, request=request) + response: ASGIApp = await route_handler.to_response( + app=scope["litestar_app"], data=response_data, request=request + ) if cleanup_group: await cleanup_group.cleanup() diff --git a/litestar/testing/client/base.py b/litestar/testing/client/base.py index eeec983b79..f65630f830 100644 --- a/litestar/testing/client/base.py +++ b/litestar/testing/client/base.py @@ -54,6 +54,7 @@ def fake_asgi_connection(app: ASGIApp, cookies: dict[str, str]) -> ASGIConnectio "http_version": "1.1", "extensions": {"http.response.template": {}}, "app": app, # type: ignore[typeddict-item] + "litestar_app": app, "state": {}, "path_params": {}, "route_handler": None, diff --git a/litestar/testing/request_factory.py b/litestar/testing/request_factory.py index e25b6b0956..290b04fb84 100644 --- a/litestar/testing/request_factory.py +++ b/litestar/testing/request_factory.py @@ -75,7 +75,7 @@ def __init__( """Initialize ``RequestFactory`` Args: - app: An instance of :class:`Litestar ` to set as ``request.scope["app"]``. + app: An instance of :class:`Litestar ` to set as ``request.scope["litestar_app"]``. server: The server's domain. port: The server's port. root_path: Root path for the server. @@ -175,6 +175,7 @@ def _create_scope( path=path, headers=[], app=self.app, + litestar_app=self.app, session=session, user=user, auth=auth, diff --git a/litestar/types/asgi_types.py b/litestar/types/asgi_types.py index abbcf2ae77..c7eb56b6b1 100644 --- a/litestar/types/asgi_types.py +++ b/litestar/types/asgi_types.py @@ -124,7 +124,8 @@ class HeaderScope(TypedDict): class BaseScope(HeaderScope): """Base ASGI-scope.""" - app: Litestar + app: Litestar # deprecated + litestar_app: Litestar asgi: ASGIVersion auth: Any client: tuple[str, int] | None diff --git a/litestar/utils/scope/__init__.py b/litestar/utils/scope/__init__.py index e5757d3983..bc707eff3a 100644 --- a/litestar/utils/scope/__init__.py +++ b/litestar/utils/scope/__init__.py @@ -24,7 +24,7 @@ def get_serializer_from_scope(scope: Scope) -> Serializer: A serializer function """ route_handler = scope["route_handler"] - app = scope["app"] + app = scope["litestar_app"] if hasattr(route_handler, "resolve_type_encoders"): type_encoders = route_handler.resolve_type_encoders() diff --git a/tests/conftest.py b/tests/conftest.py index de243a37e2..91d88cc800 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -211,6 +211,7 @@ def inner( ) -> Scope: scope = { "app": app, + "litestar_app": app, "asgi": asgi or {"spec_version": "2.0", "version": "3.0"}, "auth": auth, "type": type, diff --git a/tests/unit/test_app.py b/tests/unit/test_app.py index 084e0572fa..7fce815ed7 100644 --- a/tests/unit/test_app.py +++ b/tests/unit/test_app.py @@ -243,7 +243,7 @@ def handler() -> Dict[str, str]: async def before_send_hook_handler(message: Message, scope: Scope) -> None: if message["type"] == "http.response.start": headers = MutableScopeHeaders(message) - headers.add("My Header", scope["app"].state.message) + headers.add("My Header", Litestar.from_scope(scope).state.message) def on_startup(app: Litestar) -> None: app.state.message = "value injected during send" @@ -466,3 +466,19 @@ def my_route_handler() -> None: ... with create_test_client(my_route_handler, path="/abc") as client: response = client.get("/abc") assert response.status_code == HTTP_200_OK + + +def test_from_scope() -> None: + mock = MagicMock() + + @get() + def handler(scope: Scope) -> None: + mock(Litestar.from_scope(scope)) + return + + app = Litestar(route_handlers=[handler]) + + with TestClient(app) as client: + client.get("/") + + mock.assert_called_once_with(app) diff --git a/tests/unit/test_handlers/test_asgi_handlers/test_handle_asgi.py b/tests/unit/test_handlers/test_asgi_handlers/test_handle_asgi.py index 84d9320c98..2f60f8ac90 100644 --- a/tests/unit/test_handlers/test_asgi_handlers/test_handle_asgi.py +++ b/tests/unit/test_handlers/test_asgi_handlers/test_handle_asgi.py @@ -1,9 +1,14 @@ -from litestar import Controller, MediaType, asgi +from unittest.mock import MagicMock + +import pytest + +from litestar import Controller, Litestar, MediaType, asgi from litestar.enums import ScopeType +from litestar.exceptions import LitestarWarning from litestar.response.base import ASGIResponse from litestar.status_codes import HTTP_200_OK from litestar.testing import create_test_client -from litestar.types import Receive, Scope, Send +from litestar.types import ASGIApp, Receive, Scope, Send def test_handle_asgi() -> None: @@ -51,3 +56,37 @@ async def root_asgi_handler( response = client.get("/asgi") assert response.status_code == HTTP_200_OK assert response.text == "/asgi" + + +def test_copy_scope_not_set_warns_on_modification() -> None: + @asgi(is_mount=True) + async def handler(scope: "Scope", receive: "Receive", send: "Send") -> None: + scope["foo"] = "" # type: ignore[typeddict-unknown-key] + await ASGIResponse()(scope, receive, send) + + with create_test_client([handler]) as client: + with pytest.warns(LitestarWarning, match="modified 'scope' with 'copy_scope' set to 'None'"): + response = client.get("/") + assert response.status_code == HTTP_200_OK + + +@pytest.mark.parametrize("copy_scope, expected_value", [(True, None), (False, "foo")]) +def test_copy_scope(copy_scope: bool, expected_value: "str | None") -> None: + mock = MagicMock() + + def middleware_factory(app: Litestar) -> ASGIApp: + async def middleware(scope: "Scope", receive: "Receive", send: "Send") -> None: + await app(scope, receive, send) + mock(scope.get("foo")) + + return middleware + + @asgi(is_mount=True, copy_scope=copy_scope) + async def handler(scope: "Scope", receive: "Receive", send: "Send") -> None: + scope["foo"] = "foo" # type: ignore[typeddict-unknown-key] + await ASGIResponse()(scope, receive, send) + + with create_test_client([handler], middleware=[middleware_factory]) as client: + client.get("/") + + mock.assert_called_once_with(expected_value)