Skip to content

Commit

Permalink
Add ssl echo test
Browse files Browse the repository at this point in the history
  • Loading branch information
taras committed Aug 16, 2024
1 parent c408692 commit 0bc6fa6
Show file tree
Hide file tree
Showing 7 changed files with 134 additions and 40 deletions.
33 changes: 23 additions & 10 deletions examples/echo_client_benchmark.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import argparse
import asyncio
import os
import ssl

from logging import getLogger
from ssl import SSLContext

import websockets
import aiohttp
Expand All @@ -19,7 +21,16 @@
}


async def picows_main(endpoint: str, msg: bytes, duration: int):
def create_client_ssl_context():
ssl_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
ssl_context.load_default_certs(ssl.Purpose.SERVER_AUTH)
ssl_context.check_hostname = False
ssl_context.hostname_checks_common_name = False
ssl_context.verify_mode = ssl.CERT_NONE
return ssl_context


async def picows_main(endpoint: str, msg: bytes, duration: int, ssl_context):
class PicowsClientListener(WSListener):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -52,12 +63,12 @@ def on_ws_frame(self, transport: WSTransport, frame: WSFrame):
else:
self._transport.send(WSMsgType.BINARY, msg)

(_, client) = await ws_connect(endpoint, PicowsClientListener, "client")
(_, client) = await ws_connect(endpoint, PicowsClientListener, "client", ssl=ssl_context)
await client._transport.wait_until_closed()


async def websockets_main(endpoint: str, msg: bytes, duration: int):
async with websockets.connect(endpoint) as websocket:
async def websockets_main(endpoint: str, msg: bytes, duration: int, ssl_context):
async with websockets.connect(endpoint, ssl=ssl_context) as websocket:
await websocket.send(msg)
start_time = time()
cnt = 0
Expand All @@ -73,9 +84,9 @@ async def websockets_main(endpoint: str, msg: bytes, duration: int):
RPS[f"websockets({websockets.__version__})"] = int(cnt / duration)


async def aiohttp_main(url: str, data: bytes, duration: int) -> None:
async def aiohttp_main(url: str, data: bytes, duration: int, ssl_context) -> None:
async with ClientSession() as session:
async with session.ws_connect(url) as ws:
async with session.ws_connect(url, ssl_context=ssl_context) as ws:
# send request
cnt = 0
start_time = time()
Expand Down Expand Up @@ -124,15 +135,17 @@ async def aiohttp_main(url: str, data: bytes, duration: int) -> None:
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
loop_name = "uvloop"

ssl_context = create_client_ssl_context() if args.url.startswith("wss://") else None

try:
from examples.picows_client_cython import picows_main_cython
asyncio.get_event_loop().run_until_complete(picows_main_cython(args.url, msg, duration))
asyncio.get_event_loop().run_until_complete(picows_main_cython(args.url, msg, duration, ssl_context))
except ImportError:
pass

asyncio.run(picows_main(args.url, msg, duration))
asyncio.run(aiohttp_main(args.url, msg, duration))
asyncio.run(websockets_main(args.url, msg, duration))
asyncio.run(picows_main(args.url, msg, duration, ssl_context))
asyncio.run(aiohttp_main(args.url, msg, duration, ssl_context))
asyncio.run(websockets_main(args.url, msg, duration, ssl_context))

for k, v in RPS.items():
print(k, v)
Expand Down
14 changes: 12 additions & 2 deletions examples/echo_server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import asyncio
import os
import pathlib
import ssl
from logging import getLogger, INFO, basicConfig
from ssl import SSLContext

from picows import WSFrame, WSTransport, ws_create_server, WSListener, WSMsgType

_logger = getLogger(__name__)
Expand All @@ -17,8 +21,14 @@ def on_ws_frame(self, transport: WSTransport, frame: WSFrame):


async def async_main():
url = "ws://127.0.0.1:9001"
server = await ws_create_server(url, PicowsServerListener, "server")
url = "wss://127.0.0.1:9001"
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
ssl_context.load_cert_chain(pathlib.Path(__file__).parent.parent / "tests" / "picows_test.crt",
pathlib.Path(__file__).parent.parent / "tests" / "picows_test.key")
ssl_context.check_hostname = False
ssl_context.hostname_checks_common_name = False
ssl_context.verify_mode = ssl.CERT_NONE
server = await ws_create_server(url, PicowsServerListener, "server", ssl_context=ssl_context)
_logger.info("Server started on %s", url)
server_task = asyncio.get_running_loop().create_task(server.serve_forever())
await server_task
Expand Down
4 changes: 2 additions & 2 deletions examples/picows_client_cython.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ cdef class PicowsClientListener(WSListener):
self._transport.send(WSMsgType.BINARY, self._data)


