diff --git a/.bandit b/.bandit old mode 100644 new mode 100755 diff --git a/.coveragerc b/.coveragerc old mode 100644 new mode 100755 diff --git a/.editorconfig b/.editorconfig old mode 100644 new mode 100755 diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml old mode 100644 new mode 100755 diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md old mode 100644 new mode 100755 diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md old mode 100644 new mode 100755 diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md old mode 100644 new mode 100755 diff --git a/.gitignore b/.gitignore old mode 100644 new mode 100755 diff --git a/.travis.yml b/.travis.yml old mode 100644 new mode 100755 diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md old mode 100644 new mode 100755 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md old mode 100644 new mode 100755 diff --git a/LICENSE b/LICENSE old mode 100644 new mode 100755 diff --git a/MANIFEST.in b/MANIFEST.in old mode 100644 new mode 100755 diff --git a/Makefile b/Makefile old mode 100644 new mode 100755 diff --git a/README.md b/README.md old mode 100644 new mode 100755 diff --git a/pytest.ini b/pytest.ini old mode 100644 new mode 100755 diff --git a/requirements.txt b/requirements.txt old mode 100644 new mode 100755 diff --git a/rethinkdb/__init__.py b/rethinkdb/__init__.py old mode 100644 new mode 100755 diff --git a/rethinkdb/__main__.py b/rethinkdb/__main__.py old mode 100644 new mode 100755 diff --git a/rethinkdb/ast.py b/rethinkdb/ast.py old mode 100644 new mode 100755 index 3b9fddc6..7a1b1fdc --- a/rethinkdb/ast.py +++ b/rethinkdb/ast.py @@ -17,7 +17,6 @@ __all__ = ["expr", "RqlQuery", "ReQLEncoder", "ReQLDecoder", "Repl"] - import base64 import binascii import collections @@ -26,7 +25,8 @@ import threading from rethinkdb import ql2_pb2 -from rethinkdb.errors import QueryPrinter, ReqlDriverCompileError, ReqlDriverError, T +from rethinkdb.errors import (QueryPrinter, ReqlDriverCompileError, + ReqlDriverError, T) P_TERM = ql2_pb2.Term.TermType @@ -77,7 +77,8 @@ def expr(val, nesting_depth=20): Convert a Python primitive into a RQL primitive value """ if not isinstance(nesting_depth, int): - raise ReqlDriverCompileError("Second argument to `r.expr` must be a number.") + raise ReqlDriverCompileError( + "Second argument to `r.expr` must be a number.") if nesting_depth <= 0: raise ReqlDriverCompileError("Nesting depth limit exceeded.") @@ -95,9 +96,7 @@ def expr(val, nesting_depth=20): timezone values with r.make_timezone(\"[+-]HH:MM\"). Alternatively, use one of ReQL's bultin time constructors, r.now, r.time, or r.iso8601. - """ - % (type(val).__name__) - ) + """ % (type(val).__name__)) return ISO8601(val.isoformat()) elif isinstance(val, RqlBinary): return Binary(val) @@ -136,12 +135,10 @@ def run(self, c=None, **global_optargs): if Repl.repl_active: raise ReqlDriverError( "RqlQuery.run must be given a connection to run on. A default connection has been set with " - "`repl()` on another thread, but not this one." - ) + "`repl()` on another thread, but not this one.") else: raise ReqlDriverError( - "RqlQuery.run must be given a connection to run on." - ) + "RqlQuery.run must be given a connection to run on.") return c._start(self, **global_optargs) @@ -155,7 +152,7 @@ def __repr__(self): # Compile this query to a json-serializable object def build(self): res = [self.term_type, self._args] - if len(self.optargs) > 0: + if self.optargs: res.append(self.optargs) return res @@ -392,7 +389,10 @@ def set_difference(self, *args): def __getitem__(self, index): if isinstance(index, slice): if index.stop: - return Slice(self, index.start or 0, index.stop, bracket_operator=True) + return Slice(self, + index.start or 0, + index.stop, + bracket_operator=True) else: return Slice( self, @@ -408,8 +408,7 @@ def __iter__(*args, **kwargs): raise ReqlDriverError( "__iter__ called on an RqlQuery object.\n" "To iterate over the results of a query, call run first.\n" - "To iterate inside a query, use map or for_each." - ) + "To iterate inside a query, use map or for_each.") def get_field(self, *args): return GetField(self, *args) @@ -468,7 +467,7 @@ def max(self, *args, **kwargs): def map(self, *args): if len(args) > 0: # `func_wrap` only the last argument - return Map(self, *(args[:-1] + (func_wrap(args[-1]),))) + return Map(self, *(args[:-1] + (func_wrap(args[-1]), ))) else: return Map(self) @@ -481,7 +480,8 @@ def fold(self, *args, **kwargs): kwfuncargs = {} for arg_name in kwargs: kwfuncargs[arg_name] = func_wrap(kwargs[arg_name]) - return Fold(self, *(args[:-1] + (func_wrap(args[-1]),)), **kwfuncargs) + return Fold(self, *(args[:-1] + (func_wrap(args[-1]), )), + **kwfuncargs) else: return Fold(self) @@ -492,7 +492,10 @@ def concat_map(self, *args): return ConcatMap(self, *[func_wrap(arg) for arg in args]) def order_by(self, *args, **kwargs): - args = [arg if isinstance(arg, (Asc, Desc)) else func_wrap(arg) for arg in args] + args = [ + arg if isinstance(arg, (Asc, Desc)) else func_wrap(arg) + for arg in args + ] return OrderBy(self, *args, **kwargs) def between(self, *args, **kwargs): @@ -625,21 +628,25 @@ def needs_wrap(arg): class RqlBoolOperQuery(RqlQuery): + statement_infix = None + def __init__(self, *args, **optargs): + super().__init__(*args, **optargs) self.infix = False - RqlQuery.__init__(self, *args, **optargs) def set_infix(self): self.infix = True def compose(self, args, optargs): t_args = [ - T("r.expr(", args[i], ")") if needs_wrap(self._args[i]) else args[i] + T("r.expr(", args[i], ")") + if needs_wrap(self._args[i]) else args[i] for i in xrange(len(args)) ] if self.infix: - return T("(", T(*t_args, intsp=[" ", self.statement_infix, " "]), ")") + t_args = T(*t_args, intsp=[" ", self.statement_infix, " "]) + return T("(", t_args, ")") else: return T("r.", self.statement, "(", T(*t_args, intsp=", "), ")") @@ -647,7 +654,8 @@ def compose(self, args, optargs): class RqlBiOperQuery(RqlQuery): def compose(self, args, optargs): t_args = [ - T("r.expr(", args[i], ")") if needs_wrap(self._args[i]) else args[i] + T("r.expr(", args[i], ")") + if needs_wrap(self._args[i]) else args[i] for i in xrange(len(args)) ] return T("(", T(*t_args, intsp=[" ", self.statement, " "]), ")") @@ -666,11 +674,10 @@ def __init__(self, *args, **optargs): "This is almost always a precedence error.\n" "Note that `a < b | b < c` <==> `a < (b | b) < c`.\n" "If you really want this behavior, use `.or_` or " - "`.and_` instead." - ) + "`.and_` instead.") raise ReqlDriverCompileError( - err % (self.statement, QueryPrinter(self).print_query()) - ) + err % + (self.statement, QueryPrinter(self).print_query())) except AttributeError: pass # No infix attribute, so not possible to be an infix bool operator @@ -723,7 +730,7 @@ def __init__(self, offsetstr): self.delta = datetime.timedelta(hours=hours, minutes=minutes) def __getinitargs__(self): - return (self.offsetstr,) + return (self.offsetstr, ) def __copy__(self): return RqlTzinfo(self.offsetstr) @@ -751,9 +758,8 @@ def recursively_make_hashable(obj): if isinstance(obj, list): return tuple([recursively_make_hashable(i) for i in obj]) elif isinstance(obj, dict): - return frozenset( - [(k, recursively_make_hashable(v)) for k, v in dict_items(obj)] - ) + return frozenset([(k, recursively_make_hashable(v)) + for k, v in dict_items(obj)]) return obj @@ -789,17 +795,12 @@ def __init__(self, reql_format_opts=None): def convert_time(self, obj): if "epoch_time" not in obj: raise ReqlDriverError( - ( - "pseudo-type TIME object %s does not " - + 'have expected field "epoch_time".' - ) - % json.dumps(obj) - ) + ("pseudo-type TIME object %s does not " + + 'have expected field "epoch_time".') % json.dumps(obj)) if "timezone" in obj: - return datetime.datetime.fromtimestamp( - obj["epoch_time"], RqlTzinfo(obj["timezone"]) - ) + return datetime.datetime.fromtimestamp(obj["epoch_time"], + RqlTzinfo(obj["timezone"])) else: return datetime.datetime.utcfromtimestamp(obj["epoch_time"]) @@ -807,24 +808,18 @@ def convert_time(self, obj): def convert_grouped_data(obj): if "data" not in obj: raise ReqlDriverError( - ( - "pseudo-type GROUPED_DATA object" - + ' %s does not have the expected field "data".' - ) - % json.dumps(obj) - ) - return dict([(recursively_make_hashable(k), v) for k, v in obj["data"]]) + ("pseudo-type GROUPED_DATA object" + + ' %s does not have the expected field "data".') % + json.dumps(obj)) + return dict([(recursively_make_hashable(k), v) + for k, v in obj["data"]]) @staticmethod def convert_binary(obj): if "data" not in obj: raise ReqlDriverError( - ( - "pseudo-type BINARY object %s does not have " - + 'the expected field "data".' - ) - % json.dumps(obj) - ) + ("pseudo-type BINARY object %s does not have " + + 'the expected field "data".') % json.dumps(obj)) return RqlBinary(base64.b64decode(obj["data"].encode("utf-8"))) def convert_pseudotype(self, obj): @@ -837,16 +832,14 @@ def convert_pseudotype(self, obj): return self.convert_time(obj) elif time_format != "raw": raise ReqlDriverError( - 'Unknown time_format run option "%s".' % time_format - ) + 'Unknown time_format run option "%s".' % time_format) elif reql_type == "GROUPED_DATA": group_format = self.reql_format_opts.get("group_format") if group_format is None or group_format == "native": return self.convert_grouped_data(obj) elif group_format != "raw": raise ReqlDriverError( - 'Unknown group_format run option "%s".' % group_format - ) + 'Unknown group_format run option "%s".' % group_format) elif reql_type == "GEOMETRY": # No special support for this. Just return the raw object return obj @@ -856,8 +849,8 @@ def convert_pseudotype(self, obj): return self.convert_binary(obj) elif binary_format != "raw": raise ReqlDriverError( - 'Unknown binary_format run option "%s".' % binary_format - ) + 'Unknown binary_format run option "%s".' % + binary_format) else: raise ReqlDriverError("Unknown pseudo-type %s" % reql_type) # If there was no pseudotype, or the relevant format is raw, return @@ -909,12 +902,13 @@ def build(self): return self.optargs def compose(self, args, optargs): + list_comp = [ + T(repr(key), ": ", value) for key, value in dict_items(optargs) + ] + t_value = T(*list_comp, intsp=", ") return T( "r.expr({", - T( - *[T(repr(key), ": ", value) for key, value in dict_items(optargs)], - intsp=", " - ), + t_value, "})", ) @@ -1236,13 +1230,16 @@ class FunCall(RqlQuery): # before passing it down to the base class constructor. def __init__(self, *args): if len(args) == 0: - raise ReqlDriverCompileError("Expected 1 or more arguments but found 0.") + raise ReqlDriverCompileError( + "Expected 1 or more arguments but found 0.") args = [func_wrap(args[-1])] + list(args[:-1]) RqlQuery.__init__(self, *args) def compose(self, args, optargs): if len(args) != 2: - return T("r.do(", T(T(*(args[1:]), intsp=", "), args[0], intsp=", "), ")") + return T("r.do(", T(T(*(args[1:]), intsp=", "), + args[0], + intsp=", "), ")") if isinstance(self._args[1], Datum): args[1] = T("r.expr(", args[1], ")") @@ -1712,12 +1709,10 @@ def __new__(cls, *args, **kwargs): def __repr__(self): excerpt = binascii.hexlify(self[0:6]).decode("utf-8") - excerpt = " ".join([excerpt[i : i + 2] for i in xrange(0, len(excerpt), 2)]) - excerpt = ( - ", '%s%s'" % (excerpt, "..." if len(self) > 6 else "") - if len(self) > 0 - else "" - ) + excerpt = " ".join( + [excerpt[i:i + 2] for i in xrange(0, len(excerpt), 2)]) + excerpt = (", '%s%s'" % + (excerpt, "..." if len(self) > 6 else "") if self else "") return "" % ( len(self), "s" if len(self) != 1 else "", @@ -1741,16 +1736,11 @@ def __init__(self, data): raise ReqlDriverCompileError( "Cannot convert a unicode string to binary, " "use `unicode.encode()` to specify the " - "encoding." - ) + "encoding.") elif not isinstance(data, bytes): raise ReqlDriverCompileError( - ( - "Cannot convert %s to binary, convert the " - "object to a `bytes` object first." - ) - % type(data).__name__ - ) + ("Cannot convert %s to binary, convert the " + "object to a `bytes` object first.") % type(data).__name__) else: self.base64_data = base64.b64encode(data) @@ -1759,16 +1749,17 @@ def __init__(self, data): self.optargs = {} def compose(self, args, optargs): - if len(self._args) == 0: + if self._args: return T("r.", self.statement, "(bytes())") - else: - return RqlTopLevelQuery.compose(self, args, optargs) + return RqlTopLevelQuery.compose(self, args, optargs) def build(self): - if len(self._args) == 0: - return {"$reql_type$": "BINARY", "data": self.base64_data.decode("utf-8")} - else: - return RqlTopLevelQuery.build(self) + if self._args: + return { + "$reql_type$": "BINARY", + "data": self.base64_data.decode("utf-8") + } + return RqlTopLevelQuery.build(self) class Range(RqlTopLevelQuery): @@ -1972,12 +1963,12 @@ def __init__(self, lmbd): self._args.extend([MakeArray(*vrids), expr(lmbd(*vrs))]) def compose(self, args, optargs): + list_comp = [ + v.compose([v._args[0].compose(None, None)], []) for v in self.vrs + ] return T( "lambda ", - T( - *[v.compose([v._args[0].compose(None, None)], []) for v in self.vrs], - intsp=", " - ), + T(*list_comp, intsp=", "), ": ", args[1], ) diff --git a/rethinkdb/asyncio_net/__init__.py b/rethinkdb/asyncio_net/__init__.py old mode 100644 new mode 100755 diff --git a/rethinkdb/asyncio_net/net_asyncio.py b/rethinkdb/asyncio_net/net_asyncio.py old mode 100644 new mode 100755 index 781081e5..28ff8e8e --- a/rethinkdb/asyncio_net/net_asyncio.py +++ b/rethinkdb/asyncio_net/net_asyncio.py @@ -31,28 +31,24 @@ ) from rethinkdb.net import Connection as ConnectionBase from rethinkdb.net import Cursor, Query, Response, maybe_profile +from tasktools.taskloop import TaskLoop __all__ = ["Connection"] - pResponse = ql2_pb2.Response.ResponseType pQuery = ql2_pb2.Query.QueryType -@asyncio.coroutine -def _read_until(streamreader, delimiter): +async def _read_until(streamreader, delimiter): """Naive implementation of reading until a delimiter""" - buffer = bytearray() - - while True: - c = yield from streamreader.read(1) - if c == b"": - break # EOF - buffer.append(c[0]) - if c == delimiter: - break - - return bytes(buffer) + try: + result = await streamreader.readuntil(delimiter) + return bytes(result) + except asyncio.IncompleteReadError as ie_error: + return bytes(ie_error.partial) + except asyncio.LimitOverrunError as lo_error: + print("Amount of data exceeds the configured stream limit") + raise lo_error def reusable_waiter(loop, timeout): @@ -62,20 +58,21 @@ def reusable_waiter(loop, timeout): waiter = reusable_waiter(event_loop, 10.0) while some_condition: - yield from waiter(some_future) + await waiter(some_future) """ if timeout is not None: deadline = loop.time() + timeout else: deadline = None - @asyncio.coroutine - def wait(future): + async def wait(future): if deadline is not None: new_timeout = max(deadline - loop.time(), 0) else: new_timeout = None - return (yield from asyncio.wait_for(future, new_timeout, loop=loop)) + # loop parameter deprecated on py3.8 + result = await asyncio.wait_for(future, new_timeout) + return result return wait @@ -101,20 +98,19 @@ def __init__(self, *args, **kwargs): def __aiter__(self): return self - @asyncio.coroutine - def __anext__(self): + async def __anext__(self): try: - return (yield from self._get_next(None)) + result = await self._get_next(None) + return result except ReqlCursorEmpty: raise StopAsyncIteration - @asyncio.coroutine - def close(self): + async def close(self): if self.error is None: self.error = self._empty_error() if self.conn.is_open(): self.outstanding_requests += 1 - yield from self.conn._parent._stop(self) + await self.conn._parent._stop(self) def _extend(self, res_buf): Cursor._extend(self, res_buf) @@ -123,8 +119,7 @@ def _extend(self, res_buf): # Convenience function so users know when they've hit the end of the cursor # without having to catch an exception - @asyncio.coroutine - def fetch_next(self, wait=True): + async def fetch_next(self, wait=True): timeout = Cursor._wait_to_timeout(wait) waiter = reusable_waiter(self.conn._io_loop, timeout) while len(self.items) == 0 and self.error is None: @@ -132,33 +127,30 @@ def fetch_next(self, wait=True): if self.error is not None: raise self.error with translate_timeout_errors(): - yield from waiter(asyncio.shield(self.new_response)) + await waiter(asyncio.shield(self.new_response)) # If there is a (non-empty) error to be received, we return True, so the # user will receive it on the next `next` call. - return len(self.items) != 0 or not isinstance(self.error, RqlCursorEmpty) + return len(self.items) != 0 or not isinstance(self.error, + RqlCursorEmpty) def _empty_error(self): # We do not have RqlCursorEmpty inherit from StopIteration as that interferes # with mechanisms to return from a coroutine. return RqlCursorEmpty() - @asyncio.coroutine - def _get_next(self, timeout): + async def _get_next(self, timeout): waiter = reusable_waiter(self.conn._io_loop, timeout) while len(self.items) == 0: self._maybe_fetch_batch() if self.error is not None: raise self.error with translate_timeout_errors(): - yield from waiter(asyncio.shield(self.new_response)) + await waiter(asyncio.shield(self.new_response)) return self.items.popleft() def _maybe_fetch_batch(self): - if ( - self.error is None - and len(self.items) < self.threshold - and self.outstanding_requests == 0 - ): + if (self.error is None and len(self.items) < self.threshold + and self.outstanding_requests == 0): self.outstanding_requests += 1 asyncio.ensure_future(self.conn._parent._continue(self)) @@ -177,6 +169,15 @@ def __init__(self, parent, io_loop=None): self._io_loop = io_loop if self._io_loop is None: self._io_loop = asyncio.get_event_loop() + asyncio.set_event_loop(self._io_loop) + + @property + def writer(self): + return self._streamwriter + + @property + def reader(self): + return self._streamreader def client_port(self): if self.is_open(): @@ -186,8 +187,11 @@ def client_address(self): if self.is_open(): return self._streamwriter.get_extra_info("sockname")[0] - @asyncio.coroutine - def connect(self, timeout): + async def read_task(self): + task = TaskLoop(self._reader, [], {}, name='reader_rdb') + task.create() + + async def connect(self, timeout): try: ssl_context = None if len(self._parent.ssl) > 0: @@ -199,23 +203,20 @@ def connect(self, timeout): ssl_context.check_hostname = True # redundant with match_hostname ssl_context.load_verify_locations(self._parent.ssl["ca_certs"]) - self._streamreader, self._streamwriter = yield from asyncio.open_connection( + self._streamreader, self._streamwriter = await asyncio.open_connection( self._parent.host, self._parent.port, loop=self._io_loop, ssl=ssl_context, ) self._streamwriter.get_extra_info("socket").setsockopt( - socket.IPPROTO_TCP, socket.TCP_NODELAY, 1 - ) + socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) self._streamwriter.get_extra_info("socket").setsockopt( - socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1 - ) + socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) except Exception as err: raise ReqlDriverError( - "Could not connect to %s:%s. Error: %s" - % (self._parent.host, self._parent.port, str(err)) - ) + "Could not connect to %s:%s. Error: %s" % + (self._parent.host, self._parent.port, str(err))) try: self._parent.handshake.reset() @@ -227,41 +228,39 @@ def connect(self, timeout): break # This may happen in the `V1_0` protocol where we send two requests as # an optimization, then need to read each separately - if request is not "": + if request != "": self._streamwriter.write(request) - response = yield from asyncio.wait_for( - _read_until(self._streamreader, b"\0"), - timeout, - loop=self._io_loop, - ) + response = await asyncio.wait_for( + _read_until(self._streamreader, b"\0"), timeout) response = response[:-1] except ReqlAuthError: - yield from self.close() + await self.close() raise except ReqlTimeoutError as err: - yield from self.close() + await self.close() raise ReqlDriverError( "Connection interrupted during handshake with %s:%s. Error: %s" - % (self._parent.host, self._parent.port, str(err)) - ) + % (self._parent.host, self._parent.port, str(err))) except Exception as err: - yield from self.close() + await self.close() raise ReqlDriverError( - "Could not connect to %s:%s. Error: %s" - % (self._parent.host, self._parent.port, str(err)) - ) + "Could not connect to %s:%s. Error: %s" % + (self._parent.host, self._parent.port, str(err))) # Start a parallel function to perform reads # store a reference to it so it doesn't get destroyed - self._reader_task = asyncio.ensure_future(self._reader(), loop=self._io_loop) + + self._reader_task = asyncio.run_coroutine_threadsafe( + self.read_task(), self._io_loop) + # self._reader_task = asyncio.ensure_future(self._reader(), + # loop=self._io_loop) return self._parent def is_open(self): return not (self._closing or self._streamreader.at_eof()) - @asyncio.coroutine - def close(self, noreply_wait=False, token=None, exception=None): + async def close(self, noreply_wait=False, token=None, exception=None): self._closing = True if exception is not None: err_message = "Connection is closed (%s)." % str(exception) @@ -281,67 +280,87 @@ def close(self, noreply_wait=False, token=None, exception=None): if noreply_wait: noreply = Query(pQuery.NOREPLY_WAIT, token, None, None) - yield from self.run_query(noreply, False) + await self.run_query(noreply, False) self._streamwriter.close() + await self._streamwriter.wait_closed() # We must not wait for the _reader_task if we got an exception, because that # means that we were called from it. Waiting would lead to a deadlock. if self._reader_task and exception is None: - yield from self._reader_task + await asyncio.wrap_future(self._reader_task) return None - @asyncio.coroutine - def run_query(self, query, noreply): - self._streamwriter.write(query.serialize(self._parent._get_json_encoder(query))) - if noreply: - return None - - response_future = asyncio.Future() - self._user_queries[query.token] = (query, response_future) - return (yield from response_future) + async def run_query(self, query, noreply): + try: + serialized_query = query.serialize( + self._parent._get_json_encoder(query)) + self._streamwriter.write(serialized_query) + await self._streamwriter.drain() + if noreply: + return None + response_future = self._io_loop.create_future() + self._user_queries[query.token] = (query, response_future) + result = await response_future + return result + except asyncio.CancelledError as c_error: + raise c_error + except Exception as error: + raise error # The _reader coroutine runs in parallel, reading responses # off of the socket and forwarding them to the appropriate Future or Cursor. # This is shut down as a consequence of closing the stream, or an error in the # socket/protocol from the server. Unexpected errors in this coroutine will # close the ConnectionInstance and be passed to any open Futures or Cursors. - @asyncio.coroutine - def _reader(self): + async def _reader(self, *args, **kwargs): + # now the loop is on the taskloop try: - while True: - buf = yield from self._streamreader.readexactly(12) - (token, length,) = struct.unpack(" 10 else []) - ) + [self.items[x] for x in range(min(10, len(self.items)))] + + (["..."] if len(self.items) > 10 else [])) if val_str.endswith("'...']"): - val_str = val_str[: -len("'...']")] + "...]" + val_str = val_str[:-len("'...']")] + "...]" spacer_str = "\n" if "\n" in val_str else "" if self.error is None: status_str = "streaming" @@ -271,11 +267,10 @@ def __str__(self): def __repr__(self): val_str = pprint.pformat( - [self.items[x] for x in range(min(10, len(self.items)))] - + (["..."] if len(self.items) > 10 else []) - ) + [self.items[x] for x in range(min(10, len(self.items)))] + + (["..."] if len(self.items) > 10 else [])) if val_str.endswith("'...']"): - val_str = val_str[: -len("'...']")] + "...]" + val_str = val_str[:-len("'...']")] + "...]" spacer_str = "\n" if "\n" in val_str else "" if self.error is None: status_str = "streaming" @@ -301,11 +296,8 @@ def _error(self, message): self._extend(dummy_response) def _maybe_fetch_batch(self): - if ( - self.error is None - and len(self.items) < self.threshold - and self.outstanding_requests == 0 - ): + if (self.error is None and len(self.items) < self.threshold + and self.outstanding_requests == 0): self.outstanding_requests += 1 self.conn._parent._continue(self) @@ -346,27 +338,28 @@ def __init__(self, parent, timeout): deadline = time.time() + timeout try: - self._socket = socket.create_connection((self.host, self.port), timeout) + self._socket = socket.create_connection((self.host, self.port), + timeout) self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) if len(self.ssl) > 0: try: - if hasattr( - ssl, "SSLContext" - ): # Python2.7 and 3.2+, or backports.ssl + if hasattr(ssl, "SSLContext"): + # Python2.7 and 3.2+, or backports.ssl ssl_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) if hasattr(ssl_context, "options"): - ssl_context.options |= getattr(ssl, "OP_NO_SSLv2", 0) - ssl_context.options |= getattr(ssl, "OP_NO_SSLv3", 0) + ssl_context.options |= getattr( + ssl, "OP_NO_SSLv2", 0) + ssl_context.options |= getattr( + ssl, "OP_NO_SSLv3", 0) ssl_context.verify_mode = ssl.CERT_REQUIRED ssl_context.check_hostname = ( True # redundant with match_hostname ) ssl_context.load_verify_locations(self.ssl["ca_certs"]) self._socket = ssl_context.wrap_socket( - self._socket, server_hostname=self.host - ) + self._socket, server_hostname=self.host) else: # this does not disable SSLv2 or SSLv3 self._socket = ssl.wrap_socket( self._socket, @@ -378,8 +371,8 @@ def __init__(self, parent, timeout): self._socket.close() if "EOF occurred in violation of protocol" in str( - err - ) or "sslv3 alert handshake failure" in str(err): + err) or "sslv3 alert handshake failure" in str( + err): # probably on an older version of OpenSSL raise ReqlDriverError( "SSL handshake failed, likely because Python is linked against an old version of OpenSSL " @@ -387,15 +380,14 @@ def __init__(self, parent, timeout): "around by lowering the security setting on the server with the options " "`--tls-min-protocol TLSv1 --tls-ciphers " "EECDH+AESGCM:EDH+AESGCM:AES256+EECDH:AES256+EDH:AES256-SHA` (see server log for more " - "information): %s" % str(err) - ) + "information): %s" % str(err)) else: raise ReqlDriverError( "SSL handshake failed (see server log for more information): %s" - % str(err) - ) + % str(err)) try: - match_hostname(self._socket.getpeercert(), hostname=self.host) + match_hostname(self._socket.getpeercert(), + hostname=self.host) except CertificateError: self._socket.close() raise @@ -408,7 +400,7 @@ def __init__(self, parent, timeout): break # This may happen in the `V1_0` protocol where we send two requests as # an optimization, then need to read each separately - if request is not "": + if request != "": self.sendall(request) # The response from the server is a null-terminated string @@ -423,21 +415,18 @@ def __init__(self, parent, timeout): raise except ReqlDriverError as ex: self.close() - error = ( - str(ex) - .replace("receiving from", "during handshake with") - .replace("sending to", "during handshake with") - ) + error = (str(ex).replace("receiving from", + "during handshake with").replace( + "sending to", + "during handshake with")) raise ReqlDriverError(error) except socket.timeout as ex: self.close() raise ReqlTimeoutError(self.host, self.port) except Exception as ex: self.close() - raise ReqlDriverError( - "Could not connect to %s:%s. Error: %s" - % (self.host, self.port, str(ex)) - ) + raise ReqlDriverError("Could not connect to %s:%s. Error: %s" % + (self.host, self.port, str(ex))) def is_open(self): return self._socket is not None @@ -476,16 +465,13 @@ def recvall(self, length, deadline): # This should only happen with a timeout of 0 raise ReqlTimeoutError(self.host, self.port) elif ex.errno != errno.EINTR: - raise ReqlDriverError( - ("Connection interrupted " + "receiving from %s:%s - %s") - % (self.host, self.port, str(ex)) - ) + raise ReqlDriverError(("Connection interrupted " + + "receiving from %s:%s - %s") % + (self.host, self.port, str(ex))) except Exception as ex: self.close() - raise ReqlDriverError( - "Error receiving from %s:%s - %s" - % (self.host, self.port, str(ex)) - ) + raise ReqlDriverError("Error receiving from %s:%s - %s" % + (self.host, self.port, str(ex))) if len(chunk) == 0: self.close() @@ -505,14 +491,12 @@ def sendall(self, data): elif ex.errno != errno.EINTR: self.close() raise ReqlDriverError( - ("Connection interrupted " + "sending to %s:%s - %s") - % (self.host, self.port, str(ex)) - ) + ("Connection interrupted " + "sending to %s:%s - %s") % + (self.host, self.port, str(ex))) except Exception as ex: self.close() - raise ReqlDriverError( - "Error sending to %s:%s - %s" % (self.host, self.port, str(ex)) - ) + raise ReqlDriverError("Error sending to %s:%s - %s" % + (self.host, self.port, str(ex))) except BaseException: self.close() raise @@ -558,7 +542,8 @@ def close(self, noreply_wait=False, token=None): self._header_in_progress = None def run_query(self, query, noreply): - self._socket.sendall(query.serialize(self._parent._get_json_encoder(query))) + self._socket.sendall( + query.serialize(self._parent._get_json_encoder(query))) if noreply: return None @@ -567,7 +552,8 @@ def run_query(self, query, noreply): if res.type == pResponse.SUCCESS_ATOM: return maybe_profile(res.data[0], res) - elif res.type in (pResponse.SUCCESS_PARTIAL, pResponse.SUCCESS_SEQUENCE): + elif res.type in (pResponse.SUCCESS_PARTIAL, + pResponse.SUCCESS_SEQUENCE): cursor = DefaultCursor(self, query, res) return maybe_profile(cursor, res) elif res.type == pResponse.WAIT_COMPLETE: @@ -587,8 +573,12 @@ def _read_response(self, query, deadline=None): # of this response. The next 4 bytes give the # expected length of this response. if self._header_in_progress is None: - self._header_in_progress = self._socket.recvall(12, deadline) - (res_token, res_len,) = struct.unpack("