diff --git a/auraxium/event/_client.py b/auraxium/event/_client.py index 01039f9..6152347 100644 --- a/auraxium/event/_client.py +++ b/auraxium/event/_client.py @@ -25,7 +25,7 @@ _EventT = TypeVar('_EventT', bound=Event) _EventT2 = TypeVar('_EventT2', bound=Event) _CallbackT = Union[Callable[[_EventT], None], - Callable[[_EventT], Coroutine[Any, Any, None]]] +Callable[[_EventT], Coroutine[Any, Any, None]]] _log = logging.getLogger('auraxium.ess') @@ -170,6 +170,14 @@ def remove_trigger(self, trigger: Union[Trigger, str], *, _log.info('All triggers have been removed, closing websocket') self.loop.create_task(self.close()) + def _subscribe_all(self): + """Add subscription messages for every registered trigger. + + This will add a subscription message for every trigger currently registered with the client. + Useful for resubscribing to all events after a disconnect. + """ + self._send_queue.extend([trigger.generate_subscription() for trigger in self.triggers]) + async def close(self) -> None: """Gracefully shut down the client. @@ -186,7 +194,7 @@ async def connect(self) -> None: This will continuously loop until :meth:`EventClient.close` is called. - If the WebSocket connection encounters and error, it will be + If the WebSocket connection encounters an error, it will be automatically restarted. Any event payloads received will be passed to @@ -262,9 +270,12 @@ async def _connection_handler(self) -> None: # NOTE: The following "async for" loop will cleanly restart the # connection should it go down. Invoking "continue" manually may be # used to manually force a reconnect if needed. - + connection_failed = False async for websocket in websockets.client.connect(str(url)): _log.info('Connected to %s', url) + if connection_failed: + self._subscribe_all() + connection_failed = False self.websocket = websocket try: @@ -273,6 +284,7 @@ async def _connection_handler(self) -> None: except websockets.exceptions.ConnectionClosed: _log.info('Connection closed, restarting...') + connection_failed = True continue if not self._open: @@ -312,22 +324,22 @@ async def _handle_websocket(self, timeout: float = 0.1) -> None: def trigger(self, event: Type[_EventT], *, name: Optional[str] = None, **kwargs: Any) -> Callable[[_CallbackT[_EventT]], None]: # Single event variant (checks callback argument type) - ... # pragma: no cover + ... # pragma: no cover @overload def trigger(self, event: Type[_EventT], arg1: Type[_EventT], *args: Type[_EventT2], name: Optional[str] = None, **kwargs: Any) -> Callable[ - [_CallbackT[Union[_EventT, _EventT2]]], None]: + [_CallbackT[Union[_EventT, _EventT2]]], None]: # Two event variant (checks callback argument type) - ... # pragma: no cover + ... # pragma: no cover @overload def trigger(self, event: Union[str, Type[Event]], *args: Union[str, Type[Event]], name: Optional[str] = None, **kwargs: Any) -> Callable[[_CallbackT[Event]], None]: # Generic fallback variant (callback argument type not checked) - ... # pragma: no cover + ... # pragma: no cover def trigger(self, event: Union[str, Type[Event]], *args: Union[str, Type[Event]], name: Optional[str] = None,