async def picows_main_cython(url: str, data: bytes, duration: int):
async def picows_main_cython(url: str, data: bytes, duration: int, ssl_context):
cdef PicowsClientListener client
(_, client) = await ws_connect("ws://127.0.0.1:9001", lambda: PicowsClientListener(data, duration), "client")
(_, client) = await ws_connect(url, lambda: PicowsClientListener(data, duration), "client", ssl=ssl_context)
await client._transport.wait_until_closed()
26 changes: 14 additions & 12 deletions picows/picows.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ import hashlib
import logging
import os
import socket
import ssl
import struct
import urllib.parse
from ssl import SSLContext
from typing import cast, Tuple, Optional, Callable
from typing import cast, Tuple, Optional, Callable, Union

cimport cython

Expand Down Expand Up @@ -607,11 +608,11 @@ cdef class WSProtocol:
if not self._handshake_complete_future.done():
self._handshake_complete_future.set_result(None)

self.transport.mark_disconnected()

if self._handshake_timeout_handle is not None:
self._handshake_timeout_handle.cancel()

self.transport.mark_disconnected()

def eof_received(self) -> bool:
self._logger.debug("WS eof received")
# Returning False here means that the transport should close itself
Expand Down Expand Up @@ -964,13 +965,13 @@ cdef class WSProtocol:

def _handshake_timeout_callback(self):
self._logger.info("Handshake timeout, the client hasn't requested upgrade within required time, close connection")
self.transport.close()
self.transport.disconnect()


