Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow adding health check callback to Dapr App run in gRPC #723

Merged
merged 2 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions .github/scripts/automerge.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,8 @@

def fetch_pulls(mergeable_state, labels={'automerge'}):
return [pr for pr in repo.get_pulls(state='open', sort='created')
# noqa: E502
if (not pr.draft) and (pr.mergeable_state == mergeable_state) and \
(not labels or len(labels.intersection({label.name for label in pr.labels})) > 0)]
if (not pr.draft and pr.mergeable_state == mergeable_state
and (not labels or len(labels.intersection({label.name for label in pr.labels})) > 0))]


def is_approved(pr):
Expand Down
2 changes: 1 addition & 1 deletion examples/pubsub-simple/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ sleep: 15

```bash
# 2. Start Publisher
dapr run --app-id python-publisher --app-protocol grpc --dapr-grpc-port=3500 python3 publisher.py
dapr run --app-id python-publisher --app-protocol grpc --dapr-grpc-port=3500 --enable-app-health-check python3 publisher.py
```

<!-- END_STEP -->
Expand Down
7 changes: 7 additions & 0 deletions examples/pubsub-simple/subscriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,11 @@ def mytopic_wildcard(event: v1.Event) -> TopicEventResponse:
return TopicEventResponse('success')


# Example of an unhealthy status
# def unhealthy():
# raise ValueError("Not healthy")
# app.register_health_check(unhealthy)

app.register_health_check(lambda: print('Healthy'))

app.run(50051)
32 changes: 32 additions & 0 deletions ext/dapr-ext-grpc/dapr/ext/grpc/_health_servicer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import grpc
from typing import Callable, Optional

from dapr.proto import appcallback_service_v1
from dapr.proto.runtime.v1.appcallback_pb2 import HealthCheckResponse

HealthCheckCallable = Optional[Callable[[], None]]


class _HealthCheckServicer(appcallback_service_v1.AppCallbackHealthCheckServicer):
"""The implementation of HealthCheck Server.

:class:`App` provides useful decorators to register method, topic, input bindings.
"""

def __init__(self):
self._health_check_cb: Optional[HealthCheckCallable] = None

def register_health_check(self, cb: HealthCheckCallable) -> None:
if not cb:
raise ValueError('health check callback must be defined')

Check warning on line 21 in ext/dapr-ext-grpc/dapr/ext/grpc/_health_servicer.py

View check run for this annotation

Codecov / codecov/patch

ext/dapr-ext-grpc/dapr/ext/grpc/_health_servicer.py#L21

Added line #L21 was not covered by tests
self._health_check_cb = cb

def HealthCheck(self, request, context):
"""Health check."""

if not self._health_check_cb:
context.set_code(grpc.StatusCode.UNIMPLEMENTED) # type: ignore
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
self._health_check_cb()
return HealthCheckResponse()
16 changes: 15 additions & 1 deletion ext/dapr-ext-grpc/dapr/ext/grpc/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from typing import Dict, Optional

from dapr.conf import settings
from dapr.ext.grpc._servicier import _CallbackServicer, Rule # type: ignore
from dapr.ext.grpc._servicer import _CallbackServicer, Rule # type: ignore
from dapr.ext.grpc._health_servicer import _HealthCheckServicer # type: ignore
from dapr.proto import appcallback_service_v1


Expand All @@ -43,6 +44,7 @@ def __init__(self, max_grpc_message_length: Optional[int] = None, **kwargs):
kwargs: arguments to grpc.server()
"""
self._servicer = _CallbackServicer()
self._health_check_servicer = _HealthCheckServicer()
if not kwargs:
options = []
if max_grpc_message_length is not None:
Expand All @@ -56,6 +58,9 @@ def __init__(self, max_grpc_message_length: Optional[int] = None, **kwargs):
else:
self._server = grpc.server(**kwargs) # type: ignore
appcallback_service_v1.add_AppCallbackServicer_to_server(self._servicer, self._server)
appcallback_service_v1.add_AppCallbackHealthCheckServicer_to_server(
self._health_check_servicer, self._server
)

def __del__(self):
self.stop()
Expand All @@ -64,6 +69,15 @@ def add_external_service(self, servicer_callback, external_servicer):
"""Adds an external gRPC service to the same server"""
servicer_callback(external_servicer, self._server)

def register_health_check(self, health_check_callback):
"""Adds a health check callback

