Skip to content

Commit

Permalink
Customize server negotiation behaviour, pass extra arguments directly…
Browse files Browse the repository at this point in the history
… to asyncio.create_server
  • Loading branch information
taras committed Aug 19, 2024
1 parent 9ab0422 commit cd5fb5d
Show file tree
Hide file tree
Showing 7 changed files with 230 additions and 121 deletions.
11 changes: 8 additions & 3 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,14 @@ Echo server
transport.disconnect()
async def main():
url = "ws://127.0.0.1:9001"
server = await ws_create_server(url, ServerClientListener)
print(f"Server started on {url}")
def listener_factory(r: WSUpgradeRequest):
# Routing can be implemented here by analyzing request content
return ServerClientListener()
server: asyncio.Server = await ws_create_server(listener_factory, "127.0.0.1", 9001)
for s in server.sockets:
print(f"Server started on {s.getsockname()}")
await server.serve_forever()
if __name__ == '__main__':
Expand Down
23 changes: 23 additions & 0 deletions docs/source/reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,29 @@ Classes

Size of the payload.

.. autoclass:: WSUpgradeRequest
:members:

.. py:attribute:: method
:type: bytes

Request method. b"GET", b"POST", etc

.. py:attribute:: path
:type: bytes

Request path. For example b"/ws"

.. py:attribute:: version
:type: bytes

HTTP version. For example b"HTTP/1.1"

.. py:attribute:: headers
:type: Dict[str, str]

Request headers. header names are always in lowercase

.. autoclass:: WSListener
:members:

Expand Down
20 changes: 11 additions & 9 deletions examples/echo_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
from logging import getLogger, INFO, basicConfig
from ssl import SSLContext

from picows import WSFrame, WSTransport, ws_create_server, WSListener, WSMsgType
from picows import WSFrame, WSTransport, ws_create_server, WSListener, WSMsgType, WSUpgradeRequest

_logger = getLogger(__name__)


class PicowsServerListener(WSListener):
class ServerClientListener(WSListener):
def on_ws_connected(self, transport: WSTransport):
self._transport = transport

Expand All @@ -21,23 +21,25 @@ def on_ws_frame(self, transport: WSTransport, frame: WSFrame):


async def async_main():
url = "ws://127.0.0.1:9001"
url_ssl = "wss://127.0.0.1:9002"
def listener_factory(r: WSUpgradeRequest):
return ServerClientListener()

plain_server = await ws_create_server(url, PicowsServerListener,
plain_server = await ws_create_server(listener_factory,
"127.0.0.1", 9001,
websocket_handshake_timeout=0.5)
_logger.info("Server started on %s", url)
_logger.info("Server started on %s", plain_server.sockets[0].getsockname())

ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
ssl_context.load_cert_chain(pathlib.Path(__file__).parent.parent / "tests" / "picows_test.crt",
pathlib.Path(__file__).parent.parent / "tests" / "picows_test.key")
ssl_context.check_hostname = False
ssl_context.hostname_checks_common_name = False
ssl_context.verify_mode = ssl.CERT_NONE
ssl_server = await ws_create_server(url_ssl, PicowsServerListener,
ssl_context=ssl_context,
ssl_server = await ws_create_server(listener_factory,
"127.0.0.1", 9002,
ssl=ssl_context,
websocket_handshake_timeout=0.5)
_logger.info("Server started on %s", url_ssl)
_logger.info("Server started on %s", ssl_server.sockets[0].getsockname())

await asyncio.gather(plain_server.serve_forever(), ssl_server.serve_forever())

Expand Down
2 changes: 2 additions & 0 deletions picows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
WSFrame,
WSTransport,
WSListener,
WSUpgradeRequest,
ws_connect,
ws_create_server,
PICOWS_DEBUG_LL
Expand All @@ -16,6 +17,7 @@
'WSFrame',
'WSTransport',
'WSListener',
'WSUpgradeRequest',
'ws_connect',
'ws_create_server',
'PICOWS_DEBUG_LL'
Expand Down
19 changes: 16 additions & 3 deletions picows/picows.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,22 @@ cdef class MemoryBuffer:
cdef resize(self, Py_ssize_t new_size)


cdef class WSUpgradeRequest:
cdef:
readonly bytes method
readonly bytes path
readonly bytes version
readonly dict headers


cdef class WSFrame:
cdef:
char* payload_ptr
size_t payload_size
readonly size_t tail_size
readonly WSMsgType msg_type
readonly uint8_t fin
readonly uint8_t rsv1
readonly uint8_t last_in_buffer

cpdef bytes get_payload_as_bytes(self)
Expand All @@ -73,6 +82,7 @@ cdef class WSTransport:
readonly object underlying_transport #: asyncio.Transport

object _logger #: Logger
bint _log_debug_enabled
object _disconnected_future #: asyncio.Future
MemoryBuffer _write_buf
bint _is_client_side
Expand All @@ -84,9 +94,12 @@ cdef class WSTransport:
cpdef send_close(self, WSCloseCode close_code=*, close_message=*)
cpdef disconnect(self)

cdef send_http_handshake(self, bytes ws_path, bytes host_port, bytes websocket_key_b64)
cdef send_http_handshake_response(self, bytes accept_val)
cdef mark_disconnected(self)
cdef _send_http_handshake(self, bytes ws_path, bytes host_port, bytes websocket_key_b64)
cdef _send_http_handshake_response(self, bytes accept_val)
cdef _send_bad_request(self, str error)
cdef _send_not_found(self, WSUpgradeRequest r)
cdef _send_internal_error(self, WSUpgradeRequest r, str error)
cdef _mark_disconnected(self)

cdef bytes _prepare_frame_in_external_buffer(self, WSMsgType msg_type, uint8_t* msg_ptr, size_t msg_length)
cdef bytes _prepare_frame(self, WSMsgType msg_type, message)
Expand Down
Loading

0 comments on commit cd5fb5d

Please sign in to comment.