async def ws_connect(str url: str,
ws_listener_factory: Callable[[], WSListener],
str logger_name: str,
ssl: Optional[SSLContext]=None,
ssl: Optional[Union[bool, SSLContext]]=None,
bint disconnect_on_exception: bool=True,
ssl_handshake_timeout: int=5,
ssl_shutdown_timeout: int=5,
Expand Down Expand Up @@ -1004,10 +1005,11 @@ async def ws_connect(str url: str,
url_parts = urllib.parse.urlparse(url, allow_fragments=False)

if url_parts.scheme == "wss":
ssl = ssl or True
if ssl is None:
ssl = True
port = url_parts.port or 443
elif url_parts.scheme == "ws":
ssl_context = None
ssl = None
ssl_handshake_timeout = None
ssl_shutdown_timeout = None
port = url_parts.port or 80
Expand All @@ -1033,7 +1035,8 @@ async def ws_connect(str url: str,

async def ws_create_server(str url,
ws_listener_factory,
str logger_name, ssl_context=None,
str logger_name,
ssl_context=None,
disconnect_on_exception=True,
ssl_handshake_timeout: int=5,
ssl_shutdown_timeout: int=5,
Expand Down Expand Up @@ -1074,7 +1077,8 @@ async def ws_create_server(str url,
url_parts = urllib.parse.urlparse(url, allow_fragments=False)

if url_parts.scheme == "wss":
ssl_context = ssl_context or True
if ssl_context is None:
ssl_context = SSLContext(ssl.PROTOCOL_TLS_SERVER)
port = url_parts.port or 443
elif url_parts.scheme == "ws":
ssl_context = None
Expand All @@ -1089,13 +1093,11 @@ async def ws_create_server(str url,

cdef WSProtocol ws_protocol

server = await asyncio.get_running_loop().create_server(
return await asyncio.get_running_loop().create_server(
ws_protocol_factory,
host=url_parts.hostname, port=port,
ssl=ssl_context,
ssl_handshake_timeout=ssl_handshake_timeout,
ssl_shutdown_timeout=ssl_shutdown_timeout,
reuse_port=reuse_port,
start_serving=start_serving)

return server
23 changes: 23 additions & 0 deletions tests/picows_test.crt
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
-----BEGIN CERTIFICATE-----
MIID2TCCAsGgAwIBAgIUFYHtE43RqOZaGuqr8co1+KUFhA8wDQYJKoZIhvcNAQEL
BQAwfDELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxDTALBgNVBAcM
BENpdHkxITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDERMA8GA1UE
CwwISmFuaXRvcnMxEzARBgNVBAMMCnBpY293cy5jb20wHhcNMjQwODE2MDk0MzAz
WhcNMjUwODE2MDk0MzAzWjB8MQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1T
dGF0ZTENMAsGA1UEBwwEQ2l0eTEhMB8GA1UECgwYSW50ZXJuZXQgV2lkZ2l0cyBQ
dHkgTHRkMREwDwYDVQQLDAhKYW5pdG9yczETMBEGA1UEAwwKcGljb3dzLmNvbTCC
ASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAKVXEHQtsfk5cgedGZ0LuTSa
I1WsB/3zJRYHXnlpA1hULM0vaWoFQoNrpUKv5qJ6i3idUiGH1odpU+qMtL8I+6+y
PdRk1iGH01vWFDh28MEL1+ONYKdDeKPRz+rBV8T/518jpIVnjX0j/acRXYBqTBZV
lh+dTfJ8hcJ0g00lTDci3ypa+NUHcMdDi4pL9iPALWZYjwTmLMVnZt+jwd2a7VmC
Rvpo5ZDyc1ZeHFaN2jlJD1By/VlDZg5wXHu2DWlz7ppXxLc5z7z+f/fgDQsm2Qf9
JE159yxYLuZHJfzUYyVbut2Q+o0ChQcUy5MC3PyRSgT2BPV5RtTnbf11Ue5SBXEC
AwEAAaNTMFEwHQYDVR0OBBYEFEe4nb2hkXM7JBi5R7jR6FqLY0/EMB8GA1UdIwQY
MBaAFEe4nb2hkXM7JBi5R7jR6FqLY0/EMA8GA1UdEwEB/wQFMAMBAf8wDQYJKoZI
hvcNAQELBQADggEBAIeHCG8voBxOmp5pF48BH+c7fgVYQcaEIY0qZG8lpvPJUtSw
nMjy7GLAzmmaeUPi4e3zZFEQHepOAVPZfM+vSoKGmwO/ngcSPl1pA+w5yyq9h0xx
ix2PxX7sn9sI/LCdalhz75E4ao7dL1uf/9MrJew4619aUn5tlqt5Zt29XFq4Mp1A
oQFUNv+AbgSgOyHtpY/aCRjNDyjqgSPRnBRQTghnKL6FMS7jzjYx1mnzZd4wUJU+
SNc3jPM8JrncbrW1jvpR6FITUeTwycl6OEAk19wG1AAUeqFbcQ4JQbjJ4My5FNOL
dHPhBiI55PsoMILscEEiit5seQVa7mLPcXekmjM=
-----END CERTIFICATE-----
28 changes: 28 additions & 0 deletions tests/picows_test.key
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
-----BEGIN PRIVATE KEY-----
MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQClVxB0LbH5OXIH
nRmdC7k0miNVrAf98yUWB155aQNYVCzNL2lqBUKDa6VCr+aieot4nVIhh9aHaVPq
jLS/CPuvsj3UZNYhh9Nb1hQ4dvDBC9fjjWCnQ3ij0c/qwVfE/+dfI6SFZ419I/2n
EV2AakwWVZYfnU3yfIXCdINNJUw3It8qWvjVB3DHQ4uKS/YjwC1mWI8E5izFZ2bf
o8Hdmu1Zgkb6aOWQ8nNWXhxWjdo5SQ9Qcv1ZQ2YOcFx7tg1pc+6aV8S3Oc+8/n/3
4A0LJtkH/SRNefcsWC7mRyX81GMlW7rdkPqNAoUHFMuTAtz8kUoE9gT1eUbU5239
dVHuUgVxAgMBAAECggEAIKnRi1IXrcemelCT5c2SGg01W9Bvh8DIG9D63ftb+NAw
RZzPM220lSfn7wO3CXQyJ3LU7eDbBcdOF7ESaAcLSctpoZMGJnuhyqvBNOJyb4C6
dq7QYFrY82wYqiTmBPoVKFtvPheWARniG2Y31Y2qWGMyFC7MXlIxTpgb9Fqg8uc0
LfcTeSPn6E4gSECjUxTyxWpr7facCwCNYrrkrCLyat74cXHPiGYPOLWx0HtuY8yW
wZZL9oT4x0S4kmUSSpQMw7oPgbtLQf7d2m+Z8BEOMHFqHzUkgNwILVo4DBGZzfKO
egXhJ9pEAJlJ42MZg9mtDF8efOAm5f2mPw4TNay67QKBgQDeKbTKb5YepkwdqfA6
j3TTxPjLLfA/sLbIVKnaj9zErr6Dx9xT6Bw1+bP9mDmTaZuY3oRAhbHoNmnjnsvf
8BgY9G5NfeTA5CikFiKGJallPY6RlELEDLYCfGRFLOWsCE8FclRWUUgrfJY4i4vQ
MrZe0nuAuvOpfn5OU3a9w2/frwKBgQC+hcuA+8TIB9CQVUb1soqgWdTzyAqMk637
WfQfM33s6rjwxA2RZfkRO76I/+k/kd4sE1uOj7qN1dwzZfzZd0GMsMBBHA1w0Ox7
bGsqU4G32+tFEBzE3qYqua6SJl3rsteXGHXcX7ngylrvFBmlBW1b5uImME5z3vOb
yNV3NIOU3wKBgBRZd6DvVa3bB6/T6BhFGatoKG3b+FytICD7eE93y/4MD5Fcljbt
VOAwzibVcbip/MGk6DJMzL37dfmOixgpEtv+T7gzZuewPnTBPkpRWtHWMJ/vF6qD
i4xwvnKDqUn3vN0/2q/JZDXvhIcLaTQZ4RCQcRWaikUlPAaKqJ67Lx0rAoGBAL4S
cPIfOzRsR3CXAxH/qzlKJZ+HxK52bq5CEcBG+Kwxh4v7q6WQ3CiLOA0pciPPfJzw
OvlA/tadsu88IkM6LJUViNfsCqSwahzADzHM2a75of/mkSz/Czu4vyZjTHPmmhrN
dlgC0Ego2QuHPAZcIbv73UZIDxyeIt8aP4yLQXJ1AoGBAKYoVI/LEdVCZqBH9Kza
Hh6kRnWIzyh5ZBzAw5nQzP75HUzcOE3uZ2LDkkK3dJCPgXrNNHoKLlIjuYuaWZgb
KKTqKsDu855+w0JGO91HIsVmNZI9fH8/1s4cM86YvjOtkikyp9ZdEmR9YMbzYqVn
SMxDi9VWK7Vi3fBdd0dCBBHr
-----END PRIVATE KEY-----
46 changes: 32 additions & 14 deletions tests/test_echo.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,35 @@
import asyncio
import base64
import itertools
import os
import pathlib
import ssl

import picows
import pytest
import async_timeout

URL = "ws://127.0.0.1:9001"
URL_SSL = "wss://127.0.0.1:9002"


def create_server_ssl_context():
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
ssl_context.load_cert_chain(pathlib.Path(__file__).parent / "picows_test.crt",
pathlib.Path(__file__).parent / "picows_test.key")
ssl_context.check_hostname = False
ssl_context.hostname_checks_common_name = False
ssl_context.verify_mode = ssl.CERT_NONE
return ssl_context


def create_client_ssl_context():
ssl_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
ssl_context.load_default_certs(ssl.Purpose.SERVER_AUTH)
ssl_context.check_hostname = False
ssl_context.hostname_checks_common_name = False
ssl_context.verify_mode = ssl.CERT_NONE
return ssl_context


class BinaryFrame:
Expand All @@ -33,25 +56,22 @@ def __init__(self, frame: picows.WSFrame):
self.fin = frame.fin


#@pytest.fixture(scope="module")
@pytest.fixture
async def echo_server():
@pytest.fixture(params=[URL, URL_SSL])
async def echo_server(request):
class PicowsServerListener(picows.WSListener):
def on_ws_connected(self, transport: picows.WSTransport):
print("echo_server:on_ws_connected")
self._transport = transport

def on_ws_frame(self, transport: picows.WSTransport, frame: picows.WSFrame):
print("echo_server:on_ws_frame")
self._transport.send(frame.msg_type, frame.get_payload_as_bytes())
if frame.msg_type == picows.WSMsgType.CLOSE:
self._transport.send_close(frame.get_close_code(), frame.get_close_message())
self._transport.disconnect()

server = await picows.ws_create_server(URL, PicowsServerListener, "server")
server = await picows.ws_create_server(request.param, PicowsServerListener, "server",
ssl_context=create_server_ssl_context())
task = asyncio.create_task(server.serve_forever())
print("initiated module level echo server")
yield server
yield request.param

# Teardown server
task.cancel()
Expand All @@ -60,11 +80,8 @@ def on_ws_frame(self, transport: picows.WSTransport, frame: picows.WSFrame):
except:
pass

print("stopped module level echo server")


# @pytest.fixture(scope="module")
@pytest.fixture
@pytest.fixture()
async def echo_client(echo_server):
class PicowsClientListener(picows.WSListener):
transport: picows.WSTransport
Expand All @@ -86,7 +103,8 @@ async def get_message(self):
async with async_timeout.timeout(1):
return await self.msg_queue.get()

(_, client) = await picows.ws_connect(URL, PicowsClientListener, "client")
(_, client) = await picows.ws_connect(echo_server, PicowsClientListener, "client",
ssl=create_client_ssl_context())
yield client

# Teardown client
Expand All @@ -99,7 +117,7 @@ async def get_message(self):
client.transport.disconnect()


@pytest.mark.parametrize("msg_size", [32, 1024, 20000])
@pytest.mark.parametrize("msg_size", [256, 1024])
async def test_echo(echo_client, msg_size):
msg = os.urandom(msg_size)
echo_client.transport.send(picows.WSMsgType.BINARY, msg)
Expand Down

0 comments on commit 0bc6fa6

Please sign in to comment.