Skip to content

Commit

Permalink
Fix celery header issues (#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
caspervdw authored Nov 7, 2023
1 parent 6e8845e commit a60940c
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 73 deletions.
2 changes: 2 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

- Fix access logging of correlation id.

- Workaround celery issues with message headers: use the body (kwargs) instead.


0.8.1 (2023-11-06)
------------------
Expand Down
36 changes: 17 additions & 19 deletions clean_python/celery/base_task.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from contextvars import copy_context
from typing import Optional
from typing import Tuple
from uuid import UUID
from uuid import uuid4

from celery import Task

from clean_python import ctx
from clean_python import Json
from clean_python import Tenant
from clean_python import ValueObject

Expand All @@ -20,35 +22,31 @@ class TaskHeaders(ValueObject):
correlation_id: Optional[UUID]

@classmethod
def from_celery_request(cls, request) -> "TaskHeaders":
if request.headers and HEADER_FIELD in request.headers:
return TaskHeaders(**request.headers[HEADER_FIELD])
def from_kwargs(cls, kwargs: Json) -> Tuple["TaskHeaders", Json]:
if HEADER_FIELD in kwargs:
kwargs = kwargs.copy()
headers = kwargs.pop(HEADER_FIELD)
return TaskHeaders(**headers), kwargs
else:
return TaskHeaders(tenant=None, correlation_id=None)
return TaskHeaders(tenant=None, correlation_id=None), kwargs


class BaseTask(Task):
def apply_async(self, args=None, kwargs=None, **options):
# include correlation_id and tenant in the headers
if "headers" in options:
headers = options.pop("headers")
if headers is None:
headers = {}
else:
headers = headers.copy()
else:
headers = {}
if HEADER_FIELD not in headers:
headers[HEADER_FIELD] = TaskHeaders(
tenant=ctx.tenant, correlation_id=ctx.correlation_id or uuid4()
).model_dump(mode="json")
return super().apply_async(args, kwargs, headers=headers, **options)
# include correlation_id and tenant in the kwargs
# and NOT the headers as that is buggy in celery
# see https://github.com/celery/celery/issues/4875
kwargs = {} if kwargs is None else kwargs.copy()
kwargs[HEADER_FIELD] = TaskHeaders(
tenant=ctx.tenant, correlation_id=ctx.correlation_id or uuid4()
).model_dump(mode="json")
return super().apply_async(args, kwargs, **options)

def __call__(self, *args, **kwargs):
return copy_context().run(self._call_with_context, *args, **kwargs)

def _call_with_context(self, *args, **kwargs):
headers = TaskHeaders.from_celery_request(self.request)
headers, kwargs = TaskHeaders.from_kwargs(kwargs)
ctx.tenant = headers.tenant
ctx.correlation_id = headers.correlation_id
return super().__call__(*args, **kwargs)
32 changes: 28 additions & 4 deletions clean_python/celery/celery_task_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,34 @@ def stop(self, task: Task, state: str, result: Any = None):

try:
request = task.request
correlation_id = TaskHeaders.from_celery_request(request).correlation_id
except AttributeError:
request = None

try:
headers, kwargs = TaskHeaders.from_kwargs(request.kwargs)
except AttributeError:
headers = kwargs = None # type: ignore

try:
tenant_id = headers.tenant.id # type: ignore
except AttributeError:
tenant_id = None

try:
correlation_id = headers.correlation_id
except AttributeError:
correlation_id = None

try:
args = json.loads(json.dumps(request.args))
except (AttributeError, TypeError):
args = None

try:
kwargs = json.loads(json.dumps(kwargs))
except TypeError:
kwargs = None

log_dict = {
"tag_suffix": "task_log",
"time": start_time,
Expand All @@ -77,10 +100,11 @@ def stop(self, task: Task, state: str, result: Any = None):
"duration": duration,
"origin": getattr(request, "origin", None),
"retries": getattr(request, "retries", None),
"argsrepr": getattr(request, "argsrepr", None),
"kwargsrepr": getattr(request, "kwargsrepr", None),
"args": args,
"kwargs": kwargs,
"result": result_json,
"correlation_id": str(correlation_id) if correlation_id else None,
"tenant_id": tenant_id,
"correlation_id": None if correlation_id is None else str(correlation_id),
}

return self.gateway.add(log_dict)
Expand Down
19 changes: 11 additions & 8 deletions integration_tests/test_int_celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

@pytest.fixture(scope="session")
def celery_parameters():
return {"task_cls": BaseTask}
return {"task_cls": BaseTask, "strict_typing": False}


@pytest.fixture(scope="session")
Expand All @@ -41,7 +41,10 @@ def sleep_task(self: BaseTask, seconds: float, return_value=None, event="success
elif event == "retry":
raise self.retry(countdown=seconds, max_retries=1)
elif event == "context":
return {"tenant": ctx.tenant.id, "correlation_id": str(ctx.correlation_id)}
return {
"tenant_id": ctx.tenant.id,
"correlation_id": str(ctx.correlation_id),
}
else:
raise ValueError(f"Unknown event '{event}'")

Expand Down Expand Up @@ -71,11 +74,12 @@ def test_log_success(celery_task: BaseTask, task_logger: CeleryTaskLogger):
assert log["state"] == "SUCCESS"
assert log["name"] == "testing"
assert log["duration"] > 0.0
assert log["argsrepr"] == "(0.0,)"
assert log["kwargsrepr"] == "{'return_value': 16}"
assert log["args"] == [0.0]
assert log["kwargs"] == {"return_value": 16}
assert log["retries"] == 0
assert log["result"] == {"value": 16}
assert UUID(log["correlation_id"]) # generated
assert log["tenant_id"] is None


def test_log_failure(celery_task: BaseTask, task_logger: CeleryTaskLogger):
Expand All @@ -98,15 +102,14 @@ def custom_context():
ctx.tenant = None


def test_context(celery_task: BaseTask, task_logger: CeleryTaskLogger, custom_context):
def test_context(celery_task: BaseTask, custom_context, task_logger):
result = celery_task.apply_async((0.0,), {"event": "context"}, countdown=1.0)
custom_context.correlation_id = None
custom_context.tenant = None

assert result.get(timeout=10) == {
"tenant": 2,
"tenant_id": 2,
"correlation_id": "b3089ea7-2585-43e5-a63c-ae30a6e9b5e4",
}

(log,) = task_logger.gateway.filter([])
assert log["correlation_id"] == "b3089ea7-2585-43e5-a63c-ae30a6e9b5e4"
assert log["tenant_id"] == 2
53 changes: 18 additions & 35 deletions tests/celery/test_celery_base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,45 +26,28 @@ def temp_context():
ctx.correlation_id = None


def test_apply_async(mocked_apply_async):
BaseTask().apply_async(args="foo", kwargs="bar")
@mock.patch(
"clean_python.celery.base_task.uuid4",
return_value=UUID("479156af-a302-48fc-89ed-8c426abadc4c"),
)
def test_apply_async(uuid4, mocked_apply_async):
BaseTask().apply_async(args=("foo",), kwargs={"a": "bar"})

assert mocked_apply_async.call_count == 1
args, kwargs = mocked_apply_async.call_args
assert args == ("foo", "bar")
assert kwargs["headers"][HEADER_FIELD]["tenant"] is None
UUID(kwargs["headers"][HEADER_FIELD]["correlation_id"]) # generated
(args, kwargs), _ = mocked_apply_async.call_args
assert args == ("foo",)
assert kwargs["a"] == "bar"
assert kwargs[HEADER_FIELD] == {
"tenant": None,
"correlation_id": "479156af-a302-48fc-89ed-8c426abadc4c",
}


def test_apply_async_with_context(mocked_apply_async, temp_context):
BaseTask().apply_async(args="foo", kwargs="bar")
BaseTask().apply_async(args=("foo",), kwargs={"a": "bar"})

assert mocked_apply_async.call_count == 1
_, kwargs = mocked_apply_async.call_args
assert kwargs["headers"][HEADER_FIELD]["tenant"] == temp_context.tenant.model_dump(
mode="json"
)
kwargs["headers"][HEADER_FIELD]["correlation_id"] == str(
temp_context.correlation_id
)


def test_apply_async_headers_extended(mocked_apply_async):
headers = {"baz": 2}
BaseTask().apply_async(args="foo", kwargs="bar", headers=headers)

assert mocked_apply_async.call_count == 1
_, kwargs = mocked_apply_async.call_args
assert kwargs["headers"]["baz"] == 2
assert kwargs["headers"][HEADER_FIELD]["tenant"] is None
UUID(kwargs["headers"][HEADER_FIELD]["correlation_id"]) # generated

assert headers == {"baz": 2} # not changed inplace


def test_apply_async_headers_already_present(mocked_apply_async):
BaseTask().apply_async(args="foo", kwargs="bar", headers={HEADER_FIELD: "foo"})

assert mocked_apply_async.call_count == 1
_, kwargs = mocked_apply_async.call_args
assert kwargs["headers"] == {HEADER_FIELD: "foo"}
(_, kwargs), _ = mocked_apply_async.call_args
assert kwargs["a"] == "bar"
assert kwargs[HEADER_FIELD]["tenant"] == temp_context.tenant.model_dump(mode="json")
assert kwargs[HEADER_FIELD]["correlation_id"] == str(temp_context.correlation_id)
14 changes: 7 additions & 7 deletions tests/celery/test_celery_task_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@ def test_log_minimal(celery_task_logger: CeleryTaskLogger):
"state": "STAAT",
"duration": None,
"origin": None,
"argsrepr": None,
"kwargsrepr": None,
"args": None,
"kwargs": None,
"result": None,
"time": None,
"tenant_id": None,
"correlation_id": None,
"retries": None,
}
Expand All @@ -52,9 +53,8 @@ def celery_task():
request.id = "abc123"
request.origin = "hostname"
request.retries = 25
request.argsrepr = "[1, 2]"
request.kwargsrepr = "{}"
request.headers = {
request.args = [1, 2]
request.kwargs = {
"clean_python_context": {
"tenant": None,
"correlation_id": "b3089ea7-2585-43e5-a63c-ae30a6e9b5e4",
Expand All @@ -73,8 +73,8 @@ def test_log_with_request(celery_task_logger: CeleryTaskLogger, celery_task):
assert entry["name"] == "task_name"
assert entry["task_id"] == "abc123"
assert entry["retries"] == 25
assert entry["argsrepr"] == "[1, 2]"
assert entry["kwargsrepr"] == "{}"
assert entry["args"] == [1, 2]
assert entry["kwargs"] == {}
assert entry["origin"] == "hostname"
assert entry["correlation_id"] == "b3089ea7-2585-43e5-a63c-ae30a6e9b5e4"

Expand Down

0 comments on commit a60940c

Please sign in to comment.