From 9164275a546206899a8ec0a19a124c22f4351ef3 Mon Sep 17 00:00:00 2001 From: Fiiranek Date: Wed, 6 Nov 2024 20:01:00 +0100 Subject: [PATCH 01/12] Add ws-related dataclasses and schemas --- starknet_py/net/schemas/rpc/ws.py | 91 ++++++++++++++ starknet_py/net/ws_full_node_client_models.py | 113 ++++++++++++++++++ 2 files changed, 204 insertions(+) create mode 100644 starknet_py/net/schemas/rpc/ws.py create mode 100644 starknet_py/net/ws_full_node_client_models.py diff --git a/starknet_py/net/schemas/rpc/ws.py b/starknet_py/net/schemas/rpc/ws.py new file mode 100644 index 000000000..dfde0eb20 --- /dev/null +++ b/starknet_py/net/schemas/rpc/ws.py @@ -0,0 +1,91 @@ +from marshmallow import Schema, fields, post_load + +from starknet_py.net.schemas.rpc.block import BlockHeaderSchema +from starknet_py.net.ws_full_node_client_models import ( + EventsNotification, + NewHeadsNotification, + NewTransactionStatus, + PendingTransactionsNotification, + ReorgData, + SubscribeResponse, + TransactionStatusNotification, + UnsubscribeResponse, +) + + +class SubscribeResponseSchema(Schema): + subscription_id = fields.Integer(data_key="subscription_id", required=True) + + @post_load + def make_dataclass(self, data, **kwargs) -> SubscribeResponse: + return SubscribeResponse(**data) + + +class NewHeadsNotificationSchema(Schema): + subscription_id = fields.Integer(data_key="subscription_id", required=True) + result = fields.Nested(BlockHeaderSchema(), data_key="result", required=True) + + @post_load + def make_dataclass(self, data, **kwargs) -> NewHeadsNotification: + return NewHeadsNotification(**data) + + +class EventsNotificationSchema(Schema): + subscription_id = fields.Integer(data_key="subscription_id", required=True) + result = fields.List(fields.Dict(), data_key="result", required=True) + + @post_load + def make_dataclass(self, data, **kwargs) -> EventsNotification: + return EventsNotification(**data) + + +class NewTransactionStatusSchema(Schema): + transaction_hash = fields.Integer(data_key="transaction_hash", required=True) + status = fields.Dict(data_key="status", required=True) + + @post_load + def make_dataclass(self, data, **kwargs) -> NewTransactionStatus: + return NewTransactionStatus(**data) + + +class TransactionStatusNotificationSchema(Schema): + subscription_id = fields.Integer(data_key="subscription_id", required=True) + result = fields.Nested( + NewTransactionStatusSchema(), data_key="result", required=True + ) + + @post_load + def make_dataclass(self, data, **kwargs) -> TransactionStatusNotification: + return TransactionStatusNotification(**data) + + +class PendingTransactionsNotificationSchema(Schema): + subscription_id = fields.Integer(data_key="subscription_id", required=True) + result = fields.Dict(data_key="result", required=True) + + @post_load + def make_dataclass(self, data, **kwargs) -> PendingTransactionsNotification: + return PendingTransactionsNotification(**data) + + +class UnsubscribeResponseSchema(Schema): + result = fields.Boolean(data_key="result", required=True) + + @post_load + def make_dataclass(self, data, **kwargs) -> UnsubscribeResponse: + return UnsubscribeResponse(**data) + + +class ReorgDataSchema(Schema): + starting_block_hash = fields.Integer(data_key="starting_block_hash", required=True) + starting_block_number = fields.Integer( + data_key="starting_block_number", required=True, validate=lambda x: x >= 0 + ) + ending_block_hash = fields.Integer(data_key="ending_block_hash", required=True) + ending_block_number = fields.Integer( + data_key="ending_block_number", required=True, validate=lambda x: x >= 0 + ) + + @post_load + def make_dataclass(self, data, **kwargs) -> ReorgData: + return ReorgData(**data) diff --git a/starknet_py/net/ws_full_node_client_models.py b/starknet_py/net/ws_full_node_client_models.py new file mode 100644 index 000000000..31cd5a3d2 --- /dev/null +++ b/starknet_py/net/ws_full_node_client_models.py @@ -0,0 +1,113 @@ +""" +Dataclasses representing responses from Starknet Websocket RPC API. +""" + +from dataclasses import dataclass +from typing import TypeVar, Union + +from typing_extensions import Generic + +from starknet_py.net.client_models import ( + BlockHeader, + EmittedEvent, + TransactionStatusResponse, +) +from starknet_py.net.models import ( + DeclareV1, + DeclareV2, + DeclareV3, + DeployAccountV1, + DeployAccountV3, + InvokeV1, + InvokeV3, +) + +T = TypeVar("T") + + +@dataclass +class SubscribeResponse: + """ + Subscription result. + """ + + subscription_id: int + + +@dataclass +class Notification(Generic[T]): + subscription_id: int + result: T + + +@dataclass +class NewHeadsNotification(Notification[BlockHeader]): + """ + Notification to the client of a new block header. + """ + + +@dataclass +class EventsNotification(Notification[EmittedEvent]): + """ + Notification to the client of a new event. + """ + + +@dataclass +class NewTransactionStatus: + transaction_hash: int + status: TransactionStatusResponse + + +@dataclass +class TransactionStatusNotification(Notification[NewTransactionStatus]): + """ + Notification to the client of a new transaction status. + """ + + +Transaction = Union[ + DeclareV1, + DeclareV2, + DeclareV3, + DeployAccountV1, + DeployAccountV3, + InvokeV1, + InvokeV3, +] + + +@dataclass +class PendingTransactionsNotification(Notification[Union[int, Transaction]]): + """ + Notification to the client of a new pending transaction. + """ + + +@dataclass +class UnsubscribeResponse: + """ + Unsubscription result. + """ + + result: bool + + +@dataclass +class ReorgData: + """ + Data about reorganized blocks, starting and ending block number and hash. + """ + + starting_block_hash: int + starting_block_number: int + ending_block_hash: int + ending_block_number: int + + +@dataclass +class ReorgNotification(Notification[ReorgData]): + """ + Notification of a reorganization of the chain. + """ From 412453750f5502c0fc45ba1ef653616a3e5807f8 Mon Sep 17 00:00:00 2001 From: Fiiranek Date: Wed, 6 Nov 2024 20:01:50 +0100 Subject: [PATCH 02/12] Add `WSClient` and `RpcWSClient` --- starknet_py/net/ws_client.py | 64 ++++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) create mode 100644 starknet_py/net/ws_client.py diff --git a/starknet_py/net/ws_client.py b/starknet_py/net/ws_client.py new file mode 100644 index 000000000..a0dbc9c12 --- /dev/null +++ b/starknet_py/net/ws_client.py @@ -0,0 +1,64 @@ +import json +from typing import Any, Dict, Optional, Union, cast + +from websockets.asyncio.client import ClientConnection, connect + +from starknet_py.net.http_client import RpcHttpClient + + +class WSClient: + """ + Base class for WebSocket clients. + """ + + def __init__(self, node_url: str): + """ + :param node_url: URL of the node providing the WebSocket API. + """ + self.node_url: str = node_url + self.connection: Union[None, ClientConnection] = None + + async def connect(self): + """Establishes the WebSocket connection.""" + self.connection = await connect(self.node_url) + + async def disconnect(self): + """Closes the WebSocket connection.""" + assert self.connection is not None + await self.connection.close() + + async def send_raw( + self, + payload: Optional[Dict[str, Any]] = None, + ): + assert self.connection is not None + await self.connection.send(json.dumps(payload)) + data = await self.connection.recv() + + return data + + +class RpcWSClient(WSClient): + """ + WebSocket client for the RPC API. + """ + + async def send( + self, + method: str, + params: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + payload = { + "id": 0, + "jsonrpc": "2.0", + "method": method, + "params": params if params else [], + } + + data = await self.send_raw(payload) + data = cast(Dict, json.loads(data)) + + if "result" not in data: + RpcHttpClient.handle_rpc_error(data) + + return data From 2f1a66078dc1f3677b3b18b3072bdb4522d85ffc Mon Sep 17 00:00:00 2001 From: Fiiranek Date: Thu, 7 Nov 2024 01:16:51 +0100 Subject: [PATCH 03/12] Add `FullNodeWSClient` --- starknet_py/net/full_node_ws_client.py | 231 +++++++++++++++++++++++++ starknet_py/net/schemas/rpc/ws.py | 10 +- starknet_py/net/ws_client.py | 22 ++- 3 files changed, 257 insertions(+), 6 deletions(-) create mode 100644 starknet_py/net/full_node_ws_client.py diff --git a/starknet_py/net/full_node_ws_client.py b/starknet_py/net/full_node_ws_client.py new file mode 100644 index 000000000..7c8a2c722 --- /dev/null +++ b/starknet_py/net/full_node_ws_client.py @@ -0,0 +1,231 @@ +from typing import Any, Callable, Dict, List, Optional, Union, cast + +from starknet_py.net.client_models import BlockHeader, EmittedEvent, Hash, Tag +from starknet_py.net.http_client import RpcHttpClient +from starknet_py.net.schemas.rpc.ws import ( + EventsNotificationSchema, + NewHeadsNotificationSchema, + PendingTransactionsNotificationSchema, + SubscribeResponseSchema, + TransactionStatusNotificationSchema, + UnsubscribeResponseSchema, +) +from starknet_py.net.ws_client import RpcWSClient +from starknet_py.net.ws_full_node_client_models import ( + EventsNotification, + NewHeadsNotification, + NewTransactionStatus, + PendingTransactionsNotification, + SubscribeResponse, + Transaction, + TransactionStatusNotification, + UnsubscribeResponse, +) + +BlockId = Union[int, Hash, Tag] + + +class FullNodeWSClient: + """ + Starknet WebSocket client for RPC API. + """ + + def __init__(self, node_url: str): + """ + :param node_url: URL of the node providing the WebSocket API. + """ + self.node_url: str = node_url + self._rpc_ws_client: RpcWSClient = RpcWSClient(node_url) + self._subscriptions: Dict[int, Callable[[Any], None]] = {} + + async def connect(self): + """ + Establishes the WebSocket connection. + """ + await self._rpc_ws_client.connect() + + async def disconnect(self): + """ + Closes the WebSocket connection. + """ + await self._rpc_ws_client.disconnect() + + async def _subscribe( + self, + handler: Callable[[Any], Any], + method: str, + params: Optional[Dict[str, Any]] = None, + ) -> int: + data = await self._rpc_ws_client.send(method, params) + response = cast( + SubscribeResponse, + SubscribeResponseSchema().load(data), + ) + + self._subscriptions[response.subscription_id] = handler + + return response.subscription_id + + async def listen(self): + """ + Listens for incoming WebSocket messages. + """ + await self._rpc_ws_client.listen(self._handle_received_message) + + def _handle_received_message(self, message: Dict): + print(message) + if "params" not in message: + # TODO(#1498): Possibly move `handle_rpc_error` from `RpcHttpClient` to separate function + RpcHttpClient.handle_rpc_error(message) + + subscription_id = message["params"]["subscription_id"] + + if subscription_id not in self._subscriptions: + return + + handler = self._subscriptions[subscription_id] + method = message["method"] + + if method == "starknet_subscriptionNewHeads": + notification = cast( + NewHeadsNotification, + NewHeadsNotificationSchema().load(message["params"]), + ) + handler(notification.result) + + elif method == "starknet_subscriptionEvents": + notification = cast( + EventsNotification, + EventsNotificationSchema().load(message["params"]), + ) + handler(notification.result) + + elif method == "starknet_subscriptionTransactionStatus": + notification = cast( + TransactionStatusNotification, + TransactionStatusNotificationSchema().load(message["params"]), + ) + handler(notification.result) + + elif method == "starknet_subscriptionPendingTransactions": + notification = cast( + PendingTransactionsNotification, + PendingTransactionsNotificationSchema().load(message["params"]), + ) + handler(notification.result) + + elif method == "starknet_subscriptionReorg": + # TODO(#1498): Implement reorg handling once inconsistencies in spec are resolved + pass + + async def subscribe_new_heads( + self, + handler: Callable[[BlockHeader], Any], + block: Optional[BlockId], + ) -> int: + """ + Creates a WebSocket stream which will fire events for new block headers. + + :param handler: The function to call when a new block header is received. + :param block: The block to get notifications from, default is latest, limited to 1024 blocks back. + :return: The subscription ID. + """ + params = {"block": block} if block else {} + subscription_id = await self._subscribe( + handler, "starknet_subscribeNewHeads", params + ) + + return subscription_id + + async def subscribe_events( + self, + handler: Callable[[EmittedEvent], Any], + from_address: Optional[int] = None, + keys: Optional[List[List[int]]] = None, + block: Optional[BlockId] = None, + ) -> int: + """ + Creates a WebSocket stream which will fire events for new Starknet events with applied filters. + + :param handler: The function to call when a new event is received. + :param from_address: Address which emitted the event. + :param keys: The keys to filter events by. + :param block: The block to get notifications from, default is latest, limited to 1024 blocks back. + :return: The subscription ID. + """ + params = {"from_address": from_address, "keys": keys, "block": block} + # params = {"block": block} + subscription_id = await self._subscribe( + handler, "starknet_subscribeEvents", params + ) + + return subscription_id + + async def subscribe_transaction_status( + self, + handler: Callable[[NewTransactionStatus], Any], + transaction_hash: int, + block: Optional[BlockId], + ) -> int: + """ + Creates a WebSocket stream which will fire events when a transaction status is updated. + + :param handler: The function to call when a new transaction status is received. + :param transaction_hash: The transaction hash to fetch status updates for. + :param block: The block to get notifications from, default is latest, limited to 1024 blocks back. + :return: The subscription ID. + """ + params = {"transaction_hash": transaction_hash, "block": block} + subscription_id = await self._subscribe( + handler, "starknet_subscribeTransactionStatus", params + ) + + return subscription_id + + async def subscribe_pending_transactions( + self, + handler: Callable[[Union[int, Transaction]], Any], + transaction_details: Optional[bool], + sender_address: Optional[List[int]], + ) -> int: + """ + Creates a WebSocket stream which will fire events when a new pending transaction is added. + While there is no mempool, this notifies of transactions in the pending block. + + :param handler: The function to call when a new pending transaction is received. + :param transaction_details: Whether to include transaction details in the notification. + If false, only hash is returned. + :param sender_address: The sender address to filter transactions by. + :return: The subscription ID. + """ + params = { + "transaction_details": transaction_details, + "sender_address": sender_address, + } + subscription_id = await self._subscribe( + handler, "starknet_subscribePendingTransactions", params + ) + + return subscription_id + + async def unsubscribe(self, subscription_id: int) -> bool: + """ + Close a previously opened WebSocket stream, with the corresponding subscription id. + + :param subscription_id: ID of the subscription to close. + :return: True if the unsubscription was successful, False otherwise. + """ + if subscription_id not in self._subscriptions: + return False + + params = {"subscription_id": subscription_id} + res = await self._rpc_ws_client.send("starknet_unsubscribe", params) + + unsubscribe_response = cast( + UnsubscribeResponse, UnsubscribeResponseSchema().load(res) + ) + + if unsubscribe_response: + del self._subscriptions[subscription_id] + + return unsubscribe_response.result diff --git a/starknet_py/net/schemas/rpc/ws.py b/starknet_py/net/schemas/rpc/ws.py index dfde0eb20..2868ed39e 100644 --- a/starknet_py/net/schemas/rpc/ws.py +++ b/starknet_py/net/schemas/rpc/ws.py @@ -1,6 +1,8 @@ from marshmallow import Schema, fields, post_load +from starknet_py.net.schemas.common import Felt from starknet_py.net.schemas.rpc.block import BlockHeaderSchema +from starknet_py.net.schemas.rpc.event import EmittedEventSchema from starknet_py.net.ws_full_node_client_models import ( EventsNotification, NewHeadsNotification, @@ -32,7 +34,7 @@ def make_dataclass(self, data, **kwargs) -> NewHeadsNotification: class EventsNotificationSchema(Schema): subscription_id = fields.Integer(data_key="subscription_id", required=True) - result = fields.List(fields.Dict(), data_key="result", required=True) + result = fields.Nested(EmittedEventSchema(), data_key="result", required=True) @post_load def make_dataclass(self, data, **kwargs) -> EventsNotification: @@ -40,7 +42,7 @@ def make_dataclass(self, data, **kwargs) -> EventsNotification: class NewTransactionStatusSchema(Schema): - transaction_hash = fields.Integer(data_key="transaction_hash", required=True) + transaction_hash = Felt(data_key="transaction_hash", required=True) status = fields.Dict(data_key="status", required=True) @post_load @@ -77,11 +79,11 @@ def make_dataclass(self, data, **kwargs) -> UnsubscribeResponse: class ReorgDataSchema(Schema): - starting_block_hash = fields.Integer(data_key="starting_block_hash", required=True) + starting_block_hash = Felt(data_key="starting_block_hash", required=True) starting_block_number = fields.Integer( data_key="starting_block_number", required=True, validate=lambda x: x >= 0 ) - ending_block_hash = fields.Integer(data_key="ending_block_hash", required=True) + ending_block_hash = Felt(data_key="ending_block_hash", required=True) ending_block_number = fields.Integer( data_key="ending_block_number", required=True, validate=lambda x: x >= 0 ) diff --git a/starknet_py/net/ws_client.py b/starknet_py/net/ws_client.py index a0dbc9c12..f7a22c079 100644 --- a/starknet_py/net/ws_client.py +++ b/starknet_py/net/ws_client.py @@ -1,5 +1,5 @@ import json -from typing import Any, Dict, Optional, Union, cast +from typing import Any, Callable, Dict, Optional, Union, cast from websockets.asyncio.client import ClientConnection, connect @@ -31,12 +31,29 @@ async def send_raw( self, payload: Optional[Dict[str, Any]] = None, ): + """ + Sends a message to the WebSocket server and returns the response. + + :param payload: The message to send. + """ assert self.connection is not None await self.connection.send(json.dumps(payload)) data = await self.connection.recv() return data + async def listen(self, received_message_handler: Callable[[Dict[str, Any]], Any]): + """ + Listens for incoming WebSocket messages. + + :param received_message_handler: The function to call when a message is received. + """ + assert self.connection is not None + + async for message in self.connection: + message = cast(Dict, json.loads(message)) + received_message_handler(message) + class RpcWSClient(WSClient): """ @@ -59,6 +76,7 @@ async def send( data = cast(Dict, json.loads(data)) if "result" not in data: + # TODO(#1498): Possibly move `handle_rpc_error` from `RpcHttpClient` to separate function RpcHttpClient.handle_rpc_error(data) - return data + return data["result"] From 067e5836407b96e575c06621a028793ffe90a6d4 Mon Sep 17 00:00:00 2001 From: Fiiranek Date: Thu, 7 Nov 2024 01:20:49 +0100 Subject: [PATCH 04/12] Cleanup `FullNodeWSClient` --- starknet_py/net/full_node_ws_client.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/starknet_py/net/full_node_ws_client.py b/starknet_py/net/full_node_ws_client.py index 7c8a2c722..02ddb101b 100644 --- a/starknet_py/net/full_node_ws_client.py +++ b/starknet_py/net/full_node_ws_client.py @@ -73,7 +73,6 @@ async def listen(self): await self._rpc_ws_client.listen(self._handle_received_message) def _handle_received_message(self, message: Dict): - print(message) if "params" not in message: # TODO(#1498): Possibly move `handle_rpc_error` from `RpcHttpClient` to separate function RpcHttpClient.handle_rpc_error(message) @@ -154,7 +153,6 @@ async def subscribe_events( :return: The subscription ID. """ params = {"from_address": from_address, "keys": keys, "block": block} - # params = {"block": block} subscription_id = await self._subscribe( handler, "starknet_subscribeEvents", params ) From 3648885395c12b4790af772269e165d3fc3e884c Mon Sep 17 00:00:00 2001 From: Fiiranek Date: Thu, 7 Nov 2024 01:30:21 +0100 Subject: [PATCH 05/12] Refactor `FullNodeWSClient._handle_received_message()` --- starknet_py/net/full_node_ws_client.py | 39 +++++++++++++------------- 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/starknet_py/net/full_node_ws_client.py b/starknet_py/net/full_node_ws_client.py index 02ddb101b..163acdc54 100644 --- a/starknet_py/net/full_node_ws_client.py +++ b/starknet_py/net/full_node_ws_client.py @@ -85,31 +85,30 @@ def _handle_received_message(self, message: Dict): handler = self._subscriptions[subscription_id] method = message["method"] - if method == "starknet_subscriptionNewHeads": - notification = cast( + method_to_model_mapping = { + "starknet_subscriptionNewHeads": ( NewHeadsNotification, - NewHeadsNotificationSchema().load(message["params"]), - ) - handler(notification.result) - - elif method == "starknet_subscriptionEvents": - notification = cast( + NewHeadsNotificationSchema, + ), + "starknet_subscriptionEvents": ( EventsNotification, - EventsNotificationSchema().load(message["params"]), - ) - handler(notification.result) - - elif method == "starknet_subscriptionTransactionStatus": - notification = cast( + EventsNotificationSchema, + ), + "starknet_subscriptionTransactionStatus": ( TransactionStatusNotification, - TransactionStatusNotificationSchema().load(message["params"]), - ) - handler(notification.result) + TransactionStatusNotificationSchema, + ), + "starknet_subscriptionPendingTransactions": ( + PendingTransactionsNotification, + PendingTransactionsNotificationSchema, + ), + } - elif method == "starknet_subscriptionPendingTransactions": + if method in method_to_model_mapping: + notification, notification_schema = method_to_model_mapping[method] notification = cast( - PendingTransactionsNotification, - PendingTransactionsNotificationSchema().load(message["params"]), + notification, + notification_schema().load(message["params"]), ) handler(notification.result) From 9ea6c90ab534893896a4b116784c3407661475b2 Mon Sep 17 00:00:00 2001 From: Fiiranek Date: Fri, 8 Nov 2024 00:57:19 +0100 Subject: [PATCH 06/12] Add `ReorgNotificationSchema` --- starknet_py/net/full_node_ws_client.py | 11 ++++++----- starknet_py/net/schemas/rpc/ws.py | 10 ++++++++++ 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/starknet_py/net/full_node_ws_client.py b/starknet_py/net/full_node_ws_client.py index 163acdc54..c6ae32c9e 100644 --- a/starknet_py/net/full_node_ws_client.py +++ b/starknet_py/net/full_node_ws_client.py @@ -8,7 +8,7 @@ PendingTransactionsNotificationSchema, SubscribeResponseSchema, TransactionStatusNotificationSchema, - UnsubscribeResponseSchema, + UnsubscribeResponseSchema, ReorgNotificationSchema, ) from starknet_py.net.ws_client import RpcWSClient from starknet_py.net.ws_full_node_client_models import ( @@ -16,6 +16,7 @@ NewHeadsNotification, NewTransactionStatus, PendingTransactionsNotification, + ReorgNotification, SubscribeResponse, Transaction, TransactionStatusNotification, @@ -102,6 +103,10 @@ def _handle_received_message(self, message: Dict): PendingTransactionsNotification, PendingTransactionsNotificationSchema, ), + "starknet_subscriptionReorg": ( + ReorgNotification, + ReorgNotificationSchema, + ), } if method in method_to_model_mapping: @@ -112,10 +117,6 @@ def _handle_received_message(self, message: Dict): ) handler(notification.result) - elif method == "starknet_subscriptionReorg": - # TODO(#1498): Implement reorg handling once inconsistencies in spec are resolved - pass - async def subscribe_new_heads( self, handler: Callable[[BlockHeader], Any], diff --git a/starknet_py/net/schemas/rpc/ws.py b/starknet_py/net/schemas/rpc/ws.py index 2868ed39e..439c321b6 100644 --- a/starknet_py/net/schemas/rpc/ws.py +++ b/starknet_py/net/schemas/rpc/ws.py @@ -9,6 +9,7 @@ NewTransactionStatus, PendingTransactionsNotification, ReorgData, + ReorgNotification, SubscribeResponse, TransactionStatusNotification, UnsubscribeResponse, @@ -91,3 +92,12 @@ class ReorgDataSchema(Schema): @post_load def make_dataclass(self, data, **kwargs) -> ReorgData: return ReorgData(**data) + + +class ReorgNotificationSchema(Schema): + subscription_id = fields.Integer(data_key="subscription_id", required=True) + result = fields.Nested(ReorgDataSchema(), data_key="result", required=True) + + @post_load + def make_dataclass(self, data, **kwargs) -> ReorgNotification: + return ReorgNotification(**data) From e845e508071838300550c66d04cda93685cc22e2 Mon Sep 17 00:00:00 2001 From: Fiiranek Date: Fri, 8 Nov 2024 12:47:47 +0100 Subject: [PATCH 07/12] Add `devnet_ws` fixture --- starknet_py/tests/e2e/fixtures/devnet.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/starknet_py/tests/e2e/fixtures/devnet.py b/starknet_py/tests/e2e/fixtures/devnet.py index a1cbd2376..23ee40277 100644 --- a/starknet_py/tests/e2e/fixtures/devnet.py +++ b/starknet_py/tests/e2e/fixtures/devnet.py @@ -7,6 +7,7 @@ import pytest +from starknet_py.net.full_node_ws_client import FullNodeWSClient from starknet_py.tests.e2e.fixtures.constants import SEPOLIA_RPC_URL @@ -63,6 +64,19 @@ def devnet() -> Generator[str, None, None]: proc.kill() +@pytest.fixture(scope="package") +async def devnet_ws(devnet) -> Generator[str, None, None]: + """ + Connects WebSocket client to devnet, returns its instance and disconnects after the tests. + """ + ws_node_url = devnet.replace("http", "ws") + "/ws" + ws_client = FullNodeWSClient(ws_node_url) + + await ws_client.connect() + yield ws_client + await ws_client.disconnect() + + @pytest.fixture(scope="package") def devnet_forking_mode() -> Generator[str, None, None]: """ From 71ff6a453788095aa997c6e146058fc92fc28ed1 Mon Sep 17 00:00:00 2001 From: Fiiranek Date: Fri, 8 Nov 2024 13:42:52 +0100 Subject: [PATCH 08/12] Add `devnet_ws_client` fixture --- starknet_py/net/full_node_ws_client.py | 3 ++- starknet_py/tests/e2e/fixtures/clients.py | 15 +++++++++++++++ starknet_py/tests/e2e/fixtures/devnet.py | 12 +++--------- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/starknet_py/net/full_node_ws_client.py b/starknet_py/net/full_node_ws_client.py index c6ae32c9e..a4ac80089 100644 --- a/starknet_py/net/full_node_ws_client.py +++ b/starknet_py/net/full_node_ws_client.py @@ -6,9 +6,10 @@ EventsNotificationSchema, NewHeadsNotificationSchema, PendingTransactionsNotificationSchema, + ReorgNotificationSchema, SubscribeResponseSchema, TransactionStatusNotificationSchema, - UnsubscribeResponseSchema, ReorgNotificationSchema, + UnsubscribeResponseSchema, ) from starknet_py.net.ws_client import RpcWSClient from starknet_py.net.ws_full_node_client_models import ( diff --git a/starknet_py/tests/e2e/fixtures/clients.py b/starknet_py/tests/e2e/fixtures/clients.py index 42a413518..b978e5063 100644 --- a/starknet_py/tests/e2e/fixtures/clients.py +++ b/starknet_py/tests/e2e/fixtures/clients.py @@ -1,8 +1,23 @@ +from typing import Generator + import pytest from starknet_py.net.full_node_client import FullNodeClient +from starknet_py.net.full_node_ws_client import FullNodeWSClient @pytest.fixture(name="client", scope="package") def create_full_node_client(devnet) -> FullNodeClient: return FullNodeClient(node_url=devnet + "/rpc") + + +@pytest.fixture(scope="package") +async def full_node_ws_client(devnet_ws) -> Generator[str, None, None]: + """ + Connects WebSocket client to devnet, returns its instance and disconnects after the tests. + """ + ws_client = FullNodeWSClient(devnet_ws) + + await ws_client.connect() + yield ws_client + await ws_client.disconnect() diff --git a/starknet_py/tests/e2e/fixtures/devnet.py b/starknet_py/tests/e2e/fixtures/devnet.py index 23ee40277..e37c57d8d 100644 --- a/starknet_py/tests/e2e/fixtures/devnet.py +++ b/starknet_py/tests/e2e/fixtures/devnet.py @@ -7,7 +7,6 @@ import pytest -from starknet_py.net.full_node_ws_client import FullNodeWSClient from starknet_py.tests.e2e.fixtures.constants import SEPOLIA_RPC_URL @@ -65,16 +64,11 @@ def devnet() -> Generator[str, None, None]: @pytest.fixture(scope="package") -async def devnet_ws(devnet) -> Generator[str, None, None]: +def devnet_ws(devnet) -> Generator[str, None, None]: """ - Connects WebSocket client to devnet, returns its instance and disconnects after the tests. + Returns WebSocket address of devnet. """ - ws_node_url = devnet.replace("http", "ws") + "/ws" - ws_client = FullNodeWSClient(ws_node_url) - - await ws_client.connect() - yield ws_client - await ws_client.disconnect() + yield devnet.replace("http", "ws") + "/ws" @pytest.fixture(scope="package") From faa97ec48b41577772739710fbd600096425160c Mon Sep 17 00:00:00 2001 From: Fiiranek Date: Fri, 8 Nov 2024 13:55:47 +0100 Subject: [PATCH 09/12] Add `ws_client` and `full_node_ws_client` fixtures --- starknet_py/tests/e2e/fixtures/clients.py | 28 +++++++++++++++++------ 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/starknet_py/tests/e2e/fixtures/clients.py b/starknet_py/tests/e2e/fixtures/clients.py index b978e5063..d2dbdcb77 100644 --- a/starknet_py/tests/e2e/fixtures/clients.py +++ b/starknet_py/tests/e2e/fixtures/clients.py @@ -1,9 +1,11 @@ from typing import Generator import pytest +import pytest_asyncio from starknet_py.net.full_node_client import FullNodeClient from starknet_py.net.full_node_ws_client import FullNodeWSClient +from starknet_py.net.ws_client import WSClient @pytest.fixture(name="client", scope="package") @@ -11,13 +13,25 @@ def create_full_node_client(devnet) -> FullNodeClient: return FullNodeClient(node_url=devnet + "/rpc") -@pytest.fixture(scope="package") -async def full_node_ws_client(devnet_ws) -> Generator[str, None, None]: +@pytest_asyncio.fixture(scope="package") +async def ws_client(devnet_ws) -> Generator[WSClient, None, None]: """ - Connects WebSocket client to devnet, returns its instance and disconnects after the tests. + Connects `WSClient` to devnet, returns its instance and disconnects after the tests. """ - ws_client = FullNodeWSClient(devnet_ws) + client = WSClient(devnet_ws) - await ws_client.connect() - yield ws_client - await ws_client.disconnect() + await client.connect() + yield client + await client.disconnect() + + +@pytest_asyncio.fixture(scope="package") +async def full_node_ws_client(devnet_ws) -> Generator[FullNodeWSClient, None, None]: + """ + Connects `FullNodeWSClient` client to devnet, returns its instance and disconnects after the tests. + """ + client = FullNodeWSClient(devnet_ws) + + await client.connect() + yield client + await client.disconnect() From 7b4dd76dcd98ef4510d6cdedb4cf87ab30ae2fc5 Mon Sep 17 00:00:00 2001 From: Fiiranek Date: Fri, 8 Nov 2024 14:27:15 +0100 Subject: [PATCH 10/12] Move `devnet_ws` to separate file --- starknet_py/conftest.py | 1 + starknet_py/net/full_node_ws_client.py | 4 +++- starknet_py/net/ws_client.py | 2 +- starknet_py/tests/e2e/fixtures/devnet.py | 8 -------- 4 files changed, 5 insertions(+), 10 deletions(-) diff --git a/starknet_py/conftest.py b/starknet_py/conftest.py index 23206cbe0..e1a8f0ac5 100644 --- a/starknet_py/conftest.py +++ b/starknet_py/conftest.py @@ -7,6 +7,7 @@ "starknet_py.tests.e2e.fixtures.contracts_v1", "starknet_py.tests.e2e.fixtures.misc", "starknet_py.tests.e2e.fixtures.devnet", + "starknet_py.tests.e2e.fixtures.devnet_ws", "starknet_py.tests.e2e.fixtures.constants", "starknet_py.tests.e2e.client.fixtures.transactions", "starknet_py.tests.e2e.client.fixtures.prepare_network", diff --git a/starknet_py/net/full_node_ws_client.py b/starknet_py/net/full_node_ws_client.py index a4ac80089..1ee52ecfe 100644 --- a/starknet_py/net/full_node_ws_client.py +++ b/starknet_py/net/full_node_ws_client.py @@ -1,6 +1,7 @@ from typing import Any, Callable, Dict, List, Optional, Union, cast from starknet_py.net.client_models import BlockHeader, EmittedEvent, Hash, Tag +from starknet_py.net.client_utils import _clear_none_values from starknet_py.net.http_client import RpcHttpClient from starknet_py.net.schemas.rpc.ws import ( EventsNotificationSchema, @@ -164,7 +165,7 @@ async def subscribe_transaction_status( self, handler: Callable[[NewTransactionStatus], Any], transaction_hash: int, - block: Optional[BlockId], + block: Optional[BlockId] = None, ) -> int: """ Creates a WebSocket stream which will fire events when a transaction status is updated. @@ -175,6 +176,7 @@ async def subscribe_transaction_status( :return: The subscription ID. """ params = {"transaction_hash": transaction_hash, "block": block} + params = _clear_none_values(params) subscription_id = await self._subscribe( handler, "starknet_subscribeTransactionStatus", params ) diff --git a/starknet_py/net/ws_client.py b/starknet_py/net/ws_client.py index f7a22c079..e13f5de0f 100644 --- a/starknet_py/net/ws_client.py +++ b/starknet_py/net/ws_client.py @@ -30,7 +30,7 @@ async def disconnect(self): async def send_raw( self, payload: Optional[Dict[str, Any]] = None, - ): + ) -> Union[str, bytes]: """ Sends a message to the WebSocket server and returns the response. diff --git a/starknet_py/tests/e2e/fixtures/devnet.py b/starknet_py/tests/e2e/fixtures/devnet.py index e37c57d8d..a1cbd2376 100644 --- a/starknet_py/tests/e2e/fixtures/devnet.py +++ b/starknet_py/tests/e2e/fixtures/devnet.py @@ -63,14 +63,6 @@ def devnet() -> Generator[str, None, None]: proc.kill() -@pytest.fixture(scope="package") -def devnet_ws(devnet) -> Generator[str, None, None]: - """ - Returns WebSocket address of devnet. - """ - yield devnet.replace("http", "ws") + "/ws" - - @pytest.fixture(scope="package") def devnet_forking_mode() -> Generator[str, None, None]: """ From 9ebc171438ade7914b78de6d4fbccd4eaa5e9d41 Mon Sep 17 00:00:00 2001 From: Fiiranek Date: Fri, 8 Nov 2024 14:44:20 +0100 Subject: [PATCH 11/12] Refactor `FullNodeWSClient._handle_received_message()`; Fix typechecks --- starknet_py/net/full_node_ws_client.py | 57 +++++++++++++---------- starknet_py/tests/e2e/fixtures/clients.py | 8 ++-- 2 files changed, 36 insertions(+), 29 deletions(-) diff --git a/starknet_py/net/full_node_ws_client.py b/starknet_py/net/full_node_ws_client.py index 1ee52ecfe..2ee07766a 100644 --- a/starknet_py/net/full_node_ws_client.py +++ b/starknet_py/net/full_node_ws_client.py @@ -26,7 +26,10 @@ ) BlockId = Union[int, Hash, Tag] - +HandlerNotification = Union[ + NewHeadsNotification, EventsNotification, TransactionStatusNotification, PendingTransactionsNotification, ReorgNotification +] +Handler = Callable[[HandlerNotification], Any] class FullNodeWSClient: """ @@ -39,7 +42,7 @@ def __init__(self, node_url: str): """ self.node_url: str = node_url self._rpc_ws_client: RpcWSClient = RpcWSClient(node_url) - self._subscriptions: Dict[int, Callable[[Any], None]] = {} + self._subscriptions: Dict[int, Handler] = {} async def connect(self): """ @@ -88,36 +91,40 @@ def _handle_received_message(self, message: Dict): handler = self._subscriptions[subscription_id] method = message["method"] - method_to_model_mapping = { - "starknet_subscriptionNewHeads": ( + if method == "starknet_subscriptionNewHeads": + notification = cast( NewHeadsNotification, - NewHeadsNotificationSchema, - ), - "starknet_subscriptionEvents": ( + NewHeadsNotificationSchema().load(message["params"]), + ) + handler(notification) + + elif method == "starknet_subscriptionEvents": + notification = cast( EventsNotification, - EventsNotificationSchema, - ), - "starknet_subscriptionTransactionStatus": ( + EventsNotificationSchema().load(message["params"]), + ) + handler(notification) + + elif method == "starknet_subscriptionTransactionStatus": + notification = cast( TransactionStatusNotification, - TransactionStatusNotificationSchema, - ), - "starknet_subscriptionPendingTransactions": ( + TransactionStatusNotificationSchema().load(message["params"]), + ) + handler(notification) + + elif method == "starknet_subscriptionPendingTransactions": + notification = cast( PendingTransactionsNotification, - PendingTransactionsNotificationSchema, - ), - "starknet_subscriptionReorg": ( - ReorgNotification, - ReorgNotificationSchema, - ), - } + PendingTransactionsNotificationSchema().load(message["params"]), + ) + handler(notification) - if method in method_to_model_mapping: - notification, notification_schema = method_to_model_mapping[method] + elif method == "starknet_subscriptionReorg": notification = cast( - notification, - notification_schema().load(message["params"]), + ReorgNotification, + ReorgNotificationSchema().load(message["params"]), ) - handler(notification.result) + handler(notification) async def subscribe_new_heads( self, diff --git a/starknet_py/tests/e2e/fixtures/clients.py b/starknet_py/tests/e2e/fixtures/clients.py index d2dbdcb77..e2d7505f8 100644 --- a/starknet_py/tests/e2e/fixtures/clients.py +++ b/starknet_py/tests/e2e/fixtures/clients.py @@ -1,4 +1,4 @@ -from typing import Generator +from typing import Generator, AsyncGenerator import pytest import pytest_asyncio @@ -9,12 +9,12 @@ @pytest.fixture(name="client", scope="package") -def create_full_node_client(devnet) -> FullNodeClient: +def create_full_node_client(devnet: str) -> FullNodeClient: return FullNodeClient(node_url=devnet + "/rpc") @pytest_asyncio.fixture(scope="package") -async def ws_client(devnet_ws) -> Generator[WSClient, None, None]: +async def ws_client(devnet_ws: str) -> AsyncGenerator[WSClient, None]: """ Connects `WSClient` to devnet, returns its instance and disconnects after the tests. """ @@ -26,7 +26,7 @@ async def ws_client(devnet_ws) -> Generator[WSClient, None, None]: @pytest_asyncio.fixture(scope="package") -async def full_node_ws_client(devnet_ws) -> Generator[FullNodeWSClient, None, None]: +async def full_node_ws_client(devnet_ws: str) -> AsyncGenerator[FullNodeWSClient, None]: """ Connects `FullNodeWSClient` client to devnet, returns its instance and disconnects after the tests. """ From ddcfc06e18a818470d30a871e015b666387230cb Mon Sep 17 00:00:00 2001 From: Fiiranek Date: Fri, 8 Nov 2024 14:45:04 +0100 Subject: [PATCH 12/12] Fix formatting and linting --- starknet_py/net/full_node_ws_client.py | 7 ++++++- starknet_py/tests/e2e/fixtures/clients.py | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/starknet_py/net/full_node_ws_client.py b/starknet_py/net/full_node_ws_client.py index 2ee07766a..e98449f0b 100644 --- a/starknet_py/net/full_node_ws_client.py +++ b/starknet_py/net/full_node_ws_client.py @@ -27,10 +27,15 @@ BlockId = Union[int, Hash, Tag] HandlerNotification = Union[ - NewHeadsNotification, EventsNotification, TransactionStatusNotification, PendingTransactionsNotification, ReorgNotification + NewHeadsNotification, + EventsNotification, + TransactionStatusNotification, + PendingTransactionsNotification, + ReorgNotification, ] Handler = Callable[[HandlerNotification], Any] + class FullNodeWSClient: """ Starknet WebSocket client for RPC API. diff --git a/starknet_py/tests/e2e/fixtures/clients.py b/starknet_py/tests/e2e/fixtures/clients.py index e2d7505f8..8897c0a85 100644 --- a/starknet_py/tests/e2e/fixtures/clients.py +++ b/starknet_py/tests/e2e/fixtures/clients.py @@ -1,4 +1,4 @@ -from typing import Generator, AsyncGenerator +from typing import AsyncGenerator import pytest import pytest_asyncio