diff --git a/README.md b/README.md index e3bcbc9..435de4d 100644 --- a/README.md +++ b/README.md @@ -12,12 +12,14 @@ * [Client](#client) * [authenticate](#authenticate) * [close](#close) + * [close_and_wait](#close_and_wait) * [connect](#connect) * [connect_pool](#connect_pool) * [get_default_scope](#get_default_scope) * [get_event_loop](#get_event_loop) * [get_rooms](#get_rooms) * [is_connected](#is_connected) + * [is_websocket](#is_websocket) * [query](#query) * [reconnect](#reconnect) * [run](#run) @@ -31,6 +33,7 @@ * [emit](#emit) * [no_join](#no_join) * [Failed packages](#failed-packages) + * [WebSockets](#websockets) --------------------------------------- ## Installation @@ -71,8 +74,7 @@ async def hello_world(): finally: # the will close the client in a nice way - client.close() - await client.wait_closed() + await client.close_and_wait() # run the hello world example asyncio.get_event_loop().run_until_complete(hello_world()) @@ -148,6 +150,16 @@ This method will return immediately so the connection may not be closed yet after a call to `close()`. Use the [wait_closed()](#wait_closed) method after calling this method if this is required. +### close_and_wait + +```python +async Client().close_and_wait() -> None +``` + +Close and wait for the the connection to be closed. + +This is equivalent of combining [close()](#close)) and [wait_closed()](#wait_closed). + ### connect ```python @@ -167,11 +179,12 @@ connection before using the connection. #### Args - *host (str)*: - A hostname, IP address, FQDN to connect to. + A hostname, IP address, FQDN or URI _(for WebSockets)_ to connect to. - *port (int, optional)*: Integer value between 0 and 65535 and should be the port number where a ThingsDB node is listening to for client connections. - Defaults to 9200. + Defaults to 9200. For WebSocket connections the port must be + provided with the URI _(see host argument)_. - *timeout (int, optional)*: Can be be used to control the maximum time the client will attempt to create a connection. The timeout may be set to @@ -207,9 +220,9 @@ to perform the authentication. ```python await connect_pool([ - 'node01.local', # address as string - 'node02.local', # port will default to 9200 - ('node03.local', 9201), # ..or with an explicit port + 'node01.local', # address or WebSocket URI as string + 'node02.local', # port will default to 9200 or ignored for URI + ('node03.local', 9201), # ..or with an explicit port (ignored for URI) ], "admin", "pass") ``` @@ -217,7 +230,8 @@ await connect_pool([ - *pool (list of addresses)*: Should be an iterable with node address strings, or tuples - with `address` and `port` combinations in a tuple or list. + with `address` and `port` combinations in a tuple or list. For WebSockets, + the address must be an URI with the port included. (e.g: `"ws://host:9270"`) - *\*auth (str or (str, str))*: Argument `auth` can be be either a string with a token or a tuple with username and password. (the latter may be provided @@ -282,6 +296,18 @@ Can be used to check if the client is connected. #### Returns `True` when the client is connected else `False`. + +### is_websocket + +```python +Client().is_websocket() -> bool +``` + +Can be used to check if the client is using a WebSocket connection. + +#### Returns +`True` when the client is connected else `False`. + ### query ```python @@ -595,3 +621,52 @@ set_package_fail_file('/tmp/thingsdb-invalid-data.mp') # When a package is received which fails to unpack, the data from this package # will be stored to file. ``` + + +## WebSockets + +Since ThingsDB 1.6 has received WebSocket support. The Python client is able to use the WebSockets protocol by providing the `host` as URI. +For WebSocket connections,the `port` argument will be ignored and must be specified with the URI instead. + +Default the `websockets` package is **not included** when installing this connector. + +If you want to use WebSockets, make sure to install the package: + +``` +pip install websockets +``` + +For example: + +```python +import asyncio +from thingsdb.client import Client + +async def hello_world(): + client = Client() + + # replace `ws://localhost:9270` with your URI + await client.connect('ws://localhost:9270') + + # for a secure connection, use wss:// and provide an SSL context, example: + # (ssl can be set either to True or False, or an SSLContext) + # + # await client.connect('wss://localhost:9270', ssl=True) + + try: + # replace `admin` and `pass` with your username and password + # or use a valid token string + await client.authenticate('admin', 'pass') + + # perform the hello world code... + print(await client.query(''' + "Hello World!"; + ''') + + finally: + # the will close the client in a nice way + await client.close_and_wait() + +# run the hello world example +asyncio.get_event_loop().run_until_complete(hello_world()) +``` \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index cefb916..a887750 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,4 @@ msgpack>=0.6.2 deprecation +# Optional package: +# websockets \ No newline at end of file diff --git a/test_thingsdb.py b/test_thingsdb.py index 7c69b5e..a211b24 100644 --- a/test_thingsdb.py +++ b/test_thingsdb.py @@ -20,8 +20,7 @@ async def async_test_playground(self): self.assertEqual(data, want) finally: - client.close() - await client.wait_closed() + await client.close_and_wait() def test_playground(self): loop = asyncio.get_event_loop() diff --git a/thingsdb/client/client.py b/thingsdb/client/client.py index 0bb4d3c..b770848 100644 --- a/thingsdb/client/client.py +++ b/thingsdb/client/client.py @@ -7,7 +7,7 @@ from typing import Optional, Union, Any from concurrent.futures import CancelledError from .buildin import Buildin -from .protocol import Proto, Protocol +from .protocol import Proto, Protocol, ProtocolWS from ..exceptions import NodeError, AuthError from ..util import strip_code @@ -85,7 +85,7 @@ def is_connected(self) -> bool: Returns: bool: `True` when the client is connected else `False`. """ - return bool(self._protocol and self._protocol.transport) + return bool(self._protocol and self._protocol.is_connected()) def set_default_scope(self, scope: str) -> None: """Set the default scope. @@ -119,9 +119,9 @@ def close(self) -> None: closed yet after a call to `close()`. Use the `wait_closed()` method after calling this method if this is required. """ - if self._protocol and self._protocol.transport: - self._reconnect = False - self._protocol.transport.close() + self._reconnect = False + if self._protocol: + self._protocol.close() def connection_info(self) -> str: """Returns the current connection info as a string. @@ -134,7 +134,7 @@ def connection_info(self) -> str: """ if not self.is_connected(): return 'disconnected' - socket = self._protocol.transport.get_extra_info('socket', None) + socket = self._protocol.info() if socket is None: return 'unknown_addr' addr, port = socket.getpeername()[:2] @@ -205,11 +205,13 @@ def connect( Args: host (str): - A hostname, IP address, FQDN to connect to. + A hostname, IP address, FQDN or URI (for WebSockets) to connect + to. port (int, optional): Integer value between 0 and 65535 and should be the port number where a ThingsDB node is listening to for client connections. - Defaults to 9200. + Defaults to 9200. For WebSocket connections the port must be + provided with the URI (see host argument). timeout (int, optional): Can be be used to control the maximum time the client will attempt to create a connection. The timeout may be set to @@ -250,8 +252,19 @@ async def wait_closed(self) -> None: Can be used after calling the `close()` method to determine when the connection is actually closed. """ - if self._protocol and self._protocol.close_future: - await self._protocol.close_future + if self._protocol and self._protocol.is_closing(): + await self._protocol.wait_closed() + + async def close_and_wait(self) -> None: + """Close and wait for the connection to be closed. + + This is equivalent to calling close() and await wait_closed() + """ + if self._protocol: + await self._protocol.close_and_wait() + + def is_websocket(self) -> bool: + return self._protocol.__class__ is ProtocolWS async def authenticate( self, @@ -538,20 +551,32 @@ def _auth_check(auth): ) return auth + @staticmethod + def _is_websocket_host(host): + return host.startswith('ws://') or host.startswith('wss://') + async def _connect(self, timeout=5): host, port = self._pool[self._pool_idx] try: - conn = self._loop.create_connection( - lambda: Protocol( + if self._is_websocket_host(host): + conn = ProtocolWS( on_connection_lost=self._on_connection_lost, - on_event=self._on_event, - loop=self._loop), - host=host, - port=port, - ssl=self._ssl) - _, self._protocol = await asyncio.wait_for( - conn, - timeout=timeout) + on_event=self._on_event).connect(uri=host, ssl=self._ssl) + self._protocol = await asyncio.wait_for( + conn, + timeout=timeout) + else: + conn = self._loop.create_connection( + lambda: Protocol( + on_connection_lost=self._on_connection_lost, + on_event=self._on_event, + loop=self._loop), + host=host, + port=port, + ssl=self._ssl) + _, self._protocol = await asyncio.wait_for( + conn, + timeout=timeout) finally: self._pool_idx += 1 self._pool_idx %= len(self._pool) @@ -614,15 +639,17 @@ async def _reconnect_loop(self): await self._authenticate(timeout=5) await self._rejoin() except Exception as e: + name = host if self._is_websocket_host(host) else \ + f'{host}:{port}' logging.error( - f'Connecting to {host}:{port} failed: ' + f'Connecting to {name} failed: ' f'{e}({e.__class__.__name__}), ' f'Try next connect in {wait_time} seconds' ) else: - if protocol and protocol.transport: + if protocol and protocol.is_connected(): # make sure the `old` connection will be dropped - self._loop.call_later(10.0, protocol.transport.close) + self._loop.call_later(10.0, protocol.close) break await asyncio.sleep(wait_time) diff --git a/thingsdb/client/package.py b/thingsdb/client/package.py index b51310b..02e6862 100644 --- a/thingsdb/client/package.py +++ b/thingsdb/client/package.py @@ -33,6 +33,18 @@ def __init__(self, barray: bytearray) -> None: self.total = self.__class__.st_package.size + self.length self.data = None + def _handle_fail_file(self, message: bytes): + if _fail_file: + try: + with open(_fail_file, 'wb') as f: + f.write( + message[self.__class__.st_package.size:self.total]) + except Exception: + logging.exception('') + else: + logging.warning( + f'Wrote the content from {self} to `{_fail_file}`') + def extract_data_from(self, barray: bytearray) -> None: try: self.data = msgpack.unpackb( @@ -40,19 +52,20 @@ def extract_data_from(self, barray: bytearray) -> None: raw=False) \ if self.length else None except Exception as e: - if _fail_file: - try: - with open(_fail_file, 'wb') as f: - f.write( - barray[self.__class__.st_package.size:self.total]) - except Exception: - logging.exception('') - else: - logging.warning( - f'Wrote the content from {self} to `{_fail_file}`') + self._handle_fail_file(barray) raise e finally: del barray[:self.total] + def read_data_from(self, message: bytes) -> None: + try: + self.data = msgpack.unpackb( + message[self.__class__.st_package.size:self.total], + raw=False) \ + if self.length else None + except Exception as e: + self._handle_fail_file(message) + raise e + def __repr__(self) -> str: return ''.format(self) diff --git a/thingsdb/client/protocol.py b/thingsdb/client/protocol.py index c848540..fcdac48 100644 --- a/thingsdb/client/protocol.py +++ b/thingsdb/client/protocol.py @@ -2,6 +2,8 @@ import asyncio import logging import msgpack +from abc import abstractmethod +from ssl import SSLContext from typing import Optional, Any, Callable from .package import Package from ..exceptions import AssertionError @@ -22,11 +24,16 @@ from ..exceptions import RequestTimeoutError from ..exceptions import ResultTooLargeError from ..exceptions import SyntaxError -from ..exceptions import ThingsDBError from ..exceptions import TypeError from ..exceptions import ValueError from ..exceptions import WriteUVError from ..exceptions import ZeroDivisionError +try: + import websockets + from websockets.client import connect, WebSocketClientProtocol + from websockets.exceptions import ConnectionClosed +except ImportError: + pass class Proto(enum.IntEnum): @@ -130,7 +137,110 @@ def proto_unkown(f, d): f.set_exception(TypeError('unknown package type received ({})'.format(d))) -class Protocol(asyncio.Protocol): +class _Protocol: + def __init__( + self, + on_connection_lost: Callable[[asyncio.Protocol, Exception], None], + on_event: Callable[[Package], None],): + self._requests = {} + self._pid = 0 + self._on_connection_lost = on_connection_lost + self._on_event = on_event + + async def _timer(self, pid: int, timeout: Optional[int]) -> None: + await asyncio.sleep(timeout) + try: + future, task = self._requests.pop(pid) + except KeyError: + logging.error('Timed out package Id not found: {}'.format( + self._data_package.pid)) + return None + + future.set_exception(TimeoutError( + 'request timed out on package Id {}'.format(pid))) + + def _on_response(self, pkg: Package) -> None: + try: + future, task = self._requests.pop(pkg.pid) + except KeyError: + logging.error('Received package id not found: {}'.format(pkg.pid)) + return None + + # cancel the timeout task + if task is not None: + task.cancel() + + if future.cancelled(): + return + + _PROTO_RESPONSE_MAP.get(pkg.tp, proto_unkown)(future, pkg.data) + + def _handle_package(self, pkg: Package): + tp = pkg.tp + if tp in _PROTO_RESPONSE_MAP: + self._on_response(pkg) + elif tp in _PROTO_EVENTS: + try: + self._on_event(pkg) + except Exception: + logging.exception('') + else: + logging.error(f'Unsupported package type received: {tp}') + + def write( + self, + tp: Proto, + data: Any = None, + is_bin: bool = False, + timeout: Optional[int] = None + ) -> asyncio.Future: + """Write data to ThingsDB. + This will create a new PID and returns a Future which will be + set when a response is received from ThingsDB, or time-out is reached. + """ + self._pid += 1 + self._pid %= 0x10000 # pid is handled as uint16_t + + data = data if is_bin else b'' if data is None else \ + msgpack.packb(data, use_bin_type=True) + + header = Package.st_package.pack( + len(data), + self._pid, + tp, + tp ^ 0xff) + + self._write(header + data) + + task = asyncio.ensure_future( + self._timer(self._pid, timeout)) if timeout else None + + future = asyncio.Future() + self._requests[self._pid] = (future, task) + return future + + @abstractmethod + def _write(self, data: Any): + ... + + @abstractmethod + def close(self): + ... + + @abstractmethod + def is_closing(self) -> bool: + ... + + @abstractmethod + async def wait_closed(self): + ... + + @abstractmethod + async def close_and_wait(self): + ... + + +class Protocol(_Protocol, asyncio.Protocol): def __init__( self, @@ -138,15 +248,12 @@ def __init__( on_event: Callable[[Package], None], loop: Optional[asyncio.AbstractEventLoop] = None ): + super().__init__(on_connection_lost, on_event) self._buffered_data = bytearray() self.package = None self.transport = None self.loop = asyncio.get_event_loop() if loop is None else loop self.close_future = None - self._requests = {} - self._pid = 0 - self._on_connection_lost = on_connection_lost - self._on_event = on_event def connection_made(self, transport: asyncio.Transport) -> None: ''' @@ -198,78 +305,104 @@ def data_received(self, data: bytes) -> None: f'Exception above came from package: {self.package}') self._buffered_data.clear() else: - tp = self.package.tp - if tp in _PROTO_RESPONSE_MAP: - self._on_response(self.package) - elif tp in _PROTO_EVENTS: - try: - self._on_event(self.package) - except Exception: - logging.exception('') - else: - logging.error(f'Unsupported package type received: {tp}') + self._handle_package(self.package) self.package = None - def write( - self, - tp: Proto, - data: Any = None, - is_bin: bool = False, - timeout: Optional[int] = None - ) -> asyncio.Future: - """Write data to ThingsDB. - This will create a new PID and returns a Future which will be - set when a response is received from ThingsDB, or time-out is reached. - """ + def _write(self, data: Any): if self.transport is None: raise ConnectionError('no connection') + self.transport.write(data) - self._pid += 1 - self._pid %= 0x10000 # pid is handled as uint16_t + def close(self): + if self.transport: + self.transport.close() - data = data if is_bin else b'' if data is None else \ - msgpack.packb(data, use_bin_type=True) + def is_closing(self) -> bool: + return self.close_future is not None - header = Package.st_package.pack( - len(data), - self._pid, - tp, - tp ^ 0xff) + async def wait_closed(self): + await self.close_future - self.transport.write(header + data) + async def close_and_wait(self): + self.close() + await self.close_future - task = asyncio.ensure_future( - self._timer(self._pid, timeout)) if timeout else None + def info(self): + return self.transport.get_extra_info('socket', None) - future = asyncio.Future() - self._requests[self._pid] = (future, task) - return future + def is_connected(self) -> bool: + return self.transport is not None - async def _timer(self, pid: int, timeout: Optional[int]) -> None: - await asyncio.sleep(timeout) + +class ProtocolWS(_Protocol): + """More a wrapper than a true protocol.""" + def __init__( + self, + on_connection_lost: Callable[[asyncio.Protocol, Exception], None], + on_event: Callable[[Package], None], + ): + super().__init__(on_connection_lost, on_event) try: - future, task = self._requests.pop(pid) - except KeyError: - logging.error('Timed out package Id not found: {}'.format( - self._data_package.pid)) - return None + assert type(websockets).__name__ == 'module' + except Exception: + raise ImportError( + 'missing `websockets` module; ' + 'please install the `websockets` module: ' + '\n\n pip install websockets\n\n') + self._proto: WebSocketClientProtocol = None + self._is_closing = False + + async def connect(self, uri, ssl: SSLContext): + self._proto = await connect(uri, ssl=ssl) + asyncio.create_task(self._recv_loop()) + self._is_closing = False + return self + + async def _recv_loop(self): + try: + while True: + data = await self._proto.recv() + pkg = None + try: + pkg = Package(data) + pkg.read_data_from(data) + except Exception: + logging.exception('') + # empty the byte-array to recover from this error + if pkg: + logging.error( + f'Exception above came from package: {pkg}') + else: + self._handle_package(pkg) - future.set_exception(TimeoutError( - 'request timed out on package Id {}'.format(pid))) + except ConnectionClosed as exc: + self._proto = None + self._on_connection_lost(self, exc) - def _on_response(self, pkg: Package) -> None: - try: - future, task = self._requests.pop(pkg.pid) - except KeyError: - logging.error('Received package id not found: {}'.format(pkg.pid)) - return None + def _write(self, data: Any): + if self._proto is None: + raise ConnectionError('no connection') + asyncio.create_task(self._proto.send(data)) - # cancel the timeout task - if task is not None: - task.cancel() + def close(self): + self._is_closing = True + if self._proto: + asyncio.create_task(self._proto.close()) - if future.cancelled(): - return + def is_closing(self) -> bool: + self._is_closing - _PROTO_RESPONSE_MAP.get(pkg.tp, proto_unkown)(future, pkg.data) + async def wait_closed(self): + if self._proto: + await self._proto.wait_closed() + + async def close_and_wait(self): + if self._proto: + await self._proto.close() + + def info(self): + return self._proto.transport.get_extra_info('socket', None) + + def is_connected(self) -> bool: + return self._proto is not None diff --git a/thingsdb/version.py b/thingsdb/version.py index 887a342..1a72d32 100644 --- a/thingsdb/version.py +++ b/thingsdb/version.py @@ -1 +1 @@ -__version__ = '1.0.7' +__version__ = '1.1.0'