From d70636ba2e8987e8e78e292a704a264121b8082f Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Mon, 31 Oct 2022 13:22:30 +0200 Subject: [PATCH] Add GenericCreator for loading SSL certs in processes (#2578) --- sanic/worker/loader.py | 37 +++++++++++++++++++------------------ tests/test_tls.py | 27 +++++++++++++++++++++++++++ tests/worker/test_loader.py | 4 ++++ 3 files changed, 50 insertions(+), 18 deletions(-) diff --git a/sanic/worker/loader.py b/sanic/worker/loader.py index abdcc9870a..344593dbdc 100644 --- a/sanic/worker/loader.py +++ b/sanic/worker/loader.py @@ -5,18 +5,10 @@ from importlib import import_module from pathlib import Path -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - Optional, - Type, - Union, - cast, -) +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union, cast -from sanic.http.tls.creators import CertCreator, MkcertCreator, TrustmeCreator +from sanic.http.tls.context import process_to_context +from sanic.http.tls.creators import MkcertCreator, TrustmeCreator if TYPE_CHECKING: @@ -106,21 +98,30 @@ def load(self) -> SanicApp: class CertLoader: - _creator_class: Type[CertCreator] + _creators = { + "mkcert": MkcertCreator, + "trustme": TrustmeCreator, + } def __init__(self, ssl_data: Dict[str, Union[str, os.PathLike]]): - creator_name = ssl_data.get("creator") - if creator_name not in ("mkcert", "trustme"): + self._ssl_data = ssl_data + + creator_name = cast(str, ssl_data.get("creator")) + + self._creator_class = self._creators.get(creator_name) + if not creator_name: + return + + if not self._creator_class: raise RuntimeError(f"Unknown certificate creator: {creator_name}") - elif creator_name == "mkcert": - self._creator_class = MkcertCreator - elif creator_name == "trustme": - self._creator_class = TrustmeCreator self._key = ssl_data["key"] self._cert = ssl_data["cert"] self._localhost = cast(str, ssl_data["localhost"]) def load(self, app: SanicApp): + if not self._creator_class: + return process_to_context(self._ssl_data) + creator = self._creator_class(app, self._key, self._cert) return creator.generate_cert(self._localhost) diff --git a/tests/test_tls.py b/tests/test_tls.py index 6c369f9284..497b67deaf 100644 --- a/tests/test_tls.py +++ b/tests/test_tls.py @@ -4,6 +4,7 @@ import subprocess from contextlib import contextmanager +from multiprocessing import Event from pathlib import Path from unittest.mock import Mock, patch from urllib.parse import urlparse @@ -636,3 +637,29 @@ def test_sanic_ssl_context_create(): assert sanic_context is context assert isinstance(sanic_context, SanicSSLContext) + + +def test_ssl_in_multiprocess_mode(app: Sanic, caplog): + + ssl_dict = {"cert": localhost_cert, "key": localhost_key} + event = Event() + + @app.main_process_start + async def main_start(app: Sanic): + app.shared_ctx.event = event + + @app.after_server_start + async def shutdown(app): + app.shared_ctx.event.set() + app.stop() + + assert not event.is_set() + with caplog.at_level(logging.INFO): + app.run(ssl=ssl_dict) + assert event.is_set() + + assert ( + "sanic.root", + logging.INFO, + "Goin' Fast @ https://127.0.0.1:8000", + ) in caplog.record_tuples diff --git a/tests/worker/test_loader.py b/tests/worker/test_loader.py index 6f953c54d1..d0d04e9aab 100644 --- a/tests/worker/test_loader.py +++ b/tests/worker/test_loader.py @@ -86,6 +86,10 @@ def test_input_is_module(): @patch("sanic.worker.loader.TrustmeCreator") @patch("sanic.worker.loader.MkcertCreator") def test_cert_loader(MkcertCreator: Mock, TrustmeCreator: Mock, creator: str): + CertLoader._creators = { + "mkcert": MkcertCreator, + "trustme": TrustmeCreator, + } MkcertCreator.return_value = MkcertCreator TrustmeCreator.return_value = TrustmeCreator data = {