The below example adds a basic health check to check Dapr gRPC is running

@app.register_health_check(lambda: None)
"""
self._health_check_servicer.register_health_check(health_check_callback)

def run(self, app_port: Optional[int] = None, listen_address: Optional[str] = None) -> None:
"""Starts app gRPC server and waits until :class:`App`.stop() is called.

Expand Down
14 changes: 14 additions & 0 deletions ext/dapr-ext-grpc/tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,17 @@
'AppTests.test_subscribe_decorator.<locals>.handle_dead_letter',
str(subscription_map['pubsub:topic2:']),
)

def test_register_health_check(self):
def health_check_cb():
pass

Check warning on line 80 in ext/dapr-ext-grpc/tests/test_app.py

View check run for this annotation

Codecov / codecov/patch

ext/dapr-ext-grpc/tests/test_app.py#L80

Added line #L80 was not covered by tests

self._app.register_health_check(health_check_cb)
registered_cb = self._app._health_check_servicer._health_check_cb
self.assertIn(
'AppTests.test_register_health_check.<locals>.health_check_cb', str(registered_cb)
)

def test_no_health_check(self):
registered_cb = self._app._health_check_servicer._health_check_cb
self.assertIsNone(registered_cb)
20 changes: 20 additions & 0 deletions ext/dapr-ext-grpc/tests/test_health_servicer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import unittest
from unittest.mock import MagicMock

from dapr.ext.grpc._health_servicer import _HealthCheckServicer


class OnInvokeTests(unittest.TestCase):
def setUp(self):
self._health_servicer = _HealthCheckServicer()

def test_healthcheck_cb_called(self):
health_cb = MagicMock()
self._health_servicer.register_health_check(health_cb)
self._health_servicer.HealthCheck(None, MagicMock())
health_cb.assert_called_once()

def test_no_healthcheck_cb(self):
with self.assertRaises(NotImplementedError) as exception_context:
self._health_servicer.HealthCheck(None, MagicMock())
self.assertIn('Method not implemented!', exception_context.exception.args[0])
48 changes: 24 additions & 24 deletions ext/dapr-ext-grpc/tests/test_servicier.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,18 @@

from dapr.clients.grpc._request import InvokeMethodRequest
from dapr.clients.grpc._response import InvokeMethodResponse, TopicEventResponse
from dapr.ext.grpc._servicier import _CallbackServicer
from dapr.ext.grpc._servicer import _CallbackServicer
from dapr.proto import common_v1, appcallback_v1

from google.protobuf.any_pb2 import Any as GrpcAny


class OnInvokeTests(unittest.TestCase):
def setUp(self):
self._servicier = _CallbackServicer()
self._servicer = _CallbackServicer()

def _on_invoke(self, method_name, method_cb):
self._servicier.register_method(method_name, method_cb)
self._servicer.register_method(method_name, method_cb)

# fake context
fake_context = MagicMock()
Expand All @@ -39,7 +39,7 @@ def _on_invoke(self, method_name, method_cb):
('key2', 'value1'),
)

