From 2f1a66078dc1f3677b3b18b3072bdb4522d85ffc Mon Sep 17 00:00:00 2001 From: Fiiranek Date: Thu, 7 Nov 2024 01:16:51 +0100 Subject: [PATCH] Add `FullNodeWSClient` --- starknet_py/net/full_node_ws_client.py | 231 +++++++++++++++++++++++++ starknet_py/net/schemas/rpc/ws.py | 10 +- starknet_py/net/ws_client.py | 22 ++- 3 files changed, 257 insertions(+), 6 deletions(-) create mode 100644 starknet_py/net/full_node_ws_client.py diff --git a/starknet_py/net/full_node_ws_client.py b/starknet_py/net/full_node_ws_client.py new file mode 100644 index 000000000..7c8a2c722 --- /dev/null +++ b/starknet_py/net/full_node_ws_client.py @@ -0,0 +1,231 @@ +from typing import Any, Callable, Dict, List, Optional, Union, cast + +from starknet_py.net.client_models import BlockHeader, EmittedEvent, Hash, Tag +from starknet_py.net.http_client import RpcHttpClient +from starknet_py.net.schemas.rpc.ws import ( + EventsNotificationSchema, + NewHeadsNotificationSchema, + PendingTransactionsNotificationSchema, + SubscribeResponseSchema, + TransactionStatusNotificationSchema, + UnsubscribeResponseSchema, +) +from starknet_py.net.ws_client import RpcWSClient +from starknet_py.net.ws_full_node_client_models import ( + EventsNotification, + NewHeadsNotification, + NewTransactionStatus, + PendingTransactionsNotification, + SubscribeResponse, + Transaction, + TransactionStatusNotification, + UnsubscribeResponse, +) + +BlockId = Union[int, Hash, Tag] + + +class FullNodeWSClient: + """ + Starknet WebSocket client for RPC API. + """ + + def __init__(self, node_url: str): + """ + :param node_url: URL of the node providing the WebSocket API. + """ + self.node_url: str = node_url + self._rpc_ws_client: RpcWSClient = RpcWSClient(node_url) + self._subscriptions: Dict[int, Callable[[Any], None]] = {} + + async def connect(self): + """ + Establishes the WebSocket connection. + """ + await self._rpc_ws_client.connect() + + async def disconnect(self): + """ + Closes the WebSocket connection. + """ + await self._rpc_ws_client.disconnect() + + async def _subscribe( + self, + handler: Callable[[Any], Any], + method: str, + params: Optional[Dict[str, Any]] = None, + ) -> int: + data = await self._rpc_ws_client.send(method, params) + response = cast( + SubscribeResponse, + SubscribeResponseSchema().load(data), + ) + + self._subscriptions[response.subscription_id] = handler + + return response.subscription_id + + async def listen(self): + """ + Listens for incoming WebSocket messages. + """ + await self._rpc_ws_client.listen(self._handle_received_message) + + def _handle_received_message(self, message: Dict): + print(message) + if "params" not in message: + # TODO(#1498): Possibly move `handle_rpc_error` from `RpcHttpClient` to separate function + RpcHttpClient.handle_rpc_error(message) + + subscription_id = message["params"]["subscription_id"] + + if subscription_id not in self._subscriptions: + return + + handler = self._subscriptions[subscription_id] + method = message["method"] + + if method == "starknet_subscriptionNewHeads": + notification = cast( + NewHeadsNotification, + NewHeadsNotificationSchema().load(message["params"]), + ) + handler(notification.result) + + elif method == "starknet_subscriptionEvents": + notification = cast( + EventsNotification, + EventsNotificationSchema().load(message["params"]), + ) + handler(notification.result) + + elif method == "starknet_subscriptionTransactionStatus": + notification = cast( + TransactionStatusNotification, + TransactionStatusNotificationSchema().load(message["params"]), + ) + handler(notification.result) + + elif method == "starknet_subscriptionPendingTransactions": + notification = cast( + PendingTransactionsNotification, + PendingTransactionsNotificationSchema().load(message["params"]), + ) + handler(notification.result) + + elif method == "starknet_subscriptionReorg": + # TODO(#1498): Implement reorg handling once inconsistencies in spec are resolved + pass + + async def subscribe_new_heads( + self, + handler: Callable[[BlockHeader], Any], + block: Optional[BlockId], + ) -> int: + """ + Creates a WebSocket stream which will fire events for new block headers. + + :param handler: The function to call when a new block header is received. + :param block: The block to get notifications from, default is latest, limited to 1024 blocks back. + :return: The subscription ID. + """ + params = {"block": block} if block else {} + subscription_id = await self._subscribe( + handler, "starknet_subscribeNewHeads", params + ) + + return subscription_id + + async def subscribe_events( + self, + handler: Callable[[EmittedEvent], Any], + from_address: Optional[int] = None, + keys: Optional[List[List[int]]] = None, + block: Optional[BlockId] = None, + ) -> int: + """ + Creates a WebSocket stream which will fire events for new Starknet events with applied filters. + + :param handler: The function to call when a new event is received. + :param from_address: Address which emitted the event. + :param keys: The keys to filter events by. + :param block: The block to get notifications from, default is latest, limited to 1024 blocks back. + :return: The subscription ID. + """ + params = {"from_address": from_address, "keys": keys, "block": block} + # params = {"block": block} + subscription_id = await self._subscribe( + handler, "starknet_subscribeEvents", params + ) + + return subscription_id + + async def subscribe_transaction_status( + self, + handler: Callable[[NewTransactionStatus], Any], + transaction_hash: int, + block: Optional[BlockId], + ) -> int: + """ + Creates a WebSocket stream which will fire events when a transaction status is updated. + + :param handler: The function to call when a new transaction status is received. + :param transaction_hash: The transaction hash to fetch status updates for. + :param block: The block to get notifications from, default is latest, limited to 1024 blocks back. + :return: The subscription ID. + """ + params = {"transaction_hash": transaction_hash, "block": block} + subscription_id = await self._subscribe( + handler, "starknet_subscribeTransactionStatus", params + ) + + return subscription_id + + async def subscribe_pending_transactions( + self, + handler: Callable[[Union[int, Transaction]], Any], + transaction_details: Optional[bool], + sender_address: Optional[List[int]], + ) -> int: + """ + Creates a WebSocket stream which will fire events when a new pending transaction is added. + While there is no mempool, this notifies of transactions in the pending block. + + :param handler: The function to call when a new pending transaction is received. + :param transaction_details: Whether to include transaction details in the notification. + If false, only hash is returned. + :param sender_address: The sender address to filter transactions by. + :return: The subscription ID. + """ + params = { + "transaction_details": transaction_details, + "sender_address": sender_address, + } + subscription_id = await self._subscribe( + handler, "starknet_subscribePendingTransactions", params + ) + + return subscription_id + + async def unsubscribe(self, subscription_id: int) -> bool: + """ + Close a previously opened WebSocket stream, with the corresponding subscription id. + + :param subscription_id: ID of the subscription to close. + :return: True if the unsubscription was successful, False otherwise. + """ + if subscription_id not in self._subscriptions: + return False + + params = {"subscription_id": subscription_id} + res = await self._rpc_ws_client.send("starknet_unsubscribe", params) + + unsubscribe_response = cast( + UnsubscribeResponse, UnsubscribeResponseSchema().load(res) + ) + + if unsubscribe_response: + del self._subscriptions[subscription_id] + + return unsubscribe_response.result diff --git a/starknet_py/net/schemas/rpc/ws.py b/starknet_py/net/schemas/rpc/ws.py index dfde0eb20..2868ed39e 100644 --- a/starknet_py/net/schemas/rpc/ws.py +++ b/starknet_py/net/schemas/rpc/ws.py @@ -1,6 +1,8 @@ from marshmallow import Schema, fields, post_load +from starknet_py.net.schemas.common import Felt from starknet_py.net.schemas.rpc.block import BlockHeaderSchema +from starknet_py.net.schemas.rpc.event import EmittedEventSchema from starknet_py.net.ws_full_node_client_models import ( EventsNotification, NewHeadsNotification, @@ -32,7 +34,7 @@ def make_dataclass(self, data, **kwargs) -> NewHeadsNotification: class EventsNotificationSchema(Schema): subscription_id = fields.Integer(data_key="subscription_id", required=True) - result = fields.List(fields.Dict(), data_key="result", required=True) + result = fields.Nested(EmittedEventSchema(), data_key="result", required=True) @post_load def make_dataclass(self, data, **kwargs) -> EventsNotification: @@ -40,7 +42,7 @@ def make_dataclass(self, data, **kwargs) -> EventsNotification: class NewTransactionStatusSchema(Schema): - transaction_hash = fields.Integer(data_key="transaction_hash", required=True) + transaction_hash = Felt(data_key="transaction_hash", required=True) status = fields.Dict(data_key="status", required=True) @post_load @@ -77,11 +79,11 @@ def make_dataclass(self, data, **kwargs) -> UnsubscribeResponse: class ReorgDataSchema(Schema): - starting_block_hash = fields.Integer(data_key="starting_block_hash", required=True) + starting_block_hash = Felt(data_key="starting_block_hash", required=True) starting_block_number = fields.Integer( data_key="starting_block_number", required=True, validate=lambda x: x >= 0 ) - ending_block_hash = fields.Integer(data_key="ending_block_hash", required=True) + ending_block_hash = Felt(data_key="ending_block_hash", required=True) ending_block_number = fields.Integer( data_key="ending_block_number", required=True, validate=lambda x: x >= 0 ) diff --git a/starknet_py/net/ws_client.py b/starknet_py/net/ws_client.py index a0dbc9c12..f7a22c079 100644 --- a/starknet_py/net/ws_client.py +++ b/starknet_py/net/ws_client.py @@ -1,5 +1,5 @@ import json -from typing import Any, Dict, Optional, Union, cast +from typing import Any, Callable, Dict, Optional, Union, cast from websockets.asyncio.client import ClientConnection, connect @@ -31,12 +31,29 @@ async def send_raw( self, payload: Optional[Dict[str, Any]] = None, ): + """ + Sends a message to the WebSocket server and returns the response. + + :param payload: The message to send. + """ assert self.connection is not None await self.connection.send(json.dumps(payload)) data = await self.connection.recv() return data + async def listen(self, received_message_handler: Callable[[Dict[str, Any]], Any]): + """ + Listens for incoming WebSocket messages. + + :param received_message_handler: The function to call when a message is received. + """ + assert self.connection is not None + + async for message in self.connection: + message = cast(Dict, json.loads(message)) + received_message_handler(message) + class RpcWSClient(WSClient): """ @@ -59,6 +76,7 @@ async def send( data = cast(Dict, json.loads(data)) if "result" not in data: + # TODO(#1498): Possibly move `handle_rpc_error` from `RpcHttpClient` to separate function RpcHttpClient.handle_rpc_error(data) - return data + return data["result"]