diff --git a/starknet_py/conftest.py b/starknet_py/conftest.py index 23206cbe0..e1a8f0ac5 100644 --- a/starknet_py/conftest.py +++ b/starknet_py/conftest.py @@ -7,6 +7,7 @@ "starknet_py.tests.e2e.fixtures.contracts_v1", "starknet_py.tests.e2e.fixtures.misc", "starknet_py.tests.e2e.fixtures.devnet", + "starknet_py.tests.e2e.fixtures.devnet_ws", "starknet_py.tests.e2e.fixtures.constants", "starknet_py.tests.e2e.client.fixtures.transactions", "starknet_py.tests.e2e.client.fixtures.prepare_network", diff --git a/starknet_py/net/full_node_ws_client.py b/starknet_py/net/full_node_ws_client.py new file mode 100644 index 000000000..e98449f0b --- /dev/null +++ b/starknet_py/net/full_node_ws_client.py @@ -0,0 +1,244 @@ +from typing import Any, Callable, Dict, List, Optional, Union, cast + +from starknet_py.net.client_models import BlockHeader, EmittedEvent, Hash, Tag +from starknet_py.net.client_utils import _clear_none_values +from starknet_py.net.http_client import RpcHttpClient +from starknet_py.net.schemas.rpc.ws import ( + EventsNotificationSchema, + NewHeadsNotificationSchema, + PendingTransactionsNotificationSchema, + ReorgNotificationSchema, + SubscribeResponseSchema, + TransactionStatusNotificationSchema, + UnsubscribeResponseSchema, +) +from starknet_py.net.ws_client import RpcWSClient +from starknet_py.net.ws_full_node_client_models import ( + EventsNotification, + NewHeadsNotification, + NewTransactionStatus, + PendingTransactionsNotification, + ReorgNotification, + SubscribeResponse, + Transaction, + TransactionStatusNotification, + UnsubscribeResponse, +) + +BlockId = Union[int, Hash, Tag] +HandlerNotification = Union[ + NewHeadsNotification, + EventsNotification, + TransactionStatusNotification, + PendingTransactionsNotification, + ReorgNotification, +] +Handler = Callable[[HandlerNotification], Any] + + +class FullNodeWSClient: + """ + Starknet WebSocket client for RPC API. + """ + + def __init__(self, node_url: str): + """ + :param node_url: URL of the node providing the WebSocket API. + """ + self.node_url: str = node_url + self._rpc_ws_client: RpcWSClient = RpcWSClient(node_url) + self._subscriptions: Dict[int, Handler] = {} + + async def connect(self): + """ + Establishes the WebSocket connection. + """ + await self._rpc_ws_client.connect() + + async def disconnect(self): + """ + Closes the WebSocket connection. + """ + await self._rpc_ws_client.disconnect() + + async def _subscribe( + self, + handler: Callable[[Any], Any], + method: str, + params: Optional[Dict[str, Any]] = None, + ) -> int: + data = await self._rpc_ws_client.send(method, params) + response = cast( + SubscribeResponse, + SubscribeResponseSchema().load(data), + ) + + self._subscriptions[response.subscription_id] = handler + + return response.subscription_id + + async def listen(self): + """ + Listens for incoming WebSocket messages. + """ + await self._rpc_ws_client.listen(self._handle_received_message) + + def _handle_received_message(self, message: Dict): + if "params" not in message: + # TODO(#1498): Possibly move `handle_rpc_error` from `RpcHttpClient` to separate function + RpcHttpClient.handle_rpc_error(message) + + subscription_id = message["params"]["subscription_id"] + + if subscription_id not in self._subscriptions: + return + + handler = self._subscriptions[subscription_id] + method = message["method"] + + if method == "starknet_subscriptionNewHeads": + notification = cast( + NewHeadsNotification, + NewHeadsNotificationSchema().load(message["params"]), + ) + handler(notification) + + elif method == "starknet_subscriptionEvents": + notification = cast( + EventsNotification, + EventsNotificationSchema().load(message["params"]), + ) + handler(notification) + + elif method == "starknet_subscriptionTransactionStatus": + notification = cast( + TransactionStatusNotification, + TransactionStatusNotificationSchema().load(message["params"]), + ) + handler(notification) + + elif method == "starknet_subscriptionPendingTransactions": + notification = cast( + PendingTransactionsNotification, + PendingTransactionsNotificationSchema().load(message["params"]), + ) + handler(notification) + + elif method == "starknet_subscriptionReorg": + notification = cast( + ReorgNotification, + ReorgNotificationSchema().load(message["params"]), + ) + handler(notification) + + async def subscribe_new_heads( + self, + handler: Callable[[BlockHeader], Any], + block: Optional[BlockId], + ) -> int: + """ + Creates a WebSocket stream which will fire events for new block headers. + + :param handler: The function to call when a new block header is received. + :param block: The block to get notifications from, default is latest, limited to 1024 blocks back. + :return: The subscription ID. + """ + params = {"block": block} if block else {} + subscription_id = await self._subscribe( + handler, "starknet_subscribeNewHeads", params + ) + + return subscription_id + + async def subscribe_events( + self, + handler: Callable[[EmittedEvent], Any], + from_address: Optional[int] = None, + keys: Optional[List[List[int]]] = None, + block: Optional[BlockId] = None, + ) -> int: + """ + Creates a WebSocket stream which will fire events for new Starknet events with applied filters. + + :param handler: The function to call when a new event is received. + :param from_address: Address which emitted the event. + :param keys: The keys to filter events by. + :param block: The block to get notifications from, default is latest, limited to 1024 blocks back. + :return: The subscription ID. + """ + params = {"from_address": from_address, "keys": keys, "block": block} + subscription_id = await self._subscribe( + handler, "starknet_subscribeEvents", params + ) + + return subscription_id + + async def subscribe_transaction_status( + self, + handler: Callable[[NewTransactionStatus], Any], + transaction_hash: int, + block: Optional[BlockId] = None, + ) -> int: + """ + Creates a WebSocket stream which will fire events when a transaction status is updated. + + :param handler: The function to call when a new transaction status is received. + :param transaction_hash: The transaction hash to fetch status updates for. + :param block: The block to get notifications from, default is latest, limited to 1024 blocks back. + :return: The subscription ID. + """ + params = {"transaction_hash": transaction_hash, "block": block} + params = _clear_none_values(params) + subscription_id = await self._subscribe( + handler, "starknet_subscribeTransactionStatus", params + ) + + return subscription_id + + async def subscribe_pending_transactions( + self, + handler: Callable[[Union[int, Transaction]], Any], + transaction_details: Optional[bool], + sender_address: Optional[List[int]], + ) -> int: + """ + Creates a WebSocket stream which will fire events when a new pending transaction is added. + While there is no mempool, this notifies of transactions in the pending block. + + :param handler: The function to call when a new pending transaction is received. + :param transaction_details: Whether to include transaction details in the notification. + If false, only hash is returned. + :param sender_address: The sender address to filter transactions by. + :return: The subscription ID. + """ + params = { + "transaction_details": transaction_details, + "sender_address": sender_address, + } + subscription_id = await self._subscribe( + handler, "starknet_subscribePendingTransactions", params + ) + + return subscription_id + + async def unsubscribe(self, subscription_id: int) -> bool: + """ + Close a previously opened WebSocket stream, with the corresponding subscription id. + + :param subscription_id: ID of the subscription to close. + :return: True if the unsubscription was successful, False otherwise. + """ + if subscription_id not in self._subscriptions: + return False + + params = {"subscription_id": subscription_id} + res = await self._rpc_ws_client.send("starknet_unsubscribe", params) + + unsubscribe_response = cast( + UnsubscribeResponse, UnsubscribeResponseSchema().load(res) + ) + + if unsubscribe_response: + del self._subscriptions[subscription_id] + + return unsubscribe_response.result diff --git a/starknet_py/net/schemas/rpc/ws.py b/starknet_py/net/schemas/rpc/ws.py new file mode 100644 index 000000000..439c321b6 --- /dev/null +++ b/starknet_py/net/schemas/rpc/ws.py @@ -0,0 +1,103 @@ +from marshmallow import Schema, fields, post_load + +from starknet_py.net.schemas.common import Felt +from starknet_py.net.schemas.rpc.block import BlockHeaderSchema +from starknet_py.net.schemas.rpc.event import EmittedEventSchema +from starknet_py.net.ws_full_node_client_models import ( + EventsNotification, + NewHeadsNotification, + NewTransactionStatus, + PendingTransactionsNotification, + ReorgData, + ReorgNotification, + SubscribeResponse, + TransactionStatusNotification, + UnsubscribeResponse, +) + + +class SubscribeResponseSchema(Schema): + subscription_id = fields.Integer(data_key="subscription_id", required=True) + + @post_load + def make_dataclass(self, data, **kwargs) -> SubscribeResponse: + return SubscribeResponse(**data) + + +class NewHeadsNotificationSchema(Schema): + subscription_id = fields.Integer(data_key="subscription_id", required=True) + result = fields.Nested(BlockHeaderSchema(), data_key="result", required=True) + + @post_load + def make_dataclass(self, data, **kwargs) -> NewHeadsNotification: + return NewHeadsNotification(**data) + + +class EventsNotificationSchema(Schema): + subscription_id = fields.Integer(data_key="subscription_id", required=True) + result = fields.Nested(EmittedEventSchema(), data_key="result", required=True) + + @post_load + def make_dataclass(self, data, **kwargs) -> EventsNotification: + return EventsNotification(**data) + + +class NewTransactionStatusSchema(Schema): + transaction_hash = Felt(data_key="transaction_hash", required=True) + status = fields.Dict(data_key="status", required=True) + + @post_load + def make_dataclass(self, data, **kwargs) -> NewTransactionStatus: + return NewTransactionStatus(**data) + + +class TransactionStatusNotificationSchema(Schema): + subscription_id = fields.Integer(data_key="subscription_id", required=True) + result = fields.Nested( + NewTransactionStatusSchema(), data_key="result", required=True + ) + + @post_load + def make_dataclass(self, data, **kwargs) -> TransactionStatusNotification: + return TransactionStatusNotification(**data) + + +class PendingTransactionsNotificationSchema(Schema): + subscription_id = fields.Integer(data_key="subscription_id", required=True) + result = fields.Dict(data_key="result", required=True) + + @post_load + def make_dataclass(self, data, **kwargs) -> PendingTransactionsNotification: + return PendingTransactionsNotification(**data) + + +class UnsubscribeResponseSchema(Schema): + result = fields.Boolean(data_key="result", required=True) + + @post_load + def make_dataclass(self, data, **kwargs) -> UnsubscribeResponse: + return UnsubscribeResponse(**data) + + +class ReorgDataSchema(Schema): + starting_block_hash = Felt(data_key="starting_block_hash", required=True) + starting_block_number = fields.Integer( + data_key="starting_block_number", required=True, validate=lambda x: x >= 0 + ) + ending_block_hash = Felt(data_key="ending_block_hash", required=True) + ending_block_number = fields.Integer( + data_key="ending_block_number", required=True, validate=lambda x: x >= 0 + ) + + @post_load + def make_dataclass(self, data, **kwargs) -> ReorgData: + return ReorgData(**data) + + +class ReorgNotificationSchema(Schema): + subscription_id = fields.Integer(data_key="subscription_id", required=True) + result = fields.Nested(ReorgDataSchema(), data_key="result", required=True) + + @post_load + def make_dataclass(self, data, **kwargs) -> ReorgNotification: + return ReorgNotification(**data) diff --git a/starknet_py/net/ws_client.py b/starknet_py/net/ws_client.py new file mode 100644 index 000000000..e13f5de0f --- /dev/null +++ b/starknet_py/net/ws_client.py @@ -0,0 +1,82 @@ +import json +from typing import Any, Callable, Dict, Optional, Union, cast + +from websockets.asyncio.client import ClientConnection, connect + +from starknet_py.net.http_client import RpcHttpClient + + +class WSClient: + """ + Base class for WebSocket clients. + """ + + def __init__(self, node_url: str): + """ + :param node_url: URL of the node providing the WebSocket API. + """ + self.node_url: str = node_url + self.connection: Union[None, ClientConnection] = None + + async def connect(self): + """Establishes the WebSocket connection.""" + self.connection = await connect(self.node_url) + + async def disconnect(self): + """Closes the WebSocket connection.""" + assert self.connection is not None + await self.connection.close() + + async def send_raw( + self, + payload: Optional[Dict[str, Any]] = None, + ) -> Union[str, bytes]: + """ + Sends a message to the WebSocket server and returns the response. + + :param payload: The message to send. + """ + assert self.connection is not None + await self.connection.send(json.dumps(payload)) + data = await self.connection.recv() + + return data + + async def listen(self, received_message_handler: Callable[[Dict[str, Any]], Any]): + """ + Listens for incoming WebSocket messages. + + :param received_message_handler: The function to call when a message is received. + """ + assert self.connection is not None + + async for message in self.connection: + message = cast(Dict, json.loads(message)) + received_message_handler(message) + + +class RpcWSClient(WSClient): + """ + WebSocket client for the RPC API. + """ + + async def send( + self, + method: str, + params: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + payload = { + "id": 0, + "jsonrpc": "2.0", + "method": method, + "params": params if params else [], + } + + data = await self.send_raw(payload) + data = cast(Dict, json.loads(data)) + + if "result" not in data: + # TODO(#1498): Possibly move `handle_rpc_error` from `RpcHttpClient` to separate function + RpcHttpClient.handle_rpc_error(data) + + return data["result"] diff --git a/starknet_py/net/ws_full_node_client_models.py b/starknet_py/net/ws_full_node_client_models.py new file mode 100644 index 000000000..31cd5a3d2 --- /dev/null +++ b/starknet_py/net/ws_full_node_client_models.py @@ -0,0 +1,113 @@ +""" +Dataclasses representing responses from Starknet Websocket RPC API. +""" + +from dataclasses import dataclass +from typing import TypeVar, Union + +from typing_extensions import Generic + +from starknet_py.net.client_models import ( + BlockHeader, + EmittedEvent, + TransactionStatusResponse, +) +from starknet_py.net.models import ( + DeclareV1, + DeclareV2, + DeclareV3, + DeployAccountV1, + DeployAccountV3, + InvokeV1, + InvokeV3, +) + +T = TypeVar("T") + + +@dataclass +class SubscribeResponse: + """ + Subscription result. + """ + + subscription_id: int + + +@dataclass +class Notification(Generic[T]): + subscription_id: int + result: T + + +@dataclass +class NewHeadsNotification(Notification[BlockHeader]): + """ + Notification to the client of a new block header. + """ + + +@dataclass +class EventsNotification(Notification[EmittedEvent]): + """ + Notification to the client of a new event. + """ + + +@dataclass +class NewTransactionStatus: + transaction_hash: int + status: TransactionStatusResponse + + +@dataclass +class TransactionStatusNotification(Notification[NewTransactionStatus]): + """ + Notification to the client of a new transaction status. + """ + + +Transaction = Union[ + DeclareV1, + DeclareV2, + DeclareV3, + DeployAccountV1, + DeployAccountV3, + InvokeV1, + InvokeV3, +] + + +@dataclass +class PendingTransactionsNotification(Notification[Union[int, Transaction]]): + """ + Notification to the client of a new pending transaction. + """ + + +@dataclass +class UnsubscribeResponse: + """ + Unsubscription result. + """ + + result: bool + + +@dataclass +class ReorgData: + """ + Data about reorganized blocks, starting and ending block number and hash. + """ + + starting_block_hash: int + starting_block_number: int + ending_block_hash: int + ending_block_number: int + + +@dataclass +class ReorgNotification(Notification[ReorgData]): + """ + Notification of a reorganization of the chain. + """ diff --git a/starknet_py/tests/e2e/fixtures/clients.py b/starknet_py/tests/e2e/fixtures/clients.py index 42a413518..8897c0a85 100644 --- a/starknet_py/tests/e2e/fixtures/clients.py +++ b/starknet_py/tests/e2e/fixtures/clients.py @@ -1,8 +1,37 @@ +from typing import AsyncGenerator + import pytest +import pytest_asyncio from starknet_py.net.full_node_client import FullNodeClient +from starknet_py.net.full_node_ws_client import FullNodeWSClient +from starknet_py.net.ws_client import WSClient @pytest.fixture(name="client", scope="package") -def create_full_node_client(devnet) -> FullNodeClient: +def create_full_node_client(devnet: str) -> FullNodeClient: return FullNodeClient(node_url=devnet + "/rpc") + + +@pytest_asyncio.fixture(scope="package") +async def ws_client(devnet_ws: str) -> AsyncGenerator[WSClient, None]: + """ + Connects `WSClient` to devnet, returns its instance and disconnects after the tests. + """ + client = WSClient(devnet_ws) + + await client.connect() + yield client + await client.disconnect() + + +@pytest_asyncio.fixture(scope="package") +async def full_node_ws_client(devnet_ws: str) -> AsyncGenerator[FullNodeWSClient, None]: + """ + Connects `FullNodeWSClient` client to devnet, returns its instance and disconnects after the tests. + """ + client = FullNodeWSClient(devnet_ws) + + await client.connect() + yield client + await client.disconnect()