-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #41 from adrianomarto/master
Added support of unix sockets
- Loading branch information
Showing
5 changed files
with
294 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
# Copyright (c) 2021, Adriano Marto Reis | ||
# | ||
# All rights reserved. | ||
# | ||
# The MIT License (MIT) | ||
# | ||
# Permission is hereby granted, free of charge, to any person obtaining | ||
# a copy of this software and associated documentation files (the | ||
# "Software"), to deal in the Software without restriction, including | ||
# without limitation the rights to use, copy, modify, merge, publish, | ||
# distribute, sublicense, and/or sell copies of the Software, and to | ||
# permit persons to whom the Software is furnished to do so, subject to | ||
# the following conditions: | ||
# | ||
# The above copyright notice and this permission notice shall be | ||
# included in all copies or substantial portions of the Software. | ||
# | ||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, | ||
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF | ||
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND | ||
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE | ||
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION | ||
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION | ||
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. | ||
|
||
'''Asyncio protocol abstraction.''' | ||
|
||
__all__ = ('connect_us', 'serve_us') | ||
|
||
|
||
import asyncio | ||
from functools import partial | ||
|
||
from aiorpcx.curio import Event, timeout_after, TaskTimeout | ||
from aiorpcx.session import RPCSession, SessionBase, SessionKind | ||
|
||
|
||
class ConnectionLostError(Exception): | ||
pass | ||
|
||
|
||
class USTransport(asyncio.Protocol): | ||
|
||
def __init__(self, session_factory, framer, kind): | ||
self.session_factory = session_factory | ||
self.loop = asyncio.get_event_loop() | ||
self.session = None | ||
self.kind = kind | ||
self._asyncio_transport = None | ||
self._framer = framer | ||
# Cleared when the send socket is full | ||
self._can_send = Event() | ||
self._can_send.set() | ||
self._closed_event = Event() | ||
self._process_messages_task = None | ||
|
||
async def process_messages(self): | ||
try: | ||
await self.session.process_messages(self.receive_message) | ||
except ConnectionLostError: | ||
pass | ||
finally: | ||
self._closed_event.set() | ||
|
||
async def receive_message(self): | ||
return await self._framer.receive_message() | ||
|
||
def connection_made(self, transport): | ||
'''Called by asyncio when a connection is established.''' | ||
self._asyncio_transport = transport | ||
self.session = self.session_factory(self) | ||
self._framer = self._framer or self.session.default_framer() | ||
self._process_messages_task = self.loop.create_task(self.process_messages()) | ||
|
||
def connection_lost(self, _exeption): | ||
'''Called by asyncio when the connection closes. | ||
Tear down things done in connection_made.''' | ||
# Release waiting tasks | ||
self._can_send.set() | ||
self._framer.fail(ConnectionLostError()) | ||
|
||
def data_received(self, data): | ||
'''Called by asyncio when a message comes in.''' | ||
self.session.data_received(data) | ||
self._framer.received_bytes(data) | ||
|
||
def pause_writing(self): | ||
'''Called by asyncio the send buffer is full.''' | ||
if not self.is_closing(): | ||
self._can_send.clear() | ||
self._asyncio_transport.pause_reading() | ||
|
||
def resume_writing(self): | ||
'''Called by asyncio the send buffer has room.''' | ||
if not self._can_send.is_set(): | ||
self._can_send.set() | ||
self._asyncio_transport.resume_reading() | ||
|
||
# API exposed to session | ||
async def write(self, message): | ||
await self._can_send.wait() | ||
if not self.is_closing(): | ||
framed_message = self._framer.frame(message) | ||
self._asyncio_transport.write(framed_message) | ||
|
||
async def close(self, force_after): | ||
'''Close the connection and return when closed.''' | ||
if self._asyncio_transport: | ||
self._asyncio_transport.close() | ||
try: | ||
async with timeout_after(force_after): | ||
await self._closed_event.wait() | ||
except TaskTimeout: | ||
await self.abort() | ||
await self._closed_event.wait() | ||
|
||
async def abort(self): | ||
if self._asyncio_transport: | ||
self._asyncio_transport.abort() | ||
|
||
def is_closing(self): | ||
'''Return True if the connection is closing.''' | ||
return self._closed_event.is_set() or self._asyncio_transport.is_closing() | ||
|
||
def proxy(self): | ||
'''Not applicable to unix sockets.''' | ||
return None | ||
|
||
def remote_address(self): | ||
'''Not applicable to unix sockets''' | ||
return None | ||
|
||
|
||
class USClient: | ||
|
||
def __init__(self, path=None, *, framer=None, **kwargs): | ||
session_factory = kwargs.pop('session_factory', RPCSession) | ||
self.protocol_factory = partial(USTransport, session_factory, framer, | ||
SessionKind.CLIENT) | ||
self.path = path | ||
self.session = None | ||
self.loop = kwargs.get('loop', asyncio.get_event_loop()) | ||
self.kwargs = kwargs | ||
|
||
async def create_connection(self): | ||
'''Initiate a connection.''' | ||
return await self.loop.create_unix_connection( | ||
self.protocol_factory, self.path, **self.kwargs) | ||
|
||
async def __aenter__(self): | ||
_transport, protocol = await self.create_connection() | ||
self.session = protocol.session | ||
assert isinstance(self.session, SessionBase) | ||
return self.session | ||
|
||
async def __aexit__(self, _type, _value, _traceback): | ||
await self.session.close() | ||
|
||
|
||
async def serve_us(session_factory, path=None, *, framer=None, loop=None, **kwargs): | ||
loop = loop or asyncio.get_event_loop() | ||
protocol_factory = partial(USTransport, session_factory, framer, SessionKind.SERVER) | ||
return await loop.create_unix_server(protocol_factory, path, **kwargs) | ||
|
||
|
||
connect_us = USClient |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import asyncio | ||
import aiorpcx | ||
|
||
|
||
async def main(path): | ||
async with aiorpcx.connect_us(path) as session: | ||
# A good request with standard argument passing | ||
result = await session.send_request('echo', ["Howdy"]) | ||
print(result) | ||
# A good request with named argument passing | ||
result = await session.send_request('echo', {'message': "Hello with a named argument"}) | ||
print(result) | ||
|
||
# aiorpcX transparently handles erroneous calls server-side, returning appropriate | ||
# errors. This in turn causes an exception to be raised in the client. | ||
for bad_args in ( | ||
['echo'], | ||
['echo', {}], | ||
['foo'], | ||
# This causes an error running the server's buggy request handler. | ||
# aiorpcX catches the problem, returning an 'internal server error' to the | ||
# client, and continues serving | ||
['sum', [2, 4, "b"]] | ||
): | ||
try: | ||
await session.send_request(*bad_args) | ||
except Exception as exc: | ||
print(repr(exc)) | ||
|
||
# Batch requests | ||
async with session.send_batch() as batch: | ||
batch.add_request('echo', ["Me again"]) | ||
batch.add_notification('ping') | ||
batch.add_request('sum', list(range(50))) | ||
|
||
for n, result in enumerate(batch.results, start=1): | ||
print(f'batch result #{n}: {result}') | ||
|
||
|
||
asyncio.get_event_loop().run_until_complete(main('/tmp/test.sock')) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import asyncio | ||
import aiorpcx | ||
|
||
|
||
# Handlers are declared as normal python functions. aiorpcx automatically checks RPC | ||
# arguments, including named arguments, and returns errors as appropriate | ||
async def handle_echo(message): | ||
return message | ||
|
||
async def handle_sum(*values): | ||
return sum(values, 0) | ||
|
||
|
||
handlers = { | ||
'echo': handle_echo, | ||
'sum': handle_sum, | ||
} | ||
|
||
|
||
class ServerSession(aiorpcx.RPCSession): | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
print('connected') | ||
|
||
async def connection_lost(self): | ||
await super().connection_lost() | ||
print('disconnected') | ||
|
||
async def handle_request(self, request): | ||
handler = handlers.get(request.method) | ||
coro = aiorpcx.handler_invocation(handler, request)() | ||
return await coro | ||
|
||
|
||
loop = asyncio.get_event_loop() | ||
loop.run_until_complete(aiorpcx.serve_us(ServerSession, '/tmp/test.sock')) | ||
loop.run_forever() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import sys | ||
import asyncio | ||
import pytest | ||
import tempfile | ||
from os import path | ||
from aiorpcx import * | ||
from test_session import MyServerSession | ||
|
||
if sys.platform.startswith("win"): | ||
pytest.skip("skipping tests not compatible with Windows platform", allow_module_level=True) | ||
|
||
@pytest.fixture | ||
def us_server(event_loop): | ||
with tempfile.TemporaryDirectory() as tmp_folder: | ||
socket_path = path.join(tmp_folder, 'test.socket') | ||
coro = serve_us(MyServerSession, socket_path, loop=event_loop) | ||
server = event_loop.run_until_complete(coro) | ||
yield socket_path | ||
if hasattr(asyncio, 'all_tasks'): | ||
tasks = asyncio.all_tasks(event_loop) | ||
else: | ||
tasks = asyncio.Task.all_tasks(loop=event_loop) | ||
async def close_all(): | ||
server.close() | ||
await server.wait_closed() | ||
if tasks: | ||
await asyncio.wait(tasks) | ||
event_loop.run_until_complete(close_all()) | ||
|
||
|
||
class TestUSTransport: | ||
|
||
@pytest.mark.asyncio | ||
async def test_send_request(self, us_server): | ||
async with connect_us(us_server) as session: | ||
assert await session.send_request('echo', [23]) == 23 | ||
|
||
@pytest.mark.asyncio | ||
async def test_is_closing(self, us_server): | ||
async with connect_us(us_server) as session: | ||
assert not session.is_closing() | ||
await session.close() | ||
assert session.is_closing() | ||
|
||
async with connect_us(us_server) as session: | ||
assert not session.is_closing() | ||
await session.abort() | ||
assert session.is_closing() |