diff --git a/thriftpy2/contrib/aio/socket.py b/thriftpy2/contrib/aio/socket.py index 8d1e49d..256cd0a 100644 --- a/thriftpy2/contrib/aio/socket.py +++ b/thriftpy2/contrib/aio/socket.py @@ -93,8 +93,7 @@ def __init__(self, host=None, port=None, unix_socket=None, ciphers=ciphers) if cafile or capath: - self.ssl_context.load_verify_locations(cafile=cafile, - capath=capath) + self.ssl_context.load_verify_locations(cafile=cafile, capath=capath) if certfile: self.ssl_context.load_cert_chain(certfile, keyfile=keyfile) @@ -106,58 +105,23 @@ def __init__(self, host=None, port=None, unix_socket=None, self.ssl_context = None self.server_hostname = None - def _init_sock(self): - if self.unix_socket: - _sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - else: - _sock = socket.socket(self.socket_family, socket.SOCK_STREAM) - _sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - - # socket options - linger = struct.pack('ii', 0, 0) - _sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, linger) - _sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) - - self.raw_sock = _sock - - def set_handle(self, sock): - self.raw_sock = sock - - def set_timeout(self, ms): - """Backward compat api, will bind the timeout to both connect_timeout - and socket_timeout. - """ - self.socket_timeout = ms / 1000 if (ms and ms > 0) else None - self.connect_timeout = self.socket_timeout - - if self.raw_sock is not None: - self.raw_sock.settimeout(self.socket_timeout) - - def is_open(self): - return bool(self.raw_sock) - async def open(self): - self._init_sock() - addr = self.unix_socket or (self.host, self.port) - try: - if self.connect_timeout: - self.raw_sock.settimeout(self.connect_timeout) - - self.raw_sock.connect(addr) - - if self.socket_timeout: - self.raw_sock.settimeout(self.socket_timeout) - - kwargs = {'sock': self.raw_sock, 'ssl': self.ssl_context} - if self.server_hostname: - kwargs['server_hostname'] = self.server_hostname - - self.reader, self.writer = await asyncio.wait_for( - self.sock_factory(**kwargs), - self.socket_timeout - ) + if self.unix_socket: + self.reader, self.writer = await asyncio.wait_for( + asyncio.open_unix_connection(addr), self.connect_timeout + ) + else: + self.reader, self.writer = await asyncio.wait_for( + asyncio.open_connection(self.host, self.port, ssl=self.ssl_context), + self.connect_timeout, + ) + sock = self.writer.get_extra_info("socket") + # socket options + linger = struct.pack("ii", 0, 0) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, linger) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) except (socket.error, OSError): raise TTransportException( @@ -166,9 +130,14 @@ async def open(self): async def read(self, sz): try: - buff = await asyncio.wait_for( - self.reader.read(sz), - self.connect_timeout + buff = await asyncio.wait_for(self.reader.read(sz), self.connect_timeout) + except asyncio.TimeoutError: + raise TTransportException( + type=TTransportException.TIMED_OUT, message="TSocket read timed out" + ) + except asyncio.IncompleteReadError as e: + raise TTransportException( + type=TTransportException.END_OF_FILE, message="TSocket read 0 bytes" ) except socket.error as e: if e.errno == errno.ECONNRESET and MAC_OR_BSD: @@ -199,7 +168,6 @@ def close(self): try: self.writer.close() - self.raw_sock.close() self.raw_sock = None except (socket.error, OSError): pass @@ -251,7 +219,7 @@ def __init__(self, host=None, port=None, unix_socket=None, self.ssl_context = ssl_context elif certfile: if not os.access(certfile, os.R_OK): - raise IOError('No such certfile found: %s' % certfile) + raise IOError("No such certfile found: %s" % certfile) self.ssl_context = create_thriftpy_context(server_side=True, ciphers=ciphers)