Skip to content

Commit

Permalink
Add FullNodeWSClient
Browse files Browse the repository at this point in the history
  • Loading branch information
franciszekjob committed Nov 7, 2024
1 parent 4124537 commit 2f1a660
Show file tree
Hide file tree
Showing 3 changed files with 257 additions and 6 deletions.
231 changes: 231 additions & 0 deletions starknet_py/net/full_node_ws_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
from typing import Any, Callable, Dict, List, Optional, Union, cast

from starknet_py.net.client_models import BlockHeader, EmittedEvent, Hash, Tag
from starknet_py.net.http_client import RpcHttpClient
from starknet_py.net.schemas.rpc.ws import (
EventsNotificationSchema,
NewHeadsNotificationSchema,
PendingTransactionsNotificationSchema,
SubscribeResponseSchema,
TransactionStatusNotificationSchema,
UnsubscribeResponseSchema,
)
from starknet_py.net.ws_client import RpcWSClient
from starknet_py.net.ws_full_node_client_models import (
EventsNotification,
NewHeadsNotification,
NewTransactionStatus,
PendingTransactionsNotification,
SubscribeResponse,
Transaction,
TransactionStatusNotification,
UnsubscribeResponse,
)

BlockId = Union[int, Hash, Tag]


class FullNodeWSClient:
"""
Starknet WebSocket client for RPC API.
"""

def __init__(self, node_url: str):
"""
:param node_url: URL of the node providing the WebSocket API.
"""
self.node_url: str = node_url
self._rpc_ws_client: RpcWSClient = RpcWSClient(node_url)
self._subscriptions: Dict[int, Callable[[Any], None]] = {}

async def connect(self):
"""
Establishes the WebSocket connection.
"""
await self._rpc_ws_client.connect()

async def disconnect(self):
"""
Closes the WebSocket connection.
"""
await self._rpc_ws_client.disconnect()

async def _subscribe(
self,
handler: Callable[[Any], Any],
method: str,
params: Optional[Dict[str, Any]] = None,
) -> int:
data = await self._rpc_ws_client.send(method, params)
response = cast(
SubscribeResponse,
SubscribeResponseSchema().load(data),
)

self._subscriptions[response.subscription_id] = handler

return response.subscription_id

async def listen(self):
"""
Listens for incoming WebSocket messages.
"""
await self._rpc_ws_client.listen(self._handle_received_message)

def _handle_received_message(self, message: Dict):
print(message)
if "params" not in message:
# TODO(#1498): Possibly move `handle_rpc_error` from `RpcHttpClient` to separate function
RpcHttpClient.handle_rpc_error(message)

subscription_id = message["params"]["subscription_id"]

if subscription_id not in self._subscriptions:
return

handler = self._subscriptions[subscription_id]
method = message["method"]

if method == "starknet_subscriptionNewHeads":
notification = cast(
NewHeadsNotification,
NewHeadsNotificationSchema().load(message["params"]),
)
handler(notification.result)

elif method == "starknet_subscriptionEvents":
notification = cast(
EventsNotification,
EventsNotificationSchema().load(message["params"]),
)
handler(notification.result)

elif method == "starknet_subscriptionTransactionStatus":
notification = cast(
TransactionStatusNotification,
TransactionStatusNotificationSchema().load(message["params"]),
)
handler(notification.result)

elif method == "starknet_subscriptionPendingTransactions":
notification = cast(
PendingTransactionsNotification,
PendingTransactionsNotificationSchema().load(message["params"]),
)
handler(notification.result)

elif method == "starknet_subscriptionReorg":
# TODO(#1498): Implement reorg handling once inconsistencies in spec are resolved
pass

async def subscribe_new_heads(
self,
handler: Callable[[BlockHeader], Any],
block: Optional[BlockId],
) -> int:
"""
Creates a WebSocket stream which will fire events for new block headers.
:param handler: The function to call when a new block header is received.
:param block: The block to get notifications from, default is latest, limited to 1024 blocks back.
:return: The subscription ID.
"""
params = {"block": block} if block else {}
subscription_id = await self._subscribe(
handler, "starknet_subscribeNewHeads", params
)

return subscription_id

async def subscribe_events(
self,
handler: Callable[[EmittedEvent], Any],
from_address: Optional[int] = None,
keys: Optional[List[List[int]]] = None,
block: Optional[BlockId] = None,
) -> int:
"""
Creates a WebSocket stream which will fire events for new Starknet events with applied filters.
:param handler: The function to call when a new event is received.
:param from_address: Address which emitted the event.
:param keys: The keys to filter events by.
:param block: The block to get notifications from, default is latest, limited to 1024 blocks back.
:return: The subscription ID.
"""
params = {"from_address": from_address, "keys": keys, "block": block}
# params = {"block": block}
subscription_id = await self._subscribe(
handler, "starknet_subscribeEvents", params
)

return subscription_id

