Skip to content

Commit

Permalink
feat: Draft router class
Browse files Browse the repository at this point in the history
  • Loading branch information
empicano committed Jun 2, 2024
1 parent f1a6139 commit 9d48222
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 15 deletions.
2 changes: 2 additions & 0 deletions aiomqtt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
)
from .exceptions import MqttCodeError, MqttError, MqttReentrantError
from .message import Message
from .router import Router
from .topic import Topic, TopicLike, Wildcard, WildcardLike

# These are placeholders that are managed by poetry-dynamic-versioning
Expand All @@ -19,6 +20,7 @@
"__version_tuple__",
"Client",
"Message",
"Router",
"ProtocolVersion",
"ProxySettings",
"TLSParameters",
Expand Down
16 changes: 16 additions & 0 deletions aiomqtt/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

from .exceptions import MqttCodeError, MqttConnectError, MqttError, MqttReentrantError
from .message import Message
from .router import Router
from .types import (
P,
PayloadType,
Expand Down Expand Up @@ -134,6 +135,7 @@ class Client:
password: The password to authenticate with.
logger: Custom logger instance.
identifier: The client identifier. Generated automatically if ``None``.
routers: A list of routers to route messages to.
queue_type: The class to use for the queue. The default is
``asyncio.Queue``, which stores messages in FIFO order. For LIFO order,
you can use ``asyncio.LifoQueue``; For priority order you can subclass
Expand Down Expand Up @@ -186,6 +188,7 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915
password: str | None = None,
logger: logging.Logger | None = None,
identifier: str | None = None,
routers: list[Router] | None = None,
queue_type: type[asyncio.Queue[Message]] | None = None,
protocol: ProtocolVersion | None = None,
will: Will | None = None,
Expand Down Expand Up @@ -250,6 +253,11 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915
if protocol is None:
protocol = ProtocolVersion.V311

# List of routers with message handlers
if routers is None:
routers = []
self._routers = routers

# Create the underlying paho-mqtt client instance
self._client: mqtt.Client = mqtt.Client(
callback_api_version=CallbackAPIVersion.VERSION1,
Expand Down Expand Up @@ -453,6 +461,14 @@ async def publish( # noqa: PLR0913
# Wait for confirmation
await self._wait_for(confirmation.wait(), timeout=timeout)

async def route(self, message: Message) -> None:
"""Route a message to the appropriate handler."""
for router in self._routers:
for wildcard, handler in router._handlers.items():
with contextlib.suppress(ValueError):
# If we get a ValueError, we know that the topic doesn't match
await handler(message, self, *message.topic.extract(wildcard))

async def _messages(self) -> AsyncGenerator[Message, None]:
"""Async generator that yields messages from the underlying message queue."""
while True:
Expand Down
11 changes: 11 additions & 0 deletions aiomqtt/router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
class Router:
def __init__(self) -> None:
self._handlers = {}

def match(self, *args: str):
"""Add a new handler with one or multiple wildcards to the router."""
def decorator(func):
for wildcard in args:
self._handlers[wildcard] = func
return func
return decorator
43 changes: 28 additions & 15 deletions aiomqtt/topic.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,21 @@ def matches(self, wildcard: WildcardLike) -> bool:
Returns:
True if the topic matches the wildcard, False otherwise.
"""
try:
self.extract(wildcard)
return True
except ValueError:
return False

def extract(self, wildcard: WildcardLike) -> list[str]:
"""Extract the wildcard values from the topic.
Args:
wildcard: The wildcard to match against.
Returns:
A list of wildcard values extracted from the topic.
"""
if not isinstance(wildcard, Wildcard):
wildcard = Wildcard(wildcard)
# Split topics into levels to compare them one by one
Expand All @@ -98,21 +113,19 @@ def matches(self, wildcard: WildcardLike) -> bool:
# Shared subscriptions use the topic structure: $share/<group_id>/<topic>
wildcard_levels = wildcard_levels[2:]

def recurse(tl: list[str], wl: list[str]) -> bool:
"""Recursively match topic levels with wildcard levels."""
if not tl:
if not wl or wl[0] == "#":
return True
return False
if not wl:
return False
if wl[0] == "#":
return True
if tl[0] == wl[0] or wl[0] == "+":
return recurse(tl[1:], wl[1:])
return False

return recurse(topic_levels, wildcard_levels)
# Extract wildcard values from the topic
arguments = []
for index, level in enumerate(wildcard_levels):
if level == "#":
return arguments + topic_levels[index:]
if len(topic_levels) == index:
raise ValueError("Topic does not match wildcard")
if level != "+" and level != topic_levels[index]:
raise ValueError("Topic does not match wildcard")
arguments.append(topic_levels[index])
if len(topic_levels) > index + 1:
raise ValueError("Topic does not match wildcard")
return arguments


TopicLike: TypeAlias = "str | Topic"

0 comments on commit 9d48222

Please sign in to comment.