Skip to content

Commit

Permalink
Merge pull request #41 from adrianomarto/master
Browse files Browse the repository at this point in the history
Added support of unix sockets
  • Loading branch information
Neil authored Nov 13, 2021
2 parents ba055d9 + 20d2a18 commit e55950f
Show file tree
Hide file tree
Showing 5 changed files with 294 additions and 0 deletions.
2 changes: 2 additions & 0 deletions aiorpcx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .rawsocket import *
from .socks import *
from .session import *
from .unixsocket import *
from .util import *
from .websocket import *

Expand All @@ -18,5 +19,6 @@
rawsocket.__all__ +
socks.__all__ +
session.__all__ +
unixsocket.__all__ +
util.__all__ +
websocket.__all__)
166 changes: 166 additions & 0 deletions aiorpcx/unixsocket.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# Copyright (c) 2021, Adriano Marto Reis
#
# All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

'''Asyncio protocol abstraction.'''

__all__ = ('connect_us', 'serve_us')


import asyncio
from functools import partial

from aiorpcx.curio import Event, timeout_after, TaskTimeout
from aiorpcx.session import RPCSession, SessionBase, SessionKind


class ConnectionLostError(Exception):
pass


class USTransport(asyncio.Protocol):

def __init__(self, session_factory, framer, kind):
self.session_factory = session_factory
self.loop = asyncio.get_event_loop()
self.session = None
self.kind = kind
self._asyncio_transport = None
self._framer = framer
# Cleared when the send socket is full
self._can_send = Event()
self._can_send.set()
self._closed_event = Event()
self._process_messages_task = None

async def process_messages(self):
try:
await self.session.process_messages(self.receive_message)
except ConnectionLostError:
pass
finally:
self._closed_event.set()

async def receive_message(self):
return await self._framer.receive_message()

def connection_made(self, transport):
'''Called by asyncio when a connection is established.'''
self._asyncio_transport = transport
self.session = self.session_factory(self)
self._framer = self._framer or self.session.default_framer()
self._process_messages_task = self.loop.create_task(self.process_messages())

def connection_lost(self, _exeption):
'''Called by asyncio when the connection closes.
Tear down things done in connection_made.'''
# Release waiting tasks
self._can_send.set()
self._framer.fail(ConnectionLostError())

def data_received(self, data):
'''Called by asyncio when a message comes in.'''
self.session.data_received(data)
self._framer.received_bytes(data)

def pause_writing(self):
'''Called by asyncio the send buffer is full.'''
if not self.is_closing():
self._can_send.clear()
self._asyncio_transport.pause_reading()

def resume_writing(self):
'''Called by asyncio the send buffer has room.'''
if not self._can_send.is_set():
self._can_send.set()
self._asyncio_transport.resume_reading()

# API exposed to session
async def write(self, message):
await self._can_send.wait()
if not self.is_closing():
framed_message = self._framer.frame(message)
self._asyncio_transport.write(framed_message)

async def close(self, force_after):
'''Close the connection and return when closed.'''
if self._asyncio_transport:
self._asyncio_transport.close()
try:
async with timeout_after(force_after):
await self._closed_event.wait()
except TaskTimeout:
await self.abort()
await self._closed_event.wait()

async def abort(self):
if self._asyncio_transport:
self._asyncio_transport.abort()

def is_closing(self):
'''Return True if the connection is closing.'''
return self._closed_event.is_set() or self._asyncio_transport.is_closing()

def proxy(self):
'''Not applicable to unix sockets.'''
return None

def remote_address(self):
'''Not applicable to unix sockets'''
return None


class USClient:

def __init__(self, path=None, *, framer=None, **kwargs):
session_factory = kwargs.pop('session_factory', RPCSession)
self.protocol_factory = partial(USTransport, session_factory, framer,
SessionKind.CLIENT)
self.path = path
self.session = None
self.loop = kwargs.get('loop', asyncio.get_event_loop())
self.kwargs = kwargs

async def create_connection(self):
'''Initiate a connection.'''
return await self.loop.create_unix_connection(
self.protocol_factory, self.path, **self.kwargs)