return self._servicier.OnInvoke(
return self._servicer.OnInvoke(
common_v1.InvokeRequest(method=method_name, data=GrpcAny()),
fake_context,
)
Expand Down Expand Up @@ -93,18 +93,18 @@ def method_cb(request: InvokeMethodRequest):

class TopicSubscriptionTests(unittest.TestCase):
def setUp(self):
self._servicier = _CallbackServicer()
self._servicer = _CallbackServicer()
self._topic1_method = Mock()
self._topic2_method = Mock()
self._topic3_method = Mock()
self._topic3_method.return_value = TopicEventResponse('success')
self._topic4_method = Mock()

self._servicier.register_topic('pubsub1', 'topic1', self._topic1_method, {'session': 'key'})
self._servicier.register_topic('pubsub1', 'topic3', self._topic3_method, {'session': 'key'})
self._servicier.register_topic('pubsub2', 'topic2', self._topic2_method, {'session': 'key'})
self._servicier.register_topic('pubsub2', 'topic3', self._topic3_method, {'session': 'key'})
self._servicier.register_topic(
self._servicer.register_topic('pubsub1', 'topic1', self._topic1_method, {'session': 'key'})
self._servicer.register_topic('pubsub1', 'topic3', self._topic3_method, {'session': 'key'})
self._servicer.register_topic('pubsub2', 'topic2', self._topic2_method, {'session': 'key'})
self._servicer.register_topic('pubsub2', 'topic3', self._topic3_method, {'session': 'key'})
self._servicer.register_topic(
'pubsub3',
'topic4',
self._topic4_method,
Expand All @@ -121,12 +121,12 @@ def setUp(self):

def test_duplicated_topic(self):
with self.assertRaises(ValueError):
self._servicier.register_topic(
self._servicer.register_topic(
'pubsub1', 'topic1', self._topic1_method, {'session': 'key'}
)

def test_list_topic_subscription(self):
resp = self._servicier.ListTopicSubscriptions(None, None)
resp = self._servicer.ListTopicSubscriptions(None, None)
self.assertEqual('pubsub1', resp.subscriptions[0].pubsub_name)
self.assertEqual('topic1', resp.subscriptions[0].topic)
self.assertEqual({'session': 'key'}, resp.subscriptions[0].metadata)
Expand All @@ -143,23 +143,23 @@ def test_list_topic_subscription(self):
self.assertEqual({'session': 'key'}, resp.subscriptions[4].metadata)

def test_topic_event(self):
self._servicier.OnTopicEvent(
self._servicer.OnTopicEvent(
appcallback_v1.TopicEventRequest(pubsub_name='pubsub1', topic='topic1'),
self.fake_context,
)

self._topic1_method.assert_called_once()

def test_topic3_event_called_once(self):
self._servicier.OnTopicEvent(
self._servicer.OnTopicEvent(
appcallback_v1.TopicEventRequest(pubsub_name='pubsub1', topic='topic3'),
self.fake_context,
)

self._topic3_method.assert_called_once()

def test_topic3_event_response(self):
response = self._servicier.OnTopicEvent(
response = self._servicer.OnTopicEvent(
appcallback_v1.TopicEventRequest(pubsub_name='pubsub1', topic='topic3'),
self.fake_context,
)
Expand All @@ -169,7 +169,7 @@ def test_topic3_event_response(self):
)

def test_disable_topic_validation(self):
self._servicier.OnTopicEvent(
self._servicer.OnTopicEvent(
appcallback_v1.TopicEventRequest(pubsub_name='pubsub3', topic='should_be_ignored'),
self.fake_context,
)
Expand All @@ -178,20 +178,20 @@ def test_disable_topic_validation(self):

def test_non_registered_topic(self):
with self.assertRaises(NotImplementedError):
self._servicier.OnTopicEvent(
self._servicer.OnTopicEvent(
appcallback_v1.TopicEventRequest(pubsub_name='pubsub1', topic='topic_non_existed'),
self.fake_context,
)


class BindingTests(unittest.TestCase):
def setUp(self):
self._servicier = _CallbackServicer()
self._servicer = _CallbackServicer()
self._binding1_method = Mock()
self._binding2_method = Mock()

self._servicier.register_binding('binding1', self._binding1_method)
self._servicier.register_binding('binding2', self._binding2_method)
self._servicer.register_binding('binding1', self._binding1_method)
self._servicer.register_binding('binding2', self._binding2_method)

# fake context
self.fake_context = MagicMock()
Expand All @@ -202,15 +202,15 @@ def setUp(self):

def test_duplicated_binding(self):
with self.assertRaises(ValueError):
self._servicier.register_binding('binding1', self._binding1_method)
self._servicer.register_binding('binding1', self._binding1_method)

def test_list_bindings(self):
resp = self._servicier.ListInputBindings(None, None)
resp = self._servicer.ListInputBindings(None, None)
self.assertEqual('binding1', resp.bindings[0])
self.assertEqual('binding2', resp.bindings[1])

def test_binding_event(self):
self._servicier.OnBindingEvent(
self._servicer.OnBindingEvent(
appcallback_v1.BindingEventRequest(name='binding1'),
self.fake_context,
)
Expand All @@ -219,7 +219,7 @@ def test_binding_event(self):

def test_non_registered_binding(self):
with self.assertRaises(NotImplementedError):
self._servicier.OnBindingEvent(
self._servicer.OnBindingEvent(
appcallback_v1.BindingEventRequest(name='binding3'),
self.fake_context,
)
Expand Down