async def subscribe_transaction_status(
self,
handler: Callable[[NewTransactionStatus], Any],
transaction_hash: int,
block: Optional[BlockId],
) -> int:
"""
Creates a WebSocket stream which will fire events when a transaction status is updated.
:param handler: The function to call when a new transaction status is received.
:param transaction_hash: The transaction hash to fetch status updates for.
:param block: The block to get notifications from, default is latest, limited to 1024 blocks back.
:return: The subscription ID.
"""
params = {"transaction_hash": transaction_hash, "block": block}
subscription_id = await self._subscribe(
handler, "starknet_subscribeTransactionStatus", params
)

return subscription_id

async def subscribe_pending_transactions(
self,
handler: Callable[[Union[int, Transaction]], Any],
transaction_details: Optional[bool],
sender_address: Optional[List[int]],
) -> int:
"""
Creates a WebSocket stream which will fire events when a new pending transaction is added.
While there is no mempool, this notifies of transactions in the pending block.
:param handler: The function to call when a new pending transaction is received.
:param transaction_details: Whether to include transaction details in the notification.
If false, only hash is returned.
:param sender_address: The sender address to filter transactions by.
:return: The subscription ID.
"""
params = {
"transaction_details": transaction_details,
"sender_address": sender_address,
}
subscription_id = await self._subscribe(
handler, "starknet_subscribePendingTransactions", params
)

return subscription_id

async def unsubscribe(self, subscription_id: int) -> bool:
"""
Close a previously opened WebSocket stream, with the corresponding subscription id.
:param subscription_id: ID of the subscription to close.
:return: True if the unsubscription was successful, False otherwise.
"""
if subscription_id not in self._subscriptions:
return False

params = {"subscription_id": subscription_id}
res = await self._rpc_ws_client.send("starknet_unsubscribe", params)

unsubscribe_response = cast(
UnsubscribeResponse, UnsubscribeResponseSchema().load(res)
)

if unsubscribe_response:
del self._subscriptions[subscription_id]

return unsubscribe_response.result
10 changes: 6 additions & 4 deletions starknet_py/net/schemas/rpc/ws.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from marshmallow import Schema, fields, post_load

from starknet_py.net.schemas.common import Felt
from starknet_py.net.schemas.rpc.block import BlockHeaderSchema
from starknet_py.net.schemas.rpc.event import EmittedEventSchema
from starknet_py.net.ws_full_node_client_models import (
EventsNotification,
NewHeadsNotification,
Expand Down Expand Up @@ -32,15 +34,15 @@ def make_dataclass(self, data, **kwargs) -> NewHeadsNotification:

class EventsNotificationSchema(Schema):
subscription_id = fields.Integer(data_key="subscription_id", required=True)
result = fields.List(fields.Dict(), data_key="result", required=True)
result = fields.Nested(EmittedEventSchema(), data_key="result", required=True)

@post_load
def make_dataclass(self, data, **kwargs) -> EventsNotification:
return EventsNotification(**data)


class NewTransactionStatusSchema(Schema):
transaction_hash = fields.Integer(data_key="transaction_hash", required=True)
transaction_hash = Felt(data_key="transaction_hash", required=True)
status = fields.Dict(data_key="status", required=True)

@post_load
Expand Down Expand Up @@ -77,11 +79,11 @@ def make_dataclass(self, data, **kwargs) -> UnsubscribeResponse:


class ReorgDataSchema(Schema):
starting_block_hash = fields.Integer(data_key="starting_block_hash", required=True)
starting_block_hash = Felt(data_key="starting_block_hash", required=True)
starting_block_number = fields.Integer(
data_key="starting_block_number", required=True, validate=lambda x: x >= 0
)
ending_block_hash = fields.Integer(data_key="ending_block_hash", required=True)
ending_block_hash = Felt(data_key="ending_block_hash", required=True)
ending_block_number = fields.Integer(
data_key="ending_block_number", required=True, validate=lambda x: x >= 0
)
Expand Down
22 changes: 20 additions & 2 deletions starknet_py/net/ws_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Any, Dict, Optional, Union, cast
from typing import Any, Callable, Dict, Optional, Union, cast

from websockets.asyncio.client import ClientConnection, connect

Expand Down Expand Up @@ -31,12 +31,29 @@ async def send_raw(
self,
payload: Optional[Dict[str, Any]] = None,
):
"""
Sends a message to the WebSocket server and returns the response.
:param payload: The message to send.
"""
assert self.connection is not None
await self.connection.send(json.dumps(payload))
data = await self.connection.recv()

return data

async def listen(self, received_message_handler: Callable[[Dict[str, Any]], Any]):
"""
Listens for incoming WebSocket messages.
:param received_message_handler: The function to call when a message is received.
"""
assert self.connection is not None

async for message in self.connection:
message = cast(Dict, json.loads(message))
received_message_handler(message)


class RpcWSClient(WSClient):
"""
Expand All @@ -59,6 +76,7 @@ async def send(
data = cast(Dict, json.loads(data))

if "result" not in data:
# TODO(#1498): Possibly move `handle_rpc_error` from `RpcHttpClient` to separate function
RpcHttpClient.handle_rpc_error(data)

return data
return data["result"]

0 comments on commit 2f1a660

Please sign in to comment.