async def __aenter__(self):
_transport, protocol = await self.create_connection()
self.session = protocol.session
assert isinstance(self.session, SessionBase)
return self.session

async def __aexit__(self, _type, _value, _traceback):
await self.session.close()


async def serve_us(session_factory, path=None, *, framer=None, loop=None, **kwargs):
loop = loop or asyncio.get_event_loop()
protocol_factory = partial(USTransport, session_factory, framer, SessionKind.SERVER)
return await loop.create_unix_server(protocol_factory, path, **kwargs)


connect_us = USClient
40 changes: 40 additions & 0 deletions examples/client_us.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import asyncio
import aiorpcx


async def main(path):
async with aiorpcx.connect_us(path) as session:
# A good request with standard argument passing
result = await session.send_request('echo', ["Howdy"])
print(result)
# A good request with named argument passing
result = await session.send_request('echo', {'message': "Hello with a named argument"})
print(result)

# aiorpcX transparently handles erroneous calls server-side, returning appropriate
# errors. This in turn causes an exception to be raised in the client.
for bad_args in (
['echo'],
['echo', {}],
['foo'],
# This causes an error running the server's buggy request handler.
# aiorpcX catches the problem, returning an 'internal server error' to the
# client, and continues serving
['sum', [2, 4, "b"]]
):
try:
await session.send_request(*bad_args)
except Exception as exc:
print(repr(exc))

# Batch requests
async with session.send_batch() as batch:
batch.add_request('echo', ["Me again"])
batch.add_notification('ping')
batch.add_request('sum', list(range(50)))

for n, result in enumerate(batch.results, start=1):
print(f'batch result #{n}: {result}')


asyncio.get_event_loop().run_until_complete(main('/tmp/test.sock'))
38 changes: 38 additions & 0 deletions examples/server_us.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import asyncio
import aiorpcx


# Handlers are declared as normal python functions. aiorpcx automatically checks RPC
# arguments, including named arguments, and returns errors as appropriate
async def handle_echo(message):
return message

async def handle_sum(*values):
return sum(values, 0)


handlers = {
'echo': handle_echo,
'sum': handle_sum,
}


class ServerSession(aiorpcx.RPCSession):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
print('connected')

async def connection_lost(self):
await super().connection_lost()
print('disconnected')

async def handle_request(self, request):
handler = handlers.get(request.method)
coro = aiorpcx.handler_invocation(handler, request)()
return await coro


loop = asyncio.get_event_loop()
loop.run_until_complete(aiorpcx.serve_us(ServerSession, '/tmp/test.sock'))
loop.run_forever()
48 changes: 48 additions & 0 deletions tests/test_unixsocket.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import sys
import asyncio
import pytest
import tempfile
from os import path
from aiorpcx import *
from test_session import MyServerSession

if sys.platform.startswith("win"):
pytest.skip("skipping tests not compatible with Windows platform", allow_module_level=True)

@pytest.fixture
def us_server(event_loop):
with tempfile.TemporaryDirectory() as tmp_folder:
socket_path = path.join(tmp_folder, 'test.socket')
coro = serve_us(MyServerSession, socket_path, loop=event_loop)
server = event_loop.run_until_complete(coro)
yield socket_path
if hasattr(asyncio, 'all_tasks'):
tasks = asyncio.all_tasks(event_loop)
else:
tasks = asyncio.Task.all_tasks(loop=event_loop)
async def close_all():
server.close()
await server.wait_closed()
if tasks:
await asyncio.wait(tasks)
event_loop.run_until_complete(close_all())


class TestUSTransport:

@pytest.mark.asyncio
async def test_send_request(self, us_server):
async with connect_us(us_server) as session:
assert await session.send_request('echo', [23]) == 23

@pytest.mark.asyncio
async def test_is_closing(self, us_server):
async with connect_us(us_server) as session:
assert not session.is_closing()
await session.close()
assert session.is_closing()

async with connect_us(us_server) as session:
assert not session.is_closing()
await session.abort()
assert session.is_closing()

0 comments on commit e55950f

Please sign in to comment.