diff --git a/examples/echo_client_benchmark.py b/examples/echo_client_benchmark.py index f64f99c..9c03974 100644 --- a/examples/echo_client_benchmark.py +++ b/examples/echo_client_benchmark.py @@ -1,8 +1,10 @@ import argparse import asyncio import os +import ssl from logging import getLogger +from ssl import SSLContext import websockets import aiohttp @@ -19,7 +21,16 @@ } -async def picows_main(endpoint: str, msg: bytes, duration: int): +def create_client_ssl_context(): + ssl_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) + ssl_context.load_default_certs(ssl.Purpose.SERVER_AUTH) + ssl_context.check_hostname = False + ssl_context.hostname_checks_common_name = False + ssl_context.verify_mode = ssl.CERT_NONE + return ssl_context + + +async def picows_main(endpoint: str, msg: bytes, duration: int, ssl_context): class PicowsClientListener(WSListener): def __init__(self): super().__init__() @@ -52,12 +63,12 @@ def on_ws_frame(self, transport: WSTransport, frame: WSFrame): else: self._transport.send(WSMsgType.BINARY, msg) - (_, client) = await ws_connect(endpoint, PicowsClientListener, "client") + (_, client) = await ws_connect(endpoint, PicowsClientListener, "client", ssl=ssl_context) await client._transport.wait_until_closed() -async def websockets_main(endpoint: str, msg: bytes, duration: int): - async with websockets.connect(endpoint) as websocket: +async def websockets_main(endpoint: str, msg: bytes, duration: int, ssl_context): + async with websockets.connect(endpoint, ssl=ssl_context) as websocket: await websocket.send(msg) start_time = time() cnt = 0 @@ -73,9 +84,9 @@ async def websockets_main(endpoint: str, msg: bytes, duration: int): RPS[f"websockets({websockets.__version__})"] = int(cnt / duration) -async def aiohttp_main(url: str, data: bytes, duration: int) -> None: +async def aiohttp_main(url: str, data: bytes, duration: int, ssl_context) -> None: async with ClientSession() as session: - async with session.ws_connect(url) as ws: + async with session.ws_connect(url, ssl_context=ssl_context) as ws: # send request cnt = 0 start_time = time() @@ -124,15 +135,17 @@ async def aiohttp_main(url: str, data: bytes, duration: int) -> None: asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) loop_name = "uvloop" + ssl_context = create_client_ssl_context() if args.url.startswith("wss://") else None + try: from examples.picows_client_cython import picows_main_cython - asyncio.get_event_loop().run_until_complete(picows_main_cython(args.url, msg, duration)) + asyncio.get_event_loop().run_until_complete(picows_main_cython(args.url, msg, duration, ssl_context)) except ImportError: pass - asyncio.run(picows_main(args.url, msg, duration)) - asyncio.run(aiohttp_main(args.url, msg, duration)) - asyncio.run(websockets_main(args.url, msg, duration)) + asyncio.run(picows_main(args.url, msg, duration, ssl_context)) + asyncio.run(aiohttp_main(args.url, msg, duration, ssl_context)) + asyncio.run(websockets_main(args.url, msg, duration, ssl_context)) for k, v in RPS.items(): print(k, v) diff --git a/examples/echo_server.py b/examples/echo_server.py index a182c73..0dc0c49 100644 --- a/examples/echo_server.py +++ b/examples/echo_server.py @@ -1,6 +1,10 @@ import asyncio import os +import pathlib +import ssl from logging import getLogger, INFO, basicConfig +from ssl import SSLContext + from picows import WSFrame, WSTransport, ws_create_server, WSListener, WSMsgType _logger = getLogger(__name__) @@ -17,8 +21,14 @@ def on_ws_frame(self, transport: WSTransport, frame: WSFrame): async def async_main(): - url = "ws://127.0.0.1:9001" - server = await ws_create_server(url, PicowsServerListener, "server") + url = "wss://127.0.0.1:9001" + ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + ssl_context.load_cert_chain(pathlib.Path(__file__).parent.parent / "tests" / "picows_test.crt", + pathlib.Path(__file__).parent.parent / "tests" / "picows_test.key") + ssl_context.check_hostname = False + ssl_context.hostname_checks_common_name = False + ssl_context.verify_mode = ssl.CERT_NONE + server = await ws_create_server(url, PicowsServerListener, "server", ssl_context=ssl_context) _logger.info("Server started on %s", url) server_task = asyncio.get_running_loop().create_task(server.serve_forever()) await server_task diff --git a/examples/picows_client_cython.pyx b/examples/picows_client_cython.pyx index fbbccc8..2100394 100644 --- a/examples/picows_client_cython.pyx +++ b/examples/picows_client_cython.pyx @@ -74,7 +74,7 @@ cdef class PicowsClientListener(WSListener): self._transport.send(WSMsgType.BINARY, self._data) -async def picows_main_cython(url: str, data: bytes, duration: int): +async def picows_main_cython(url: str, data: bytes, duration: int, ssl_context): cdef PicowsClientListener client - (_, client) = await ws_connect("ws://127.0.0.1:9001", lambda: PicowsClientListener(data, duration), "client") + (_, client) = await ws_connect(url, lambda: PicowsClientListener(data, duration), "client", ssl=ssl_context) await client._transport.wait_until_closed() diff --git a/picows/picows.pyx b/picows/picows.pyx index 74db569..052fded 100644 --- a/picows/picows.pyx +++ b/picows/picows.pyx @@ -5,10 +5,11 @@ import hashlib import logging import os import socket +import ssl import struct import urllib.parse from ssl import SSLContext -from typing import cast, Tuple, Optional, Callable +from typing import cast, Tuple, Optional, Callable, Union cimport cython @@ -607,11 +608,11 @@ cdef class WSProtocol: if not self._handshake_complete_future.done(): self._handshake_complete_future.set_result(None) - self.transport.mark_disconnected() - if self._handshake_timeout_handle is not None: self._handshake_timeout_handle.cancel() + self.transport.mark_disconnected() + def eof_received(self) -> bool: self._logger.debug("WS eof received") # Returning False here means that the transport should close itself @@ -964,13 +965,13 @@ cdef class WSProtocol: def _handshake_timeout_callback(self): self._logger.info("Handshake timeout, the client hasn't requested upgrade within required time, close connection") - self.transport.close() + self.transport.disconnect() async def ws_connect(str url: str, ws_listener_factory: Callable[[], WSListener], str logger_name: str, - ssl: Optional[SSLContext]=None, + ssl: Optional[Union[bool, SSLContext]]=None, bint disconnect_on_exception: bool=True, ssl_handshake_timeout: int=5, ssl_shutdown_timeout: int=5, @@ -1004,10 +1005,11 @@ async def ws_connect(str url: str, url_parts = urllib.parse.urlparse(url, allow_fragments=False) if url_parts.scheme == "wss": - ssl = ssl or True + if ssl is None: + ssl = True port = url_parts.port or 443 elif url_parts.scheme == "ws": - ssl_context = None + ssl = None ssl_handshake_timeout = None ssl_shutdown_timeout = None port = url_parts.port or 80 @@ -1033,7 +1035,8 @@ async def ws_connect(str url: str, async def ws_create_server(str url, ws_listener_factory, - str logger_name, ssl_context=None, + str logger_name, + ssl_context=None, disconnect_on_exception=True, ssl_handshake_timeout: int=5, ssl_shutdown_timeout: int=5, @@ -1074,7 +1077,8 @@ async def ws_create_server(str url, url_parts = urllib.parse.urlparse(url, allow_fragments=False) if url_parts.scheme == "wss": - ssl_context = ssl_context or True + if ssl_context is None: + ssl_context = SSLContext(ssl.PROTOCOL_TLS_SERVER) port = url_parts.port or 443 elif url_parts.scheme == "ws": ssl_context = None @@ -1089,7 +1093,7 @@ async def ws_create_server(str url, cdef WSProtocol ws_protocol - server = await asyncio.get_running_loop().create_server( + return await asyncio.get_running_loop().create_server( ws_protocol_factory, host=url_parts.hostname, port=port, ssl=ssl_context, @@ -1097,5 +1101,3 @@ async def ws_create_server(str url, ssl_shutdown_timeout=ssl_shutdown_timeout, reuse_port=reuse_port, start_serving=start_serving) - - return server diff --git a/tests/picows_test.crt b/tests/picows_test.crt new file mode 100644 index 0000000..446c091 --- /dev/null +++ b/tests/picows_test.crt @@ -0,0 +1,23 @@ +-----BEGIN CERTIFICATE----- +MIID2TCCAsGgAwIBAgIUFYHtE43RqOZaGuqr8co1+KUFhA8wDQYJKoZIhvcNAQEL +BQAwfDELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxDTALBgNVBAcM +BENpdHkxITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDERMA8GA1UE +CwwISmFuaXRvcnMxEzARBgNVBAMMCnBpY293cy5jb20wHhcNMjQwODE2MDk0MzAz +WhcNMjUwODE2MDk0MzAzWjB8MQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1T +dGF0ZTENMAsGA1UEBwwEQ2l0eTEhMB8GA1UECgwYSW50ZXJuZXQgV2lkZ2l0cyBQ +dHkgTHRkMREwDwYDVQQLDAhKYW5pdG9yczETMBEGA1UEAwwKcGljb3dzLmNvbTCC +ASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAKVXEHQtsfk5cgedGZ0LuTSa +I1WsB/3zJRYHXnlpA1hULM0vaWoFQoNrpUKv5qJ6i3idUiGH1odpU+qMtL8I+6+y +PdRk1iGH01vWFDh28MEL1+ONYKdDeKPRz+rBV8T/518jpIVnjX0j/acRXYBqTBZV +lh+dTfJ8hcJ0g00lTDci3ypa+NUHcMdDi4pL9iPALWZYjwTmLMVnZt+jwd2a7VmC +Rvpo5ZDyc1ZeHFaN2jlJD1By/VlDZg5wXHu2DWlz7ppXxLc5z7z+f/fgDQsm2Qf9 +JE159yxYLuZHJfzUYyVbut2Q+o0ChQcUy5MC3PyRSgT2BPV5RtTnbf11Ue5SBXEC +AwEAAaNTMFEwHQYDVR0OBBYEFEe4nb2hkXM7JBi5R7jR6FqLY0/EMB8GA1UdIwQY +MBaAFEe4nb2hkXM7JBi5R7jR6FqLY0/EMA8GA1UdEwEB/wQFMAMBAf8wDQYJKoZI +hvcNAQELBQADggEBAIeHCG8voBxOmp5pF48BH+c7fgVYQcaEIY0qZG8lpvPJUtSw +nMjy7GLAzmmaeUPi4e3zZFEQHepOAVPZfM+vSoKGmwO/ngcSPl1pA+w5yyq9h0xx +ix2PxX7sn9sI/LCdalhz75E4ao7dL1uf/9MrJew4619aUn5tlqt5Zt29XFq4Mp1A +oQFUNv+AbgSgOyHtpY/aCRjNDyjqgSPRnBRQTghnKL6FMS7jzjYx1mnzZd4wUJU+ +SNc3jPM8JrncbrW1jvpR6FITUeTwycl6OEAk19wG1AAUeqFbcQ4JQbjJ4My5FNOL +dHPhBiI55PsoMILscEEiit5seQVa7mLPcXekmjM= +-----END CERTIFICATE----- diff --git a/tests/picows_test.key b/tests/picows_test.key new file mode 100644 index 0000000..b78089f --- /dev/null +++ b/tests/picows_test.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQClVxB0LbH5OXIH +nRmdC7k0miNVrAf98yUWB155aQNYVCzNL2lqBUKDa6VCr+aieot4nVIhh9aHaVPq +jLS/CPuvsj3UZNYhh9Nb1hQ4dvDBC9fjjWCnQ3ij0c/qwVfE/+dfI6SFZ419I/2n +EV2AakwWVZYfnU3yfIXCdINNJUw3It8qWvjVB3DHQ4uKS/YjwC1mWI8E5izFZ2bf +o8Hdmu1Zgkb6aOWQ8nNWXhxWjdo5SQ9Qcv1ZQ2YOcFx7tg1pc+6aV8S3Oc+8/n/3 +4A0LJtkH/SRNefcsWC7mRyX81GMlW7rdkPqNAoUHFMuTAtz8kUoE9gT1eUbU5239 +dVHuUgVxAgMBAAECggEAIKnRi1IXrcemelCT5c2SGg01W9Bvh8DIG9D63ftb+NAw +RZzPM220lSfn7wO3CXQyJ3LU7eDbBcdOF7ESaAcLSctpoZMGJnuhyqvBNOJyb4C6 +dq7QYFrY82wYqiTmBPoVKFtvPheWARniG2Y31Y2qWGMyFC7MXlIxTpgb9Fqg8uc0 +LfcTeSPn6E4gSECjUxTyxWpr7facCwCNYrrkrCLyat74cXHPiGYPOLWx0HtuY8yW +wZZL9oT4x0S4kmUSSpQMw7oPgbtLQf7d2m+Z8BEOMHFqHzUkgNwILVo4DBGZzfKO +egXhJ9pEAJlJ42MZg9mtDF8efOAm5f2mPw4TNay67QKBgQDeKbTKb5YepkwdqfA6 +j3TTxPjLLfA/sLbIVKnaj9zErr6Dx9xT6Bw1+bP9mDmTaZuY3oRAhbHoNmnjnsvf +8BgY9G5NfeTA5CikFiKGJallPY6RlELEDLYCfGRFLOWsCE8FclRWUUgrfJY4i4vQ +MrZe0nuAuvOpfn5OU3a9w2/frwKBgQC+hcuA+8TIB9CQVUb1soqgWdTzyAqMk637 +WfQfM33s6rjwxA2RZfkRO76I/+k/kd4sE1uOj7qN1dwzZfzZd0GMsMBBHA1w0Ox7 +bGsqU4G32+tFEBzE3qYqua6SJl3rsteXGHXcX7ngylrvFBmlBW1b5uImME5z3vOb +yNV3NIOU3wKBgBRZd6DvVa3bB6/T6BhFGatoKG3b+FytICD7eE93y/4MD5Fcljbt +VOAwzibVcbip/MGk6DJMzL37dfmOixgpEtv+T7gzZuewPnTBPkpRWtHWMJ/vF6qD +i4xwvnKDqUn3vN0/2q/JZDXvhIcLaTQZ4RCQcRWaikUlPAaKqJ67Lx0rAoGBAL4S +cPIfOzRsR3CXAxH/qzlKJZ+HxK52bq5CEcBG+Kwxh4v7q6WQ3CiLOA0pciPPfJzw +OvlA/tadsu88IkM6LJUViNfsCqSwahzADzHM2a75of/mkSz/Czu4vyZjTHPmmhrN +dlgC0Ego2QuHPAZcIbv73UZIDxyeIt8aP4yLQXJ1AoGBAKYoVI/LEdVCZqBH9Kza +Hh6kRnWIzyh5ZBzAw5nQzP75HUzcOE3uZ2LDkkK3dJCPgXrNNHoKLlIjuYuaWZgb +KKTqKsDu855+w0JGO91HIsVmNZI9fH8/1s4cM86YvjOtkikyp9ZdEmR9YMbzYqVn +SMxDi9VWK7Vi3fBdd0dCBBHr +-----END PRIVATE KEY----- diff --git a/tests/test_echo.py b/tests/test_echo.py index c219473..8419da5 100644 --- a/tests/test_echo.py +++ b/tests/test_echo.py @@ -1,12 +1,35 @@ import asyncio import base64 +import itertools import os +import pathlib +import ssl import picows import pytest import async_timeout URL = "ws://127.0.0.1:9001" +URL_SSL = "wss://127.0.0.1:9002" + + +def create_server_ssl_context(): + ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + ssl_context.load_cert_chain(pathlib.Path(__file__).parent / "picows_test.crt", + pathlib.Path(__file__).parent / "picows_test.key") + ssl_context.check_hostname = False + ssl_context.hostname_checks_common_name = False + ssl_context.verify_mode = ssl.CERT_NONE + return ssl_context + + +def create_client_ssl_context(): + ssl_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) + ssl_context.load_default_certs(ssl.Purpose.SERVER_AUTH) + ssl_context.check_hostname = False + ssl_context.hostname_checks_common_name = False + ssl_context.verify_mode = ssl.CERT_NONE + return ssl_context class BinaryFrame: @@ -33,25 +56,22 @@ def __init__(self, frame: picows.WSFrame): self.fin = frame.fin -#@pytest.fixture(scope="module") -@pytest.fixture -async def echo_server(): +@pytest.fixture(params=[URL, URL_SSL]) +async def echo_server(request): class PicowsServerListener(picows.WSListener): def on_ws_connected(self, transport: picows.WSTransport): - print("echo_server:on_ws_connected") self._transport = transport def on_ws_frame(self, transport: picows.WSTransport, frame: picows.WSFrame): - print("echo_server:on_ws_frame") self._transport.send(frame.msg_type, frame.get_payload_as_bytes()) if frame.msg_type == picows.WSMsgType.CLOSE: self._transport.send_close(frame.get_close_code(), frame.get_close_message()) self._transport.disconnect() - server = await picows.ws_create_server(URL, PicowsServerListener, "server") + server = await picows.ws_create_server(request.param, PicowsServerListener, "server", + ssl_context=create_server_ssl_context()) task = asyncio.create_task(server.serve_forever()) - print("initiated module level echo server") - yield server + yield request.param # Teardown server task.cancel() @@ -60,11 +80,8 @@ def on_ws_frame(self, transport: picows.WSTransport, frame: picows.WSFrame): except: pass - print("stopped module level echo server") - -# @pytest.fixture(scope="module") -@pytest.fixture +@pytest.fixture() async def echo_client(echo_server): class PicowsClientListener(picows.WSListener): transport: picows.WSTransport @@ -86,7 +103,8 @@ async def get_message(self): async with async_timeout.timeout(1): return await self.msg_queue.get() - (_, client) = await picows.ws_connect(URL, PicowsClientListener, "client") + (_, client) = await picows.ws_connect(echo_server, PicowsClientListener, "client", + ssl=create_client_ssl_context()) yield client # Teardown client @@ -99,7 +117,7 @@ async def get_message(self): client.transport.disconnect() -@pytest.mark.parametrize("msg_size", [32, 1024, 20000]) +@pytest.mark.parametrize("msg_size", [256, 1024]) async def test_echo(echo_client, msg_size): msg = os.urandom(msg_size) echo_client.transport.send(picows.WSMsgType.BINARY, msg)