diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8a79509..8287982 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -42,6 +42,15 @@ jobs: ports: - 9000:9000 + rabbitmq: + image: "rabbitmq:3-alpine" + env: + RABBITMQ_DEFAULT_USER: "cleanpython" + RABBITMQ_DEFAULT_PASS: "cleanpython" + RABBITMQ_DEFAULT_VHOST: "cleanpython" + ports: + - "5672:5672" + steps: - uses: actions/checkout@v3 diff --git a/CHANGES.md b/CHANGES.md index 6d4f724..384e023 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,9 +1,11 @@ # Changelog of clean-python -## 0.16.6 (unreleased) +## 0.17.0 (unreleased) ---------------------- -- Nothing changed yet. +- Added a `celery.CeleryConfig` with an `apply` method that properly sets up celery + without making the tasks depending on the config. Also added integration tests that + confirm the forwarding of context (tenant and correlation id). ## 0.16.5 (2024-09-12) diff --git a/clean_python/celery/__init__.py b/clean_python/celery/__init__.py index 5e3a36b..c61790a 100644 --- a/clean_python/celery/__init__.py +++ b/clean_python/celery/__init__.py @@ -1,3 +1,4 @@ from .base_task import * # NOQA from .celery_task_logger import * # NOQA +from .config import * # NOQA from .kubernetes import * # NOQA diff --git a/clean_python/celery/base_task.py b/clean_python/celery/base_task.py index 973b88c..62a7cfa 100644 --- a/clean_python/celery/base_task.py +++ b/clean_python/celery/base_task.py @@ -3,40 +3,32 @@ from uuid import uuid4 from celery import Task +from celery.worker.request import Request as CeleryRequest from clean_python import ctx -from clean_python import Json +from clean_python import Id from clean_python import Tenant from clean_python import ValueObject __all__ = ["BaseTask"] -HEADER_FIELD = "clean_python_context" - - class TaskHeaders(ValueObject): - tenant: Tenant | None - correlation_id: UUID | None + tenant_id: Id | None = None + # avoid conflict with celery's own correlation_id: + x_correlation_id: UUID | None = None @classmethod - def from_kwargs(cls, kwargs: Json) -> tuple["TaskHeaders", Json]: - if HEADER_FIELD in kwargs: - kwargs = kwargs.copy() - headers = kwargs.pop(HEADER_FIELD) - return TaskHeaders(**headers), kwargs - else: - return TaskHeaders(tenant=None, correlation_id=None), kwargs + def from_celery_request(cls, request: CeleryRequest) -> "TaskHeaders": + return cls(**request.headers) class BaseTask(Task): def apply_async(self, args=None, kwargs=None, **options): - # include correlation_id and tenant in the kwargs - # and NOT the headers as that is buggy in celery # see https://github.com/celery/celery/issues/4875 - kwargs = {} if kwargs is None else kwargs.copy() - kwargs[HEADER_FIELD] = TaskHeaders( - tenant=ctx.tenant, correlation_id=ctx.correlation_id or uuid4() + options["headers"] = TaskHeaders( + tenant_id=ctx.tenant.id if ctx.tenant else None, + x_correlation_id=ctx.correlation_id or uuid4(), ).model_dump(mode="json") return super().apply_async(args, kwargs, **options) @@ -44,7 +36,9 @@ def __call__(self, *args, **kwargs): return copy_context().run(self._call_with_context, *args, **kwargs) def _call_with_context(self, *args, **kwargs): - headers, kwargs = TaskHeaders.from_kwargs(kwargs) - ctx.tenant = headers.tenant - ctx.correlation_id = headers.correlation_id + headers = TaskHeaders.from_celery_request(self.request) + ctx.tenant = ( + Tenant(id=headers.tenant_id, name="") if headers.tenant_id else None + ) + ctx.correlation_id = headers.x_correlation_id or uuid4() return super().__call__(*args, **kwargs) diff --git a/clean_python/celery/celery_task_logger.py b/clean_python/celery/celery_task_logger.py index ecb606a..87204b1 100644 --- a/clean_python/celery/celery_task_logger.py +++ b/clean_python/celery/celery_task_logger.py @@ -66,17 +66,17 @@ def stop(self, task: Task, state: str, result: Any = None): request = None try: - headers, kwargs = TaskHeaders.from_kwargs(request.kwargs) + headers = TaskHeaders.from_celery_request(request) except (AttributeError, TypeError): - headers = kwargs = None # type: ignore + headers = None try: - tenant_id = headers.tenant.id # type: ignore + tenant_id = headers.tenant_id # type: ignore except AttributeError: tenant_id = None try: - correlation_id = headers.correlation_id + correlation_id = headers.x_correlation_id # type: ignore except AttributeError: correlation_id = None @@ -86,8 +86,8 @@ def stop(self, task: Task, state: str, result: Any = None): argsrepr = None try: - kwargsrepr = json.dumps(kwargs) - except TypeError: + kwargsrepr = json.dumps(request.kwargs) + except (AttributeError, TypeError): kwargsrepr = None log_dict = { diff --git a/clean_python/celery/config.py b/clean_python/celery/config.py new file mode 100644 index 0000000..56838b9 --- /dev/null +++ b/clean_python/celery/config.py @@ -0,0 +1,31 @@ +from celery import Celery +from celery import current_app + +from clean_python import Json +from clean_python import ValueObject +from clean_python.celery import BaseTask + +__all__ = ["CeleryConfig"] + + +class CeleryConfig(ValueObject): + timezone: str = "Europe/Amsterdam" + broker_url: str + broker_transport_options: Json = {"socket_timeout": 2} + broker_connection_retry_on_startup: bool = True + result_backend: str | None = None + worker_prefetch_multiplier: int = 1 + task_always_eager: bool = False + task_eager_propagates: bool = False + task_acks_late: bool = True + task_default_queue: str = "default" + task_default_priority: int = 0 + task_queue_max_priority: int = 10 + task_track_started: bool = True + + def apply(self, strict_typing: bool = True) -> Celery: + app = current_app if current_app else Celery() + app.task_cls = BaseTask + app.strict_typing = strict_typing + app.config_from_object(self) + return app diff --git a/docker-compose.yaml b/docker-compose.yaml index 9f3cea6..a4b533c 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -1,5 +1,3 @@ -version: "3.8" - services: postgres: @@ -18,8 +16,27 @@ services: MINIO_ROOT_PASSWORD: cleanpython ports: - "9000:9000" + healthcheck: + test: ["CMD-SHELL", "pg_isready -U postgres -d postgres"] + interval: 2s + retries: 10 + timeout: 1s fluentbit: image: fluent/fluent-bit:1.9 ports: - "24224:24224" + + rabbitmq: + image: "rabbitmq:3-alpine" + environment: + RABBITMQ_DEFAULT_USER: "cleanpython" + RABBITMQ_DEFAULT_PASS: "cleanpython" + RABBITMQ_DEFAULT_VHOST: "cleanpython" + ports: + - "5672:5672" + healthcheck: + test: rabbitmq-diagnostics check_port_connectivity + interval: 10s + timeout: 1s + retries: 5 diff --git a/integration_tests/celery_example/__init__.py b/integration_tests/celery_example/__init__.py new file mode 100644 index 0000000..cf1d64d --- /dev/null +++ b/integration_tests/celery_example/__init__.py @@ -0,0 +1,22 @@ +import os +from pathlib import Path + +from clean_python.celery import CeleryConfig +from clean_python.celery import CeleryTaskLogger +from clean_python.celery import set_task_logger +from clean_python.testing.debugger import setup_debugger + +from .logger import MultilineJsonFileGateway +from .tasks import sleep_task # NOQA + +app = CeleryConfig( + broker_url="amqp://cleanpython:cleanpython@localhost/cleanpython", + result_backend="rpc://", +).apply() +# the file path is set from the test fixture +logging_path = os.environ.get("CLEAN_PYTHON_TEST_LOGGING") +if logging_path: + set_task_logger(CeleryTaskLogger(MultilineJsonFileGateway(Path(logging_path)))) +debug_port = os.environ.get("CLEAN_PYTHON_TEST_DEBUG") +if debug_port: + setup_debugger(port=int(debug_port)) diff --git a/integration_tests/celery_example/logger.py b/integration_tests/celery_example/logger.py new file mode 100644 index 0000000..796a61d --- /dev/null +++ b/integration_tests/celery_example/logger.py @@ -0,0 +1,34 @@ +import json +from pathlib import Path + +from clean_python import Filter +from clean_python import Json +from clean_python import PageOptions +from clean_python import SyncGateway + +__all__ = ["MultilineJsonFileGateway"] + + +class MultilineJsonFileGateway(SyncGateway): + def __init__(self, path: Path) -> None: + self.path = path + + def clear(self): + if self.path.exists(): + self.path.unlink() + + def filter( + self, filters: list[Filter], params: PageOptions | None = None + ) -> list[Json]: + assert not filters + assert not params + if not self.path.exists(): + return [] + with self.path.open("r") as f: + return [json.loads(line) for line in f] + + def add(self, item: Json) -> Json: + with self.path.open("a") as f: + f.write(json.dumps(item)) + f.write("\n") + return item diff --git a/integration_tests/celery_example/tasks.py b/integration_tests/celery_example/tasks.py new file mode 100644 index 0000000..02a537a --- /dev/null +++ b/integration_tests/celery_example/tasks.py @@ -0,0 +1,34 @@ +import time + +from celery import shared_task +from celery import Task +from celery.exceptions import Ignore +from celery.exceptions import Reject + +from clean_python import ctx + + +@shared_task(bind=True, name="testing") +def sleep_task(self: Task, seconds: float, return_value=None, event="success"): + event = event.lower() + if event == "success": + time.sleep(int(seconds)) + elif event == "crash": + import ctypes + + ctypes.string_at(0) # segfault + elif event == "ignore": + raise Ignore() + elif event == "reject": + raise Reject() + elif event == "retry": + raise self.retry(countdown=seconds, max_retries=1) + elif event == "context": + return { + "tenant_id": ctx.tenant.id, + "correlation_id": str(ctx.correlation_id), + } + else: + raise ValueError(f"Unknown event '{event}'") + + return {"value": return_value} diff --git a/integration_tests/conftest.py b/integration_tests/conftest.py index 4e0dc63..fa794ea 100644 --- a/integration_tests/conftest.py +++ b/integration_tests/conftest.py @@ -4,7 +4,10 @@ import io import multiprocessing import os +import signal +import subprocess import time +from pathlib import Path from urllib.error import URLError from urllib.request import urlopen @@ -13,6 +16,8 @@ import uvicorn from botocore.exceptions import ClientError +from .celery_example import MultilineJsonFileGateway + def pytest_sessionstart(session): """ @@ -102,6 +107,37 @@ async def fastapi_example_app(): p.terminate() +@pytest.fixture(scope="session") +def celery_worker(tmp_path_factory): + log_file = str(tmp_path_factory.mktemp("pytest-celery") / "celery.log") + p = subprocess.Popen( + [ + "celery", + "-A", + "integration_tests.celery_example", + "worker", + "-c", + "1", + # "-P", enable when using the debugger + # "solo" + ], + start_new_session=True, + stdout=subprocess.PIPE, + # optionally add "CLEAN_PYTHON_TEST_DEBUG": "5679" to enable debugging + env={"CLEAN_PYTHON_TEST_LOGGING": log_file, **os.environ}, + ) + try: + yield MultilineJsonFileGateway(Path(log_file)) + finally: + p.send_signal(signal.SIGQUIT) + + +@pytest.fixture +def celery_task_logs(celery_worker): + celery_worker.clear() + return celery_worker + + @pytest.fixture(scope="session") def s3_settings(s3_url): minio_settings = { diff --git a/integration_tests/test_int_celery.py b/integration_tests/test_int_celery.py index a10c93e..656a7b5 100644 --- a/integration_tests/test_int_celery.py +++ b/integration_tests/test_int_celery.py @@ -3,56 +3,18 @@ from uuid import UUID import pytest -from celery.exceptions import Ignore -from celery.exceptions import Reject +from billiard.exceptions import WorkerLostError +from celery.exceptions import MaxRetriesExceededError from clean_python import ctx from clean_python import InMemorySyncGateway +from clean_python import SyncGateway from clean_python import Tenant -from clean_python.celery import BaseTask from clean_python.celery import CeleryTaskLogger from clean_python.celery import set_task_logger - -@pytest.fixture(scope="session") -def celery_parameters(): - return {"task_cls": BaseTask, "strict_typing": False} - - -@pytest.fixture(scope="session") -def celery_worker_parameters(): - return {"shutdown_timeout": 10} - - -@pytest.fixture -def celery_task(celery_app, celery_worker): - @celery_app.task(bind=True, base=BaseTask, name="testing") - def sleep_task(self: BaseTask, seconds: float, return_value=None, event="success"): - event = event.lower() - if event == "success": - time.sleep(int(seconds)) - elif event == "crash": - import ctypes - - ctypes.string_at(0) # segfault - elif event == "ignore": - raise Ignore() - elif event == "reject": - raise Reject() - elif event == "retry": - raise self.retry(countdown=seconds, max_retries=1) - elif event == "context": - return { - "tenant_id": ctx.tenant.id, - "correlation_id": str(ctx.correlation_id), - } - else: - raise ValueError(f"Unknown event '{event}'") - - return {"value": return_value} - - celery_worker.reload() - return sleep_task +from .celery_example import app +from .celery_example import sleep_task @pytest.fixture @@ -63,19 +25,45 @@ def task_logger(): set_task_logger(None) -def test_log_success(celery_task: BaseTask, task_logger: CeleryTaskLogger): - result = celery_task.delay(0.0, return_value=16) +@pytest.mark.usefixtures("celery_worker") +def test_run_task(): + result = sleep_task.delay(0.01, return_value=16) + + assert result.get(timeout=10) == {"value": 16} + + +@pytest.fixture +def custom_context(): + ctx.correlation_id = UUID("b3089ea7-2585-43e5-a63c-ae30a6e9b5e4") + ctx.tenant = Tenant(id=2, name="custom") + yield ctx + ctx.correlation_id = None + ctx.tenant = None + + +@pytest.mark.usefixtures("celery_worker") +def test_context(custom_context): + result = sleep_task.delay(0.01, event="context") + + assert result.get(timeout=10) == { + "tenant_id": custom_context.tenant.id, + "correlation_id": str(custom_context.correlation_id), + } + + +def test_log_success(celery_task_logs: SyncGateway): + result = sleep_task.delay(0.01, return_value=16) assert result.get(timeout=10) == {"value": 16} - (log,) = task_logger.gateway.filter([]) + (log,) = celery_task_logs.filter([]) assert 0.0 < (time.time() - log["time"]) < 1.0 assert log["tag_suffix"] == "task_log" assert log["task_id"] == result.id assert log["state"] == "SUCCESS" assert log["name"] == "testing" assert log["duration"] > 0.0 - assert json.loads(log["argsrepr"]) == [0.0] + assert json.loads(log["argsrepr"]) == [0.01] assert json.loads(log["kwargsrepr"]) == {"return_value": 16} assert log["retries"] == 0 assert log["result"] == {"value": 16} @@ -83,34 +71,64 @@ def test_log_success(celery_task: BaseTask, task_logger: CeleryTaskLogger): assert log["tenant_id"] is None -def test_log_failure(celery_task: BaseTask, task_logger: CeleryTaskLogger): - result = celery_task.delay(0.0, event="failure") +def test_log_failure(celery_task_logs: SyncGateway): + result = sleep_task.delay(0.01, event="failure") with pytest.raises(ValueError): assert result.get(timeout=10) - (log,) = task_logger.gateway.filter([]) + (log,) = celery_task_logs.filter([]) assert log["state"] == "FAILURE" assert log["result"]["traceback"].startswith("Traceback") +def test_log_crash(celery_task_logs: SyncGateway): + result = sleep_task.delay(0.01, event="crash") + + with pytest.raises(WorkerLostError): + assert result.get(timeout=10) + + (log,) = celery_task_logs.filter([]) + assert log["state"] == "FAILURE" + assert "SIGSEGV" in log["result"]["traceback"] + + +def test_log_context(celery_task_logs: SyncGateway, custom_context): + result = sleep_task.delay(0.01, return_value=16) + + assert result.get(timeout=10) == {"value": 16} + + (log,) = celery_task_logs.filter([]) + assert log["correlation_id"] == str(custom_context.correlation_id) + assert log["tenant_id"] == custom_context.tenant.id + + +def test_log_retry_propagates_context(celery_task_logs: SyncGateway, custom_context): + result = sleep_task.delay(0.01, event="retry") + + with pytest.raises(MaxRetriesExceededError): + result.get(timeout=10) + + (log,) = celery_task_logs.filter([]) + assert log["state"] == "FAILURE" + assert log["retries"] == 1 + assert log["correlation_id"] == str(custom_context.correlation_id) + assert log["tenant_id"] == custom_context.tenant.id + + @pytest.fixture -def custom_context(): - ctx.correlation_id = UUID("b3089ea7-2585-43e5-a63c-ae30a6e9b5e4") - ctx.tenant = Tenant(id=2, name="custom") - yield ctx - ctx.correlation_id = None - ctx.tenant = None +def celery_eager(): + app.conf.task_always_eager = True + yield + app.conf.task_always_eager = False -def test_context(celery_task: BaseTask, custom_context, task_logger): - result = celery_task.apply_async((0.0,), {"event": "context"}, countdown=1.0) +@pytest.mark.usefixtures("celery_eager") +def test_eager_mode_with_context(custom_context): + result = sleep_task.delay(0.01, event="context") - assert result.get(timeout=10) == { - "tenant_id": 2, - "correlation_id": "b3089ea7-2585-43e5-a63c-ae30a6e9b5e4", + assert result.__class__.__name__ == "EagerResult" + assert result.get() == { + "tenant_id": custom_context.tenant.id, + "correlation_id": str(custom_context.correlation_id), } - - (log,) = task_logger.gateway.filter([]) - assert log["correlation_id"] == "b3089ea7-2585-43e5-a63c-ae30a6e9b5e4" - assert log["tenant_id"] == 2 diff --git a/tests/celery/test_celery_base_task.py b/tests/celery/test_celery_base_task.py index ab2d5f8..2f0f9f5 100644 --- a/tests/celery/test_celery_base_task.py +++ b/tests/celery/test_celery_base_task.py @@ -8,7 +8,6 @@ from clean_python import ctx from clean_python import Tenant from clean_python.celery import BaseTask -from clean_python.celery.base_task import HEADER_FIELD @pytest.fixture @@ -34,12 +33,12 @@ def test_apply_async(uuid4, mocked_apply_async): BaseTask().apply_async(args=("foo",), kwargs={"a": "bar"}) assert mocked_apply_async.call_count == 1 - (args, kwargs), _ = mocked_apply_async.call_args + (args, kwargs), options = mocked_apply_async.call_args assert args == ("foo",) assert kwargs["a"] == "bar" - assert kwargs[HEADER_FIELD] == { - "tenant": None, - "correlation_id": "479156af-a302-48fc-89ed-8c426abadc4c", + assert options["headers"] == { + "tenant_id": None, + "x_correlation_id": "479156af-a302-48fc-89ed-8c426abadc4c", } @@ -47,7 +46,7 @@ def test_apply_async_with_context(mocked_apply_async, temp_context): BaseTask().apply_async(args=("foo",), kwargs={"a": "bar"}) assert mocked_apply_async.call_count == 1 - (_, kwargs), _ = mocked_apply_async.call_args + (_, kwargs), options = mocked_apply_async.call_args assert kwargs["a"] == "bar" - assert kwargs[HEADER_FIELD]["tenant"] == temp_context.tenant.model_dump(mode="json") - assert kwargs[HEADER_FIELD]["correlation_id"] == str(temp_context.correlation_id) + assert options["headers"]["tenant_id"] == temp_context.tenant.id + assert options["headers"]["x_correlation_id"] == str(temp_context.correlation_id) diff --git a/tests/celery/test_celery_task_logger.py b/tests/celery/test_celery_task_logger.py index b3402c9..b9378f5 100644 --- a/tests/celery/test_celery_task_logger.py +++ b/tests/celery/test_celery_task_logger.py @@ -54,12 +54,12 @@ def celery_task(): request.origin = "hostname" request.retries = 25 request.args = [1, 2] - request.kwargs = { - "clean_python_context": { - "tenant": {"id": 15, "name": "foo"}, - "correlation_id": "b3089ea7-2585-43e5-a63c-ae30a6e9b5e4", - } + request.kwargs = {} + request.headers = { + "tenant_id": 15, + "x_correlation_id": "b3089ea7-2585-43e5-a63c-ae30a6e9b5e4", } + task = mock.Mock() task.name = "task_name" task.request = request @@ -98,11 +98,12 @@ def test_log_with_result( assert entry["result"] == expected -def test_log_with_request_no_args_kwargs( +def test_log_with_request_no_args_kwargs_no_headers( celery_task_logger: CeleryTaskLogger, celery_task ): celery_task.request.args = None celery_task.request.kwargs = None + celery_task.request.headers = None celery_task_logger.stop(celery_task, "STAAT") (entry,) = celery_task_logger.gateway.filter([])