diff --git a/.github/workflows/install-edgedb.sh b/.github/workflows/install-edgedb.sh index 0469c99f..f6b9e734 100755 --- a/.github/workflows/install-edgedb.sh +++ b/.github/workflows/install-edgedb.sh @@ -5,15 +5,20 @@ shopt -s nullglob srv="https://packages.edgedb.com" -curl -fL "${srv}/dist/x86_64-unknown-linux-musl/edgedb-cli" \ - > "/usr/local/bin/edgedb" +curl -fL "${srv}/dist/$(uname -m)-unknown-linux-musl/edgedb-cli" \ + > "/usr/bin/edgedb" -chmod +x "/usr/local/bin/edgedb" +chmod +x "/usr/bin/edgedb" -useradd --shell /bin/bash edgedb +if command -v useradd >/dev/null 2>&1; then + useradd --shell /bin/bash edgedb +else + # musllinux/alpine doesn't have useradd + adduser -s /bin/bash -D edgedb +fi su -l edgedb -c "edgedb server install" ln -s $(su -l edgedb -c "edgedb server info --latest --bin-path") \ - "/usr/local/bin/edgedb-server" + "/usr/bin/edgedb-server" edgedb-server --version diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index c9643910..e9e125ce 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -5,7 +5,8 @@ on: branches: - "master" - "ci" - - "[0-9]+.[0-9x]+*" + - "release/[0-9]+.x" + - "release/[0-9]+.[0-9]+.x" paths: - "edgedb/_version.py" @@ -39,7 +40,7 @@ jobs: mkdir -p dist/ echo "${VERSION}" > dist/VERSION - - uses: actions/upload-artifact@v2 + - uses: actions/upload-artifact@v3 with: name: dist path: dist/ @@ -52,7 +53,7 @@ jobs: PIP_DISABLE_PIP_VERSION_CHECK: 1 steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 with: fetch-depth: 50 submodules: true @@ -65,19 +66,43 @@ jobs: pip install -U setuptools wheel pip python setup.py sdist - - uses: actions/upload-artifact@v2 + - uses: actions/upload-artifact@v3 with: name: dist path: dist/*.tar.* - build-wheels: + build-wheels-matrix: needs: validate-release-request + runs-on: ubuntu-latest + outputs: + include: ${{ steps.set-matrix.outputs.include }} + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: "3.x" + - run: pip install cibuildwheel==2.19.2 + - id: set-matrix + # Cannot test on Musl distros yet. + env: + CIBW_SKIP: "cp312-*" + run: | + MATRIX_INCLUDE=$( + { + cibuildwheel --print-build-identifiers --platform linux --arch x86_64,aarch64 | grep cp | grep many | jq -nRc '{"only": inputs, "os": "ubuntu-latest"}' \ + && cibuildwheel --print-build-identifiers --platform macos --arch x86_64,arm64 | grep cp | jq -nRc '{"only": inputs, "os": "macos-latest"}' \ + && cibuildwheel --print-build-identifiers --platform windows --arch AMD64 | grep cp | jq -nRc '{"only": inputs, "os": "windows-2019"}' + } | jq -sc + ) + echo "include=$MATRIX_INCLUDE" >> $GITHUB_OUTPUT + build-wheels: + needs: build-wheels-matrix runs-on: ${{ matrix.os }} + name: Build ${{ matrix.only }} strategy: + fail-fast: false matrix: - os: [ubuntu-latest, macos-latest, windows-2019] - cibw_python: ["cp37-*", "cp38-*", "cp39-*", "cp310-*"] - cibw_arch: ["auto64"] + include: ${{ fromJson(needs.build-wheels-matrix.outputs.include) }} defaults: run: @@ -87,32 +112,32 @@ jobs: PIP_DISABLE_PIP_VERSION_CHECK: 1 steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 with: fetch-depth: 50 submodules: true - name: Setup WSL - if: ${{ steps.release.outputs.version == 0 && matrix.os == 'windows-2019' }} - uses: vampire/setup-wsl@v1 + if: ${{ matrix.os == 'windows-2019' }} + uses: vampire/setup-wsl@v2 with: wsl-shell-user: edgedb additional-packages: ca-certificates curl + - name: Set up QEMU + if: runner.os == 'Linux' + uses: docker/setup-qemu-action@v2 + - name: Install EdgeDB uses: edgedb/setup-edgedb@v1 - - uses: pypa/cibuildwheel@v2.3.1 + - uses: pypa/cibuildwheel@v2.19.2 + with: + only: ${{ matrix.only }} env: CIBW_BUILD_VERBOSITY: 1 - CIBW_BUILD: ${{ matrix.cibw_python }} - # Cannot test on Musl distros yet. - CIBW_SKIP: "*-musllinux*" - CIBW_ARCHS: ${{ matrix.cibw_arch }} - # EdgeDB doesn't run on CentOS 6, so use 2014 as baseline - CIBW_MANYLINUX_X86_64_IMAGE: "quay.io/pypa/manylinux2014_x86_64" CIBW_BEFORE_ALL_LINUX: > .github/workflows/install-edgedb.sh CIBW_TEST_EXTRAS: "test" @@ -126,7 +151,7 @@ jobs: && chmod -R go+rX "$(dirname $(dirname $(dirname $PY)))" && su -l edgedb -c "EDGEDB_PYTHON_TEST_CODEGEN_CMD=$CODEGEN $PY {project}/tests/__init__.py" - - uses: actions/upload-artifact@v2 + - uses: actions/upload-artifact@v3 with: name: dist path: wheelhouse/*.whl @@ -134,14 +159,16 @@ jobs: publish: needs: [build-sdist, build-wheels] runs-on: ubuntu-latest + permissions: + contents: write steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 with: fetch-depth: 5 submodules: false - - uses: actions/download-artifact@v2 + - uses: actions/download-artifact@v3 with: name: dist path: dist/ diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 5a66c55d..050de749 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -23,15 +23,27 @@ jobs: PIP_DISABLE_PIP_VERSION_CHECK: 1 strategy: + fail-fast: false matrix: python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] edgedb-version: [stable , nightly] - os: [ubuntu-latest, macos-latest, windows-2019] + os: [ubuntu-20.04, ubuntu-latest, macos-latest, windows-2019] loop: [asyncio, uvloop] exclude: # uvloop does not support windows - loop: uvloop os: windows-2019 + # Python 3.7 on ubuntu-22.04 has a broken OpenSSL 3.0 + - python-version: 3.7 + os: ubuntu-latest + - python-version: 3.8 + os: ubuntu-20.04 + - python-version: 3.9 + os: ubuntu-20.04 + - python-version: 3.10 + os: ubuntu-20.04 + - python-version: 3.11 + os: ubuntu-20.04 steps: - uses: actions/checkout@v2 @@ -70,7 +82,7 @@ jobs: server-version: ${{ matrix.edgedb-version }} - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 if: steps.release.outputs.version == 0 with: python-version: ${{ matrix.python-version }} diff --git a/docs/api/asyncio_client.rst b/docs/api/asyncio_client.rst index 54afccf5..7435dafa 100644 --- a/docs/api/asyncio_client.rst +++ b/docs/api/asyncio_client.rst @@ -15,6 +15,7 @@ Client .. py:function:: create_async_client(dsn=None, *, \ host=None, port=None, \ user=None, password=None, \ + secret_key=None, \ database=None, \ timeout=60, \ concurrency=None) @@ -44,26 +45,16 @@ Client the :ref:`DSN Specification `. :param host: - Database host address as one of the following: - - - an IP address or a domain name; - - an absolute path to the directory containing the database - server Unix-domain socket (not supported on Windows); - - a sequence of any of the above, in which case the addresses - will be tried in order, and the host of the first successful - connection will be used for the whole connection pool. + Database host address as an IP address or a domain name; If not specified, the following will be tried, in order: - host address(es) parsed from the *dsn* argument, - the value of the ``EDGEDB_HOST`` environment variable, - - on Unix, common directories used for EdgeDB Unix-domain - sockets: ``"/run/edgedb"`` and ``"/var/run/edgedb"``, - ``"localhost"``. :param port: - Port number to connect to at the server host - (or Unix-domain socket file extension). If multiple host + Port number to connect to at the server host. If multiple host addresses were specified, this parameter may specify a sequence of port numbers of the same length as the host sequence, or it may specify a single port number to be used for all host @@ -95,6 +86,14 @@ Client other users and applications may be able to read it without needing specific privileges. + :param secret_key: + Secret key to be used for authentication, if the server requires one. + If not specified, the value parsed from the *dsn* argument is used, + or the value of the ``EDGEDB_SECRET_KEY`` environment variable. + Note that the use of the environment variable is discouraged as + other users and applications may be able to read it without needing + specific privileges. + :param float timeout: Connection timeout in seconds. @@ -558,8 +557,8 @@ transaction), so we have to redo all the work done. Generally it's recommended to not execute any long running code within the transaction unless absolutely necessary. -Transactions allocate expensive server resources and having -too many concurrently running long-running transactions will +Transactions allocate expensive server resources, and having +too many concurrent long-running transactions will negatively impact the performance of the DB server. To rollback a transaction that is in progress raise an exception. diff --git a/docs/api/blocking_client.rst b/docs/api/blocking_client.rst index f50456c9..eca92240 100644 --- a/docs/api/blocking_client.rst +++ b/docs/api/blocking_client.rst @@ -15,6 +15,7 @@ Client .. py:function:: create_client(dsn=None, *, \ host=None, port=None, \ user=None, password=None, \ + secret_key=None, \ database=None, \ timeout=60, \ concurrency=None) @@ -44,26 +45,16 @@ Client the :ref:`DSN Specification `. :param host: - Database host address as one of the following: - - - an IP address or a domain name; - - an absolute path to the directory containing the database - server Unix-domain socket (not supported on Windows); - - a sequence of any of the above, in which case the addresses - will be tried in order, and the host of the first successful - connection will be used for the whole connection pool. + Database host address as an IP address or a domain name; If not specified, the following will be tried, in order: - host address(es) parsed from the *dsn* argument, - the value of the ``EDGEDB_HOST`` environment variable, - - on Unix, common directories used for EdgeDB Unix-domain - sockets: ``"/run/edgedb"`` and ``"/var/run/edgedb"``, - ``"localhost"``. :param port: - Port number to connect to at the server host - (or Unix-domain socket file extension). If multiple host + Port number to connect to at the server host. If multiple host addresses were specified, this parameter may specify a sequence of port numbers of the same length as the host sequence, or it may specify a single port number to be used for all host @@ -95,6 +86,14 @@ Client other users and applications may be able to read it without needing specific privileges. + :param secret_key: + Secret key to be used for authentication, if the server requires one. + If not specified, the value parsed from the *dsn* argument is used, + or the value of the ``EDGEDB_SECRET_KEY`` environment variable. + Note that the use of the environment variable is discouraged as + other users and applications may be able to read it without needing + specific privileges. + :param float timeout: Connection timeout in seconds. diff --git a/docs/index.rst b/docs/index.rst index 1b04d26e..1a647fb6 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -32,6 +32,10 @@ and :ref:`asyncio ` implementations. EdgeDB Python types documentation. +* :ref:`edgedb-python-codegen` + + Python code generation command-line tool documentation. + * :ref:`edgedb-python-advanced` Advanced usages of the state and optional customization. diff --git a/docs/usage.rst b/docs/usage.rst index a531a61b..575fdfe3 100644 --- a/docs/usage.rst +++ b/docs/usage.rst @@ -78,8 +78,8 @@ types and vice versa. See :ref:`edgedb-python-datatypes` for details. Client connection pools ----------------------- -For server-type type applications that handle frequent requests and need -the database connection for a short period time while handling a request, +For server-type applications that handle frequent requests and need +the database connection for a short period of time while handling a request, the use of a connection pool is recommended. Both :py:class:`edgedb.Client` and :py:class:`edgedb.AsyncIOClient` come with such a pool. diff --git a/edgedb/__init__.py b/edgedb/__init__.py index bf565921..5ba24795 100644 --- a/edgedb/__init__.py +++ b/edgedb/__init__.py @@ -25,7 +25,7 @@ Tuple, NamedTuple, EnumValue, RelativeDuration, DateDuration, ConfigMemory ) from edgedb.datatypes.datatypes import Set, Object, Array, Link, LinkSet -from edgedb.datatypes.range import Range +from edgedb.datatypes.range import Range, MultiRange from .abstract import ( Executor, AsyncIOExecutor, ReadOnlyExecutor, AsyncIOReadOnlyExecutor, @@ -137,6 +137,7 @@ DuplicateFunctionDefinitionError, DuplicateConstraintDefinitionError, DuplicateCastDefinitionError, + DuplicateMigrationError, SessionTimeoutError, IdleSessionTimeoutError, QueryTimeoutError, @@ -147,6 +148,7 @@ DivisionByZeroError, NumericOutOfRangeError, AccessPolicyError, + QueryAssertionError, IntegrityError, ConstraintViolationError, CardinalityViolationError, @@ -155,11 +157,13 @@ TransactionConflictError, TransactionSerializationError, TransactionDeadlockError, + WatchError, ConfigurationError, AccessError, AuthenticationError, AvailabilityError, BackendUnavailableError, + ServerOfflineError, BackendError, UnsupportedBackendFeatureError, LogMessage, @@ -234,6 +238,7 @@ "DuplicateFunctionDefinitionError", "DuplicateConstraintDefinitionError", "DuplicateCastDefinitionError", + "DuplicateMigrationError", "SessionTimeoutError", "IdleSessionTimeoutError", "QueryTimeoutError", @@ -244,6 +249,7 @@ "DivisionByZeroError", "NumericOutOfRangeError", "AccessPolicyError", + "QueryAssertionError", "IntegrityError", "ConstraintViolationError", "CardinalityViolationError", @@ -252,11 +258,13 @@ "TransactionConflictError", "TransactionSerializationError", "TransactionDeadlockError", + "WatchError", "ConfigurationError", "AccessError", "AuthenticationError", "AvailabilityError", "BackendUnavailableError", + "ServerOfflineError", "BackendError", "UnsupportedBackendFeatureError", "LogMessage", diff --git a/edgedb/_testbase.py b/edgedb/_testbase.py index a228615f..47667add 100644 --- a/edgedb/_testbase.py +++ b/edgedb/_testbase.py @@ -135,7 +135,7 @@ def _start_cluster(*, cleanup_atexit=True): stderr=subprocess.STDOUT, ) - for _ in range(250): + for _ in range(600): try: with open(status_file, 'rb') as f: for line in f: @@ -171,6 +171,11 @@ def _start_cluster(*, cleanup_atexit=True): client = edgedb.create_client(password='test', **con_args) client.ensure_connected() + client.execute(""" + # Set session_idle_transaction_timeout to 5 minutes. + CONFIGURE INSTANCE SET session_idle_transaction_timeout := + '5 minutes'; + """) _default_cluster = { 'proc': p, 'client': client, diff --git a/edgedb/_version.py b/edgedb/_version.py index 275c66aa..101d8e88 100644 --- a/edgedb/_version.py +++ b/edgedb/_version.py @@ -28,4 +28,4 @@ # supported platforms, publish the packages on PyPI, merge the PR # to the target branch, create a Git tag pointing to the commit. -__version__ = '1.0.0' +__version__ = '1.6.1' diff --git a/edgedb/asyncio_client.py b/edgedb/asyncio_client.py index b30171e0..c03c6423 100644 --- a/edgedb/asyncio_client.py +++ b/edgedb/asyncio_client.py @@ -18,6 +18,7 @@ import asyncio +import contextlib import logging import socket import ssl @@ -273,11 +274,12 @@ def _warn_on_long_close(self): class AsyncIOIteration(transaction.BaseTransaction, abstract.AsyncIOExecutor): - __slots__ = ("_managed",) + __slots__ = ("_managed", "_locked") def __init__(self, retry, client, iteration): super().__init__(retry, client, iteration) self._managed = False + self._locked = False async def __aenter__(self): if self._managed: @@ -287,8 +289,9 @@ async def __aenter__(self): return self async def __aexit__(self, extype, ex, tb): - self._managed = False - return await self._exit(extype, ex) + with self._exclusive(): + self._managed = False + return await self._exit(extype, ex) async def _ensure_transaction(self): if not self._managed: @@ -298,6 +301,27 @@ async def _ensure_transaction(self): ) await super()._ensure_transaction() + async def _query(self, query_context: abstract.QueryContext): + with self._exclusive(): + return await super()._query(query_context) + + async def _execute(self, execute_context: abstract.ExecuteContext) -> None: + with self._exclusive(): + await super()._execute(execute_context) + + @contextlib.contextmanager + def _exclusive(self): + if self._locked: + raise errors.InterfaceError( + "concurrent queries within the same transaction " + "are not allowed" + ) + self._locked = True + try: + yield + finally: + self._locked = False + class AsyncIORetry(transaction.BaseRetry): @@ -378,6 +402,7 @@ def create_async_client( credentials_file: str = None, user: str = None, password: str = None, + secret_key: str = None, database: str = None, tls_ca: str = None, tls_ca_file: str = None, @@ -397,6 +422,7 @@ def create_async_client( credentials_file=credentials_file, user=user, password=password, + secret_key=secret_key, database=database, tls_ca=tls_ca, tls_ca_file=tls_ca_file, diff --git a/edgedb/base_client.py b/edgedb/base_client.py index 94d85972..331562e6 100644 --- a/edgedb/base_client.py +++ b/edgedb/base_client.py @@ -186,6 +186,10 @@ async def privileged_execute( qc=execute_context.cache.query_cache, output_format=protocol.OutputFormat.NONE, allow_capabilities=enums.Capability.ALL, + state=( + execute_context.state.as_dict() + if execute_context.state else None + ), ) def is_in_transaction(self) -> bool: @@ -670,6 +674,7 @@ def __init__( credentials_file: str = None, user: str = None, password: str = None, + secret_key: str = None, database: str = None, tls_ca: str = None, tls_ca_file: str = None, @@ -687,6 +692,7 @@ def __init__( "credentials_file": credentials_file, "user": user, "password": password, + "secret_key": secret_key, "database": database, "timeout": timeout, "tls_ca": tls_ca, diff --git a/edgedb/blocking_client.py b/edgedb/blocking_client.py index 4d35c43f..7eb761b9 100644 --- a/edgedb/blocking_client.py +++ b/edgedb/blocking_client.py @@ -17,6 +17,7 @@ # +import contextlib import datetime import queue import socket @@ -61,14 +62,10 @@ async def connect_addr(self, addr, timeout): raise TimeoutError # Upgrade to TLS - if self._params.ssl_ctx.check_hostname: - server_hostname = addr[0] - else: - server_hostname = None sock.settimeout(time_left) try: sock = self._params.ssl_ctx.wrap_socket( - sock, server_hostname=server_hostname + sock, server_hostname=addr[0] ) except ssl.CertificateError as e: raise con_utils.wrap_error(e) from e @@ -152,8 +149,8 @@ async def raw_query(self, query_context: abstract.QueryContext): time.monotonic() - self._protocol.last_active_timestamp > self._ping_wait_time ): - await self._protocol._sync() - except errors.ClientConnectionError: + await self._protocol.ping() + except (errors.IdleSessionTimeoutError, errors.ClientConnectionError): await self.connect() return await super().raw_query(query_context) @@ -275,22 +272,25 @@ async def close(self, timeout=None): class Iteration(transaction.BaseTransaction, abstract.Executor): - __slots__ = ("_managed",) + __slots__ = ("_managed", "_lock") def __init__(self, retry, client, iteration): super().__init__(retry, client, iteration) self._managed = False + self._lock = threading.Lock() def __enter__(self): - if self._managed: - raise errors.InterfaceError( - 'cannot enter context: already in a `with` block') - self._managed = True - return self + with self._exclusive(): + if self._managed: + raise errors.InterfaceError( + 'cannot enter context: already in a `with` block') + self._managed = True + return self def __exit__(self, extype, ex, tb): - self._managed = False - return self._client._iter_coroutine(self._exit(extype, ex)) + with self._exclusive(): + self._managed = False + return self._client._iter_coroutine(self._exit(extype, ex)) async def _ensure_transaction(self): if not self._managed: @@ -301,10 +301,24 @@ async def _ensure_transaction(self): await super()._ensure_transaction() def _query(self, query_context: abstract.QueryContext): - return self._client._iter_coroutine(super()._query(query_context)) + with self._exclusive(): + return self._client._iter_coroutine(super()._query(query_context)) def _execute(self, execute_context: abstract.ExecuteContext) -> None: - self._client._iter_coroutine(super()._execute(execute_context)) + with self._exclusive(): + self._client._iter_coroutine(super()._execute(execute_context)) + + @contextlib.contextmanager + def _exclusive(self): + if not self._lock.acquire(blocking=False): + raise errors.InterfaceError( + "concurrent queries within the same transaction " + "are not allowed" + ) + try: + yield + finally: + self._lock.release() class Retry(transaction.BaseRetry): @@ -401,6 +415,7 @@ def create_client( credentials_file: str = None, user: str = None, password: str = None, + secret_key: str = None, database: str = None, tls_ca: str = None, tls_ca_file: str = None, @@ -420,6 +435,7 @@ def create_client( credentials_file=credentials_file, user=user, password=password, + secret_key=secret_key, database=database, tls_ca=tls_ca, tls_ca_file=tls_ca_file, diff --git a/edgedb/codegen/cli.py b/edgedb/codegen/cli.py index 1a229bae..e8149ae6 100644 --- a/edgedb/codegen/cli.py +++ b/edgedb/codegen/cli.py @@ -23,11 +23,21 @@ from . import generator -parser = argparse.ArgumentParser( +class ColoredArgumentParser(argparse.ArgumentParser): + def error(self, message): + c = generator.C + self.exit( + 2, + f"{c.BOLD}{c.FAIL}error:{c.ENDC} " + f"{c.BOLD}{message:s}{c.ENDC}\n", + ) + + +parser = ColoredArgumentParser( description="Generate Python code for .edgeql files." ) parser.add_argument("--dsn") -parser.add_argument("--credentials_file", metavar="PATH") +parser.add_argument("--credentials-file", metavar="PATH") parser.add_argument("-I", "--instance", metavar="NAME") parser.add_argument("-H", "--host") parser.add_argument("-P", "--port") @@ -42,9 +52,15 @@ ) parser.add_argument( "--file", - action="store_true", + action="append", + nargs="?", help="Generate a single file instead of one per .edgeql file.", ) +parser.add_argument( + "--dir", + action="append", + help="Only search .edgeql files under specified directories.", +) parser.add_argument( "--target", choices=["blocking", "async"], diff --git a/edgedb/codegen/generator.py b/edgedb/codegen/generator.py index f357d50a..d2c42699 100644 --- a/edgedb/codegen/generator.py +++ b/edgedb/codegen/generator.py @@ -29,8 +29,10 @@ from edgedb import abstract from edgedb import describe from edgedb.con_utils import find_edgedb_project_dir +from edgedb.color import get_color +C = get_color() SYS_VERSION_INFO = os.getenv("EDGEDB_PYTHON_CODEGEN_PY_VER") if SYS_VERSION_INFO: SYS_VERSION_INFO = tuple(map(int, SYS_VERSION_INFO.split(".")))[:2] @@ -88,13 +90,20 @@ def __get_validators__(cls): """ +def print_msg(msg): + print(msg, file=sys.stderr) + + +def print_error(msg): + print_msg(f"{C.BOLD}{C.FAIL}error: {C.ENDC}{C.BOLD}{msg}{C.ENDC}") + + def _get_conn_args(args: argparse.Namespace): if args.password_from_stdin: if args.password: - print( + print_error( "--password and --password-from-stdin are " "mutually exclusive", - file=sys.stderr, ) sys.exit(22) if sys.stdin.isatty(): @@ -104,7 +113,7 @@ def _get_conn_args(args: argparse.Namespace): else: password = args.password if args.dsn and args.instance: - print("--dsn and --instance are mutually exclusive", file=sys.stderr) + print_error("--dsn and --instance are mutually exclusive") sys.exit(22) return dict( dsn=args.dsn or args.instance, @@ -133,9 +142,23 @@ def __init__(self, args: argparse.Namespace): "codegen must be run under an EdgeDB project dir" ) sys.exit(2) - print(f"Found EdgeDB project: {self._project_dir}", file=sys.stderr) + print_msg(f"Found EdgeDB project: {C.BOLD}{self._project_dir}{C.ENDC}") self._client = edgedb.create_client(**_get_conn_args(args)) - self._file_mode = args.file + self._single_mode_files = args.file + self._search_dirs = [] + for search_dir in args.dir or []: + search_dir = pathlib.Path(search_dir).absolute() + if ( + search_dir == self._project_dir + or self._project_dir in search_dir.parents + ): + self._search_dirs.append(search_dir) + else: + print( + f"--dir '{search_dir}' is not under " + f"the project directory: {self._project_dir}" + ) + sys.exit(1) self._method_names = set() self._describe_results = [] @@ -161,15 +184,20 @@ def run(self): print(f"Failed to connect to EdgeDB instance: {e}") sys.exit(61) with self._client: - self._process_dir(self._project_dir) + if self._search_dirs: + for search_dir in self._search_dirs: + self._process_dir(search_dir) + else: + self._process_dir(self._project_dir) for target, suffix, is_async in SUFFIXES: if target in self._targets: self._async = is_async - if self._file_mode: + if self._single_mode_files: self._generate_single_file(suffix) else: self._generate_files(suffix) self._new_file() + print_msg(f"{C.GREEN}{C.BOLD}Done.{C.ENDC}") def _process_dir(self, dir_: pathlib.Path): for file_or_dir in dir_.iterdir(): @@ -184,13 +212,13 @@ def _process_dir(self, dir_: pathlib.Path): self._process_file(file_or_dir) def _process_file(self, source: pathlib.Path): - print(f"Processing {source}", file=sys.stderr) + print_msg(f"{C.BOLD}Processing{C.ENDC} {C.BLUE}{source}{C.ENDC}") with source.open() as f: query = f.read() name = source.stem - if self._file_mode: + if self._single_mode_files: if name in self._method_names: - print(f"Conflict method names: {name}", file=sys.stderr) + print_error(f"Conflict method names: {name}") sys.exit(17) self._method_names.add(name) dr = self._client._describe_query(query, inject_type_names=True) @@ -199,7 +227,7 @@ def _process_file(self, source: pathlib.Path): def _generate_files(self, suffix: str): for name, source, query, dr in self._describe_results: target = source.parent / f"{name}{suffix}" - print(f"Generating {target}", file=sys.stderr) + print_msg(f"{C.BOLD}Generating{C.ENDC} {C.BLUE}{target}{C.ENDC}") self._new_file() content = self._generate(name, query, dr) buf = io.StringIO() @@ -210,8 +238,7 @@ def _generate_files(self, suffix: str): f.write(buf.getvalue()) def _generate_single_file(self, suffix: str): - target = self._project_dir / f"{FILE_MODE_OUTPUT_FILE}{suffix}" - print(f"Generating {target}", file=sys.stderr) + print_msg(f"{C.BOLD}Generating single file output...{C.ENDC}") buf = io.StringIO() output = [] sources = [] @@ -225,8 +252,15 @@ def _generate_single_file(self, suffix: str): if i < len(output) - 1: print(file=buf) print(file=buf) - with target.open("w") as f: - f.write(buf.getvalue()) + + for target in self._single_mode_files: + if target: + target = pathlib.Path(target).absolute() + else: + target = self._project_dir / f"{FILE_MODE_OUTPUT_FILE}{suffix}" + print_msg(f"{C.BOLD}Writing{C.ENDC} {C.BLUE}{target}{C.ENDC}") + with target.open("w") as f: + f.write(buf.getvalue()) def _write_comments( self, f: io.TextIOBase, src: typing.List[pathlib.Path] @@ -467,8 +501,8 @@ def _generate_code( buf = io.StringIO() self._imports.add("enum") print(f"class {rv}(enum.Enum):", file=buf) - for member in type_.members: - print(f'{INDENT}{member.upper()} = "{member}"', file=buf) + for member, member_id in self._to_unique_idents(type_.members): + print(f'{INDENT}{member_id.upper()} = "{member}"', file=buf) self._defs[rv] = buf.getvalue().strip() elif isinstance(type_, describe.RangeType): @@ -510,10 +544,7 @@ def _find_name(self, name: str) -> str: name = new break else: - print( - f"Failed to find a unique name for: {name}", - file=sys.stderr, - ) + print_error(f"Failed to find a unique name for: {name}") sys.exit(17) self._names.add(name) return name @@ -524,3 +555,37 @@ def _snake_to_camel(self, name: str) -> str: return "".join(map(str.title, parts)) else: return name + + def _to_unique_idents( + self, names: typing.Iterable[typing.Tuple[str, str]] + ) -> typing.Iterator[str]: + dedup = set() + for name in names: + if name.isidentifier(): + name_id = name + sep = name.endswith("_") + else: + sep = True + result = [] + for i, c in enumerate(name): + if c.isdigit(): + if i == 0: + result.append("e_") + result.append(c) + sep = False + elif c.isidentifier(): + result.append(c) + sep = c == "_" + elif not sep: + result.append("_") + sep = True + name_id = "".join(result) + rv = name_id + if not sep: + name_id = name_id + "_" + i = 1 + while rv in dedup: + rv = f"{name_id}{i}" + i += 1 + dedup.add(rv) + yield name, rv diff --git a/edgedb/color.py b/edgedb/color.py new file mode 100644 index 00000000..1e95c7bf --- /dev/null +++ b/edgedb/color.py @@ -0,0 +1,60 @@ +import os +import sys +import warnings + +COLOR = None + + +class Color: + HEADER = "" + BLUE = "" + CYAN = "" + GREEN = "" + WARNING = "" + FAIL = "" + ENDC = "" + BOLD = "" + UNDERLINE = "" + + +def get_color() -> Color: + global COLOR + + if COLOR is None: + COLOR = Color() + if type(USE_COLOR) is bool: + use_color = USE_COLOR + else: + try: + use_color = USE_COLOR() + except Exception: + use_color = False + if use_color: + COLOR.HEADER = '\033[95m' + COLOR.BLUE = '\033[94m' + COLOR.CYAN = '\033[96m' + COLOR.GREEN = '\033[92m' + COLOR.WARNING = '\033[93m' + COLOR.FAIL = '\033[91m' + COLOR.ENDC = '\033[0m' + COLOR.BOLD = '\033[1m' + COLOR.UNDERLINE = '\033[4m' + + return COLOR + + +try: + USE_COLOR = { + "default": lambda: sys.stderr.isatty(), + "auto": lambda: sys.stderr.isatty(), + "enabled": True, + "disabled": False, + }[ + os.getenv("EDGEDB_COLOR_OUTPUT", "default") + ] +except KeyError: + warnings.warn( + "EDGEDB_COLOR_OUTPUT can only be one of: " + "default, auto, enabled or disabled" + ) + USE_COLOR = False diff --git a/edgedb/con_utils.py b/edgedb/con_utils.py index d17cf32a..285790fc 100644 --- a/edgedb/con_utils.py +++ b/edgedb/con_utils.py @@ -17,6 +17,8 @@ # +import base64 +import binascii import errno import json import os @@ -71,6 +73,19 @@ HUMAN_US_RE = re.compile( r'((?:(?:\s|^)-\s*)?\d*\.?\d*)\s*(?i:us(\s|\d|\.|$)|microseconds?(\s|$))', ) +INSTANCE_NAME_RE = re.compile( + r'^(\w(?:-?\w)*)$', + re.ASCII, +) +CLOUD_INSTANCE_NAME_RE = re.compile( + r'^([A-Za-z0-9](?:-?[A-Za-z0-9])*)/([A-Za-z0-9](?:-?[A-Za-z0-9])*)$', + re.ASCII, +) +DSN_RE = re.compile( + r'^[a-z]+://', + re.IGNORECASE, +) +DOMAIN_LABEL_MAX_LENGTH = 63 class ClientConfiguration(typing.NamedTuple): @@ -175,6 +190,9 @@ class ResolvedConnectConfig: _password = None _password_source = None + _secret_key = None + _secret_key_source = None + _tls_ca_data = None _tls_ca_data_source = None @@ -183,6 +201,9 @@ class ResolvedConnectConfig: _wait_until_available = None + _cloud_profile = None + _cloud_profile_source = None + server_settings = {} def _set_param(self, param, value, source, validator=None): @@ -211,6 +232,9 @@ def set_user(self, user, source): def set_password(self, password, source): self._set_param('password', password, source) + def set_secret_key(self, secret_key, source): + self._set_param('secret_key', secret_key, source) + def set_tls_ca_data(self, ca_data, source): self._set_param('tls_ca_data', ca_data, source) @@ -256,6 +280,10 @@ def user(self): def password(self): return self._password + @property + def secret_key(self): + return self._secret_key + @property def tls_security(self): tls_security = self._tls_security or 'default' @@ -491,6 +519,7 @@ def _parse_connect_dsn_and_args( credentials_file, user, password, + secret_key, database, tls_ca, tls_ca_file, @@ -500,11 +529,10 @@ def _parse_connect_dsn_and_args( ): resolved_config = ResolvedConnectConfig() - dsn, instance_name = ( - (dsn, None) - if dsn is not None and re.match('(?i)^[a-z]+://', dsn) - else (None, dsn) - ) + if dsn and DSN_RE.match(dsn): + instance_name = None + else: + instance_name, dsn = dsn, None has_compound_options = _resolve_config_options( resolved_config, @@ -534,6 +562,10 @@ def _parse_connect_dsn_and_args( (password, '"password" option') if password is not None else None ), + secret_key=( + (secret_key, '"secret_key" option') + if secret_key is not None else None + ), tls_ca=( (tls_ca, '"tls_ca" option') if tls_ca is not None else None @@ -553,7 +585,7 @@ def _parse_connect_dsn_and_args( wait_until_available=( (wait_until_available, '"wait_until_available" option') if wait_until_available is not None else None - ) + ), ) if has_compound_options is False: @@ -574,10 +606,12 @@ def _parse_connect_dsn_and_args( env_database = os.getenv('EDGEDB_DATABASE') env_user = os.getenv('EDGEDB_USER') env_password = os.getenv('EDGEDB_PASSWORD') + env_secret_key = os.getenv('EDGEDB_SECRET_KEY') env_tls_ca = os.getenv('EDGEDB_TLS_CA') env_tls_ca_file = os.getenv('EDGEDB_TLS_CA_FILE') env_tls_security = os.getenv('EDGEDB_CLIENT_TLS_SECURITY') env_wait_until_available = os.getenv('EDGEDB_WAIT_UNTIL_AVAILABLE') + cloud_profile = os.getenv('EDGEDB_CLOUD_PROFILE') has_compound_options = _resolve_config_options( resolved_config, @@ -617,6 +651,10 @@ def _parse_connect_dsn_and_args( (env_password, '"EDGEDB_PASSWORD" environment variable') if env_password is not None else None ), + secret_key=( + (env_secret_key, '"EDGEDB_SECRET_KEY" environment variable') + if env_secret_key is not None else None + ), tls_ca=( (env_tls_ca, '"EDGEDB_TLS_CA" environment variable') if env_tls_ca is not None else None @@ -635,7 +673,12 @@ def _parse_connect_dsn_and_args( env_wait_until_available, '"EDGEDB_WAIT_UNTIL_AVAILABLE" environment variable' ) if env_wait_until_available is not None else None - ) + ), + cloud_profile=( + (cloud_profile, + '"EDGEDB_CLOUD_PROFILE" environment variable') + if cloud_profile is not None else None + ), ) if not has_compound_options: @@ -644,15 +687,31 @@ def _parse_connect_dsn_and_args( if os.path.exists(stash_dir): with open(os.path.join(stash_dir, 'instance-name'), 'rt') as f: instance_name = f.read().strip() + cloud_profile_file = os.path.join(stash_dir, 'cloud-profile') + if os.path.exists(cloud_profile_file): + with open(cloud_profile_file, 'rt') as f: + cloud_profile = f.read().strip() + else: + cloud_profile = None + + _resolve_config_options( + resolved_config, + '', + instance_name=( + instance_name, + f'project linked instance ("{instance_name}")' + ), + cloud_profile=( + cloud_profile, + f'project defined cloud profile ("{cloud_profile}")' + ), + ) - _resolve_config_options( - resolved_config, - '', - instance_name=( - instance_name, - f'project linked instance ("{instance_name}")' - ) - ) + opt_database_file = os.path.join(stash_dir, 'database') + if os.path.exists(opt_database_file): + with open(opt_database_file, 'rt') as f: + database = f.read().strip() + resolved_config.set_database(database, "project") else: raise errors.ClientConnectionError( f'Found `edgedb.toml` but the project is not initialized. ' @@ -774,6 +833,11 @@ def strip_leading_slash(str): resolved_config._password, resolved_config.set_password ) + handle_dsn_part( + 'secret_key', None, + resolved_config._secret_key, resolved_config.set_secret_key + ) + handle_dsn_part( 'tls_ca_file', None, resolved_config._tls_ca_data, resolved_config.set_tls_ca_file @@ -794,6 +858,66 @@ def strip_leading_slash(str): resolved_config.add_server_settings(query) +def _jwt_base64_decode(payload): + remainder = len(payload) % 4 + if remainder == 2: + payload += '==' + elif remainder == 3: + payload += '=' + elif remainder != 0: + raise errors.ClientConnectionError("Invalid secret key") + payload = base64.urlsafe_b64decode(payload.encode("utf-8")) + return json.loads(payload.decode("utf-8")) + + +def _parse_cloud_instance_name_into_config( + resolved_config: ResolvedConnectConfig, + source: str, + org_slug: str, + instance_name: str, +): + org_slug = org_slug.lower() + instance_name = instance_name.lower() + + label = f"{instance_name}--{org_slug}" + if len(label) > DOMAIN_LABEL_MAX_LENGTH: + raise ValueError( + f"invalid instance name: cloud instance name length cannot exceed " + f"{DOMAIN_LABEL_MAX_LENGTH - 1} characters: " + f"{org_slug}/{instance_name}" + ) + secret_key = resolved_config.secret_key + if secret_key is None: + try: + config_dir = platform.config_dir() + if resolved_config._cloud_profile is None: + profile = profile_src = "default" + else: + profile = resolved_config._cloud_profile + profile_src = resolved_config._cloud_profile_source + path = config_dir / "cloud-credentials" / f"{profile}.json" + with open(path, "rt") as f: + secret_key = json.load(f)["secret_key"] + except Exception: + raise errors.ClientConnectionError( + "Cannot connect to cloud instances without secret key." + ) + resolved_config.set_secret_key( + secret_key, + f"cloud-credentials/{profile}.json specified by {profile_src}", + ) + try: + dns_zone = _jwt_base64_decode(secret_key.split(".", 2)[1])["iss"] + except errors.EdgeDBError: + raise + except Exception: + raise errors.ClientConnectionError("Invalid secret key") + payload = f"{org_slug}/{instance_name}".encode("utf-8") + dns_bucket = binascii.crc_hqx(payload, 0) % 100 + host = f"{label}.c-{dns_bucket:02d}.i.{dns_zone}" + resolved_config.set_host(host, source) + + def _resolve_config_options( resolved_config: ResolvedConnectConfig, compound_error: str, @@ -807,11 +931,13 @@ def _resolve_config_options( database=None, user=None, password=None, + secret_key=None, tls_ca=None, tls_ca_file=None, tls_security=None, server_settings=None, wait_until_available=None, + cloud_profile=None, ): if database is not None: resolved_config.set_database(*database) @@ -819,6 +945,8 @@ def _resolve_config_options( resolved_config.set_user(*user) if password is not None: resolved_config.set_password(*password) + if secret_key is not None: + resolved_config.set_secret_key(*secret_key) if tls_ca_file is not None: if tls_ca is not None: raise errors.ClientConnectionError( @@ -832,6 +960,8 @@ def _resolve_config_options( resolved_config.add_server_settings(server_settings[0]) if wait_until_available is not None: resolved_config.set_wait_until_available(*wait_until_available) + if cloud_profile is not None: + resolved_config._set_param('cloud_profile', *cloud_profile) compound_params = [ dsn, @@ -868,22 +998,23 @@ def _resolve_config_options( else: creds = cred_utils.validate_credentials(cred_data) source = "credentials" + elif INSTANCE_NAME_RE.match(instance_name[0]): + source = instance_name[1] + creds = cred_utils.read_credentials( + cred_utils.get_credentials_path(instance_name[0]), + ) else: - if ( - re.match( - '^[A-Za-z_][A-Za-z_0-9]*$', - instance_name[0] - ) is None - ): + name_match = CLOUD_INSTANCE_NAME_RE.match(instance_name[0]) + if name_match is None: raise ValueError( f'invalid DSN or instance name: "{instance_name[0]}"' ) - - creds = cred_utils.read_credentials( - cred_utils.get_credentials_path(instance_name[0]), - ) - source = instance_name[1] + org, inst = name_match.groups() + _parse_cloud_instance_name_into_config( + resolved_config, source, org, inst + ) + return True resolved_config.set_host(creds.get('host'), source) resolved_config.set_port(creds.get('port'), source) @@ -939,6 +1070,7 @@ def parse_connect_arguments( database, user, password, + secret_key, tls_ca, tls_ca_file, tls_security, @@ -970,6 +1102,7 @@ def parse_connect_arguments( database=database, user=user, password=password, + secret_key=secret_key, tls_ca=tls_ca, tls_ca_file=tls_ca_file, tls_security=tls_security, diff --git a/edgedb/datatypes/range.py b/edgedb/datatypes/range.py index eaeb4bcb..e3fd3d1e 100644 --- a/edgedb/datatypes/range.py +++ b/edgedb/datatypes/range.py @@ -16,8 +16,8 @@ # limitations under the License. # -from typing import Generic, Optional, TypeVar - +from typing import (TypeVar, Any, Generic, Optional, Iterable, Iterator, + Sequence) T = TypeVar("T") @@ -78,8 +78,10 @@ def is_empty(self) -> bool: def __bool__(self): return not self.is_empty() - def __eq__(self, other): - if not isinstance(other, Range): + def __eq__(self, other) -> bool: + if isinstance(other, Range): + o = other + else: return NotImplemented return ( @@ -87,13 +89,13 @@ def __eq__(self, other): self._upper, self._inc_lower, self._inc_upper, - self._empty - ) == ( - other._lower, - other._upper, - other._inc_lower, - other._inc_upper, self._empty, + ) == ( + o._lower, + o._upper, + o._inc_lower, + o._inc_upper, + o._empty, ) def __hash__(self) -> int: @@ -125,3 +127,39 @@ def __str__(self) -> str: return f"" __repr__ = __str__ + + +# TODO: maybe we should implement range and multirange operations as well as +# normalization of the sub-ranges? +class MultiRange(Iterable[T]): + + _ranges: Sequence[T] + + def __init__(self, iterable: Optional[Iterable[T]] = None) -> None: + if iterable is not None: + self._ranges = tuple(iterable) + else: + self._ranges = tuple() + + def __len__(self) -> int: + return len(self._ranges) + + def __iter__(self) -> Iterator[T]: + return iter(self._ranges) + + def __reversed__(self) -> Iterator[T]: + return reversed(self._ranges) + + def __str__(self) -> str: + return f'' + + __repr__ = __str__ + + def __eq__(self, other: Any) -> bool: + if isinstance(other, MultiRange): + return set(self._ranges) == set(other._ranges) + else: + return NotImplemented + + def __hash__(self) -> int: + return hash(self._ranges) diff --git a/edgedb/datatypes/relative_duration.pyx b/edgedb/datatypes/relative_duration.pyx index cf2c0ab9..27c0c648 100644 --- a/edgedb/datatypes/relative_duration.pyx +++ b/edgedb/datatypes/relative_duration.pyx @@ -108,7 +108,12 @@ cdef class RelativeDuration: buf.append(f'{min}M') if sec or fsec: - sign = '-' if min < 0 or fsec < 0 else '' + # If the original microseconds are negative we expect '-' in front + # of all non-zero hour/min/second components. The hour/min sign + # can be taken as is, but seconds are constructed out of sec and + # fsec parts, both of which have their own sign and thus we cannot + # just use their string representations directly. + sign = '-' if self.microseconds < 0 else '' buf.append(f'{sign}{abs(sec)}') if fsec: diff --git a/edgedb/describe.py b/edgedb/describe.py index 05c16398..92e38854 100644 --- a/edgedb/describe.py +++ b/edgedb/describe.py @@ -91,3 +91,8 @@ class SparseObjectType(ObjectType): @dataclasses.dataclass(frozen=True) class RangeType(AnyType): value_type: AnyType + + +@dataclasses.dataclass(frozen=True) +class MultiRangeType(AnyType): + value_type: AnyType diff --git a/edgedb/errors/__init__.py b/edgedb/errors/__init__.py index c5ea6daa..424edc91 100644 --- a/edgedb/errors/__init__.py +++ b/edgedb/errors/__init__.py @@ -68,6 +68,7 @@ 'DuplicateFunctionDefinitionError', 'DuplicateConstraintDefinitionError', 'DuplicateCastDefinitionError', + 'DuplicateMigrationError', 'SessionTimeoutError', 'IdleSessionTimeoutError', 'QueryTimeoutError', @@ -78,6 +79,7 @@ 'DivisionByZeroError', 'NumericOutOfRangeError', 'AccessPolicyError', + 'QueryAssertionError', 'IntegrityError', 'ConstraintViolationError', 'CardinalityViolationError', @@ -86,11 +88,13 @@ 'TransactionConflictError', 'TransactionSerializationError', 'TransactionDeadlockError', + 'WatchError', 'ConfigurationError', 'AccessError', 'AuthenticationError', 'AvailabilityError', 'BackendUnavailableError', + 'ServerOfflineError', 'BackendError', 'UnsupportedBackendFeatureError', 'LogMessage', @@ -328,12 +332,17 @@ class DuplicateCastDefinitionError(DuplicateDefinitionError): _code = 0x_04_05_02_0A +class DuplicateMigrationError(DuplicateDefinitionError): + _code = 0x_04_05_02_0B + + class SessionTimeoutError(QueryError): _code = 0x_04_06_00_00 class IdleSessionTimeoutError(SessionTimeoutError): _code = 0x_04_06_01_00 + tags = frozenset({SHOULD_RETRY}) class QueryTimeoutError(SessionTimeoutError): @@ -368,6 +377,10 @@ class AccessPolicyError(InvalidValueError): _code = 0x_05_01_00_03 +class QueryAssertionError(InvalidValueError): + _code = 0x_05_01_00_04 + + class IntegrityError(ExecutionError): _code = 0x_05_02_00_00 @@ -403,6 +416,10 @@ class TransactionDeadlockError(TransactionConflictError): tags = frozenset({SHOULD_RETRY}) +class WatchError(ExecutionError): + _code = 0x_05_04_00_00 + + class ConfigurationError(EdgeDBError): _code = 0x_06_00_00_00 @@ -424,6 +441,11 @@ class BackendUnavailableError(AvailabilityError): tags = frozenset({SHOULD_RETRY}) +class ServerOfflineError(AvailabilityError): + _code = 0x_08_00_00_02 + tags = frozenset({SHOULD_RECONNECT, SHOULD_RETRY}) + + class BackendError(EdgeDBError): _code = 0x_09_00_00_00 diff --git a/edgedb/errors/_base.py b/edgedb/errors/_base.py index 9dec14f6..675ef567 100644 --- a/edgedb/errors/_base.py +++ b/edgedb/errors/_base.py @@ -17,6 +17,12 @@ # +import io +import os +import traceback +import unicodedata +import warnings + __all__ = ( 'EdgeDBError', 'EdgeDBMessage', ) @@ -79,6 +85,7 @@ class EdgeDBErrorMeta(Meta): class EdgeDBError(Exception, metaclass=EdgeDBErrorMeta): _code = None + _query = None tags = frozenset() def __init__(self, *args, **kwargs): @@ -93,15 +100,25 @@ def _position(self): # not a stable API method return int(self._read_str_field(FIELD_POSITION_START, -1)) + @property + def _position_start(self): + # not a stable API method + return int(self._read_str_field(FIELD_CHARACTER_START, -1)) + + @property + def _position_end(self): + # not a stable API method + return int(self._read_str_field(FIELD_CHARACTER_END, -1)) + @property def _line(self): # not a stable API method - return int(self._read_str_field(FIELD_LINE, -1)) + return int(self._read_str_field(FIELD_LINE_START, -1)) @property def _col(self): # not a stable API method - return int(self._read_str_field(FIELD_COLUMN, -1)) + return int(self._read_str_field(FIELD_COLUMN_START, -1)) @property def _hint(self): @@ -127,6 +144,35 @@ def _from_code(code, *args, **kwargs): exc._code = code return exc + def __str__(self): + msg = super().__str__() + if SHOW_HINT and self._query and self._position_start >= 0: + try: + return _format_error( + msg, + self._query, + self._position_start, + max(1, self._position_end - self._position_start), + self._line if self._line > 0 else "?", + self._col if self._col > 0 else "?", + self._hint or "error", + ) + except Exception: + return "".join( + ( + msg, + LINESEP, + LINESEP, + "During formatting of the above exception, " + "another exception occurred:", + LINESEP, + LINESEP, + traceback.format_exc(), + ) + ) + else: + return msg + def _lookup_cls(code: int, *, meta: type, default: type): try: @@ -180,6 +226,67 @@ def _severity_name(severity): return 'PANIC' +def _format_error(msg, query, start, offset, line, col, hint): + c = get_color() + rv = io.StringIO() + rv.write(f"{c.BOLD}{msg}{c.ENDC}{LINESEP}") + lines = query.splitlines(keepends=True) + num_len = len(str(len(lines))) + rv.write(f"{c.BLUE}{'':>{num_len}} ┌─{c.ENDC} query:{line}:{col}{LINESEP}") + rv.write(f"{c.BLUE}{'':>{num_len}} │ {c.ENDC}{LINESEP}") + for num, line in enumerate(lines): + length = len(line) + line = line.rstrip() # we'll use our own line separator + if start >= length: + # skip lines before the error + start -= length + continue + + if start >= 0: + # Error starts in current line, write the line before the error + first_half = repr(line[:start])[1:-1] + line = line[start:] + length -= start + rv.write(f"{c.BLUE}{num + 1:>{num_len}} │ {c.ENDC}{first_half}") + start = _unicode_width(first_half) + else: + # Multi-line error continues + rv.write(f"{c.BLUE}{num + 1:>{num_len}} │ {c.FAIL}│ {c.ENDC}") + + if offset > length: + # Error is ending beyond current line + line = repr(line)[1:-1] + rv.write(f"{c.FAIL}{line}{c.ENDC}{LINESEP}") + if start >= 0: + # Multi-line error starts + rv.write(f"{c.BLUE}{'':>{num_len}} │ " + f"{c.FAIL}╭─{'─' * start}^{c.ENDC}{LINESEP}") + offset -= length + start = -1 # mark multi-line + else: + # Error is ending within current line + first_half = repr(line[:offset])[1:-1] + line = repr(line[offset:])[1:-1] + rv.write(f"{c.FAIL}{first_half}{c.ENDC}{line}{LINESEP}") + size = _unicode_width(first_half) + if start >= 0: + # Mark single-line error + rv.write(f"{c.BLUE}{'':>{num_len}} │ {' ' * start}" + f"{c.FAIL}{'^' * size} {hint}{c.ENDC}") + else: + # End of multi-line error + rv.write(f"{c.BLUE}{'':>{num_len}} │ " + f"{c.FAIL}╰─{'─' * (size - 1)}^ {hint}{c.ENDC}") + break + return rv.getvalue() + + +def _unicode_width(text): + return sum(0 if unicodedata.category(c) in ('Mn', 'Cf') else + 2 if unicodedata.east_asian_width(c) == "W" else 1 + for c in text) + + FIELD_HINT = 0x_00_01 FIELD_DETAILS = 0x_00_02 FIELD_SERVER_TRACEBACK = 0x_01_01 @@ -187,8 +294,14 @@ def _severity_name(severity): # XXX: Subject to be changed/deprecated. FIELD_POSITION_START = 0x_FF_F1 FIELD_POSITION_END = 0x_FF_F2 -FIELD_LINE = 0x_FF_F3 -FIELD_COLUMN = 0x_FF_F4 +FIELD_LINE_START = 0x_FF_F3 +FIELD_COLUMN_START = 0x_FF_F4 +FIELD_UTF16_COLUMN_START = 0x_FF_F5 +FIELD_LINE_END = 0x_FF_F6 +FIELD_COLUMN_END = 0x_FF_F7 +FIELD_UTF16_COLUMN_END = 0x_FF_F8 +FIELD_CHARACTER_START = 0x_FF_F9 +FIELD_CHARACTER_END = 0x_FF_FA EDGE_SEVERITY_DEBUG = 20 @@ -198,3 +311,19 @@ def _severity_name(severity): EDGE_SEVERITY_ERROR = 120 EDGE_SEVERITY_FATAL = 200 EDGE_SEVERITY_PANIC = 255 + + +LINESEP = os.linesep + +try: + SHOW_HINT = {"default": True, "enabled": True, "disabled": False}[ + os.getenv("EDGEDB_ERROR_HINT", "default") + ] +except KeyError: + warnings.warn( + "EDGEDB_ERROR_HINT can only be one of: default, enabled or disabled" + ) + SHOW_HINT = False + + +from edgedb.color import get_color diff --git a/edgedb/protocol/codecs/array.pyx b/edgedb/protocol/codecs/array.pyx index 2f709531..2906f1e8 100644 --- a/edgedb/protocol/codecs/array.pyx +++ b/edgedb/protocol/codecs/array.pyx @@ -39,7 +39,8 @@ cdef class BaseArrayCodec(BaseCodec): if not isinstance( self.sub_codec, - (ScalarCodec, TupleCodec, NamedTupleCodec, RangeCodec) + (ScalarCodec, TupleCodec, NamedTupleCodec, EnumCodec, + RangeCodec, MultiRangeCodec) ): raise TypeError( 'only arrays of scalars are supported (got type {!r})'.format( @@ -60,7 +61,10 @@ cdef class BaseArrayCodec(BaseCodec): for i in range(objlen): item = obj[i] if item is None: - elem_data.write_int32(-1) + raise ValueError( + "invalid array element at index {}: " + "None is not allowed".format(i) + ) else: try: self.sub_codec.encode(elem_data, item) @@ -153,7 +157,7 @@ cdef class ArrayCodec(BaseArrayCodec): def make_type(self, describe_context): return describe.ArrayType( desc_id=uuid.UUID(bytes=self.tid), - name=None, + name=self.type_name, element_type=self.sub_codec.make_type(describe_context), ) diff --git a/edgedb/protocol/codecs/base.pyx b/edgedb/protocol/codecs/base.pyx index 3bd52bb0..a40f6e57 100644 --- a/edgedb/protocol/codecs/base.pyx +++ b/edgedb/protocol/codecs/base.pyx @@ -149,7 +149,7 @@ cdef class BaseRecordCodec(BaseCodec): if not isinstance( codec, (ScalarCodec, ArrayCodec, TupleCodec, NamedTupleCodec, - EnumCodec, RangeCodec), + EnumCodec, RangeCodec, MultiRangeCodec), ): self.encoder_flags |= RECORD_ENCODER_INVALID break diff --git a/edgedb/protocol/codecs/codecs.pyx b/edgedb/protocol/codecs/codecs.pyx index 315d7abe..ab97e498 100644 --- a/edgedb/protocol/codecs/codecs.pyx +++ b/edgedb/protocol/codecs/codecs.pyx @@ -17,6 +17,7 @@ # +import array import decimal import uuid import datetime @@ -24,6 +25,8 @@ from edgedb import describe from edgedb import enums from edgedb.datatypes import datatypes +from libc.string cimport memcpy + include "./edb_types.pxi" @@ -49,6 +52,9 @@ DEF CTYPE_ARRAY = 6 DEF CTYPE_ENUM = 7 DEF CTYPE_INPUT_SHAPE = 8 DEF CTYPE_RANGE = 9 +DEF CTYPE_OBJECT = 10 +DEF CTYPE_COMPOUND = 11 +DEF CTYPE_MULTIRANGE = 12 DEF CTYPE_ANNO_TYPENAME = 255 DEF _CODECS_BUILD_CACHE_SIZE = 200 @@ -91,8 +97,9 @@ cdef class CodecsRegistry: cdef BaseCodec _build_codec(self, FRBuffer *spec, list codecs_list, protocol_version): cdef: - uint8_t t = (frb_read(spec, 1)[0]) - bytes tid = frb_read(spec, 16)[:16] + uint32_t desc_len = 0 + uint8_t t + bytes tid uint16_t els uint16_t i uint32_t str_len @@ -101,12 +108,21 @@ cdef class CodecsRegistry: BaseCodec res BaseCodec sub_codec + if protocol_version >= (2, 0): + desc_len = frb_get_len(spec) - 16 - 1 + + t = (frb_read(spec, 1)[0]) + tid = frb_read(spec, 16)[:16] + res = self.codecs.get(tid, None) if res is None: res = self.codecs_build_cache.get(tid, None) if res is not None: # We have a codec for this "tid"; advance the buffer # so that we can process the next codec. + if desc_len > 0: + frb_read(spec, desc_len) + return res if t == CTYPE_SET: frb_read(spec, 2) @@ -150,6 +166,9 @@ cdef class CodecsRegistry: elif t == CTYPE_RANGE: frb_read(spec, 2) + elif t == CTYPE_MULTIRANGE: + frb_read(spec, 2) + elif t == CTYPE_ENUM: els = hton.unpack_int16(frb_read(spec, 2)) for i in range(els): @@ -179,7 +198,50 @@ cdef class CodecsRegistry: sub_codec = codecs_list[pos] res = SetCodec.new(tid, sub_codec) - elif t == CTYPE_SHAPE or t == CTYPE_INPUT_SHAPE: + elif t == CTYPE_SHAPE: + if protocol_version >= (2, 0): + ephemeral_free_shape = frb_read(spec, 1)[0] + objtype_pos = hton.unpack_int16(frb_read(spec, 2)) + + els = hton.unpack_int16(frb_read(spec, 2)) + codecs = cpython.PyTuple_New(els) + names = cpython.PyTuple_New(els) + flags = cpython.PyTuple_New(els) + cards = cpython.PyTuple_New(els) + for i in range(els): + flag = hton.unpack_uint32(frb_read(spec, 4)) # flags + cardinality = frb_read(spec, 1)[0] + + str_len = hton.unpack_uint32(frb_read(spec, 4)) + name = cpythonx.PyUnicode_FromStringAndSize( + frb_read(spec, str_len), str_len) + pos = hton.unpack_int16(frb_read(spec, 2)) + + if flag & datatypes._EDGE_POINTER_IS_LINKPROP: + name = "@" + name + cpython.Py_INCREF(name) + cpython.PyTuple_SetItem(names, i, name) + + sub_codec = codecs_list[pos] + cpython.Py_INCREF(sub_codec) + cpython.PyTuple_SetItem(codecs, i, sub_codec) + + cpython.Py_INCREF(flag) + cpython.PyTuple_SetItem(flags, i, flag) + + cpython.Py_INCREF(cardinality) + cpython.PyTuple_SetItem(cards, i, cardinality) + + if protocol_version >= (2, 0): + source_type_pos = hton.unpack_int16( + frb_read(spec, 2)) + source_type = codecs_list[source_type_pos] + + res = ObjectCodec.new( + tid, names, flags, cards, codecs, t == CTYPE_INPUT_SHAPE + ) + + elif t == CTYPE_INPUT_SHAPE: els = hton.unpack_int16(frb_read(spec, 2)) codecs = cpython.PyTuple_New(els) names = cpython.PyTuple_New(els) @@ -220,15 +282,60 @@ cdef class CodecsRegistry: res = BASE_SCALAR_CODECS[tid] elif t == CTYPE_SCALAR: - pos = hton.unpack_int16(frb_read(spec, 2)) - codec = codecs_list[pos] - if type(codec) is not ScalarCodec: - raise RuntimeError( - f'a scalar codec expected for base scalar type, ' - f'got {type(codec).__name__}') - res = (codecs_list[pos]).derive(tid) + if protocol_version >= (2, 0): + str_len = hton.unpack_uint32(frb_read(spec, 4)) + type_name = cpythonx.PyUnicode_FromStringAndSize( + frb_read(spec, str_len), str_len) + schema_defined = frb_read(spec, 1)[0] + + ancestor_count = hton.unpack_int16(frb_read(spec, 2)) + ancestors = [] + for _ in range(ancestor_count): + ancestor_pos = hton.unpack_int16( + frb_read(spec, 2)) + ancestor_codec = codecs_list[ancestor_pos] + if type(ancestor_codec) is not ScalarCodec: + raise RuntimeError( + f'a scalar codec expected for base scalar type, ' + f'got {type(ancestor_codec).__name__}') + ancestors.append(ancestor_codec) + + if ancestor_count == 0: + if tid in self.base_codec_overrides: + res = self.base_codec_overrides[tid] + else: + res = BASE_SCALAR_CODECS[tid] + else: + fundamental_codec = ancestors[-1] + if type(fundamental_codec) is not ScalarCodec: + raise RuntimeError( + f'a scalar codec expected for base scalar type, ' + f'got {type(fundamental_codec).__name__}') + res = (fundamental_codec).derive(tid) + res.type_name = type_name + else: + fundamental_pos = hton.unpack_int16( + frb_read(spec, 2)) + fundamental_codec = codecs_list[fundamental_pos] + if type(fundamental_codec) is not ScalarCodec: + raise RuntimeError( + f'a scalar codec expected for base scalar type, ' + f'got {type(fundamental_codec).__name__}') + res = (fundamental_codec).derive(tid) elif t == CTYPE_TUPLE: + if protocol_version >= (2, 0): + str_len = hton.unpack_uint32(frb_read(spec, 4)) + type_name = cpythonx.PyUnicode_FromStringAndSize( + frb_read(spec, str_len), str_len) + schema_defined = frb_read(spec, 1)[0] + ancestor_count = hton.unpack_int16(frb_read(spec, 2)) + for _ in range(ancestor_count): + ancestor_pos = hton.unpack_int16( + frb_read(spec, 2)) + ancestor_codec = codecs_list[ancestor_pos] + else: + type_name = None els = hton.unpack_int16(frb_read(spec, 2)) codecs = cpython.PyTuple_New(els) for i in range(els): @@ -239,8 +346,21 @@ cdef class CodecsRegistry: cpython.PyTuple_SetItem(codecs, i, sub_codec) res = TupleCodec.new(tid, codecs) + res.type_name = type_name elif t == CTYPE_NAMEDTUPLE: + if protocol_version >= (2, 0): + str_len = hton.unpack_uint32(frb_read(spec, 4)) + type_name = cpythonx.PyUnicode_FromStringAndSize( + frb_read(spec, str_len), str_len) + schema_defined = frb_read(spec, 1)[0] + ancestor_count = hton.unpack_int16(frb_read(spec, 2)) + for _ in range(ancestor_count): + ancestor_pos = hton.unpack_int16( + frb_read(spec, 2)) + ancestor_codec = codecs_list[ancestor_pos] + else: + type_name = None els = hton.unpack_int16(frb_read(spec, 2)) codecs = cpython.PyTuple_New(els) names = cpython.PyTuple_New(els) @@ -258,8 +378,21 @@ cdef class CodecsRegistry: cpython.PyTuple_SetItem(codecs, i, sub_codec) res = NamedTupleCodec.new(tid, names, codecs) + res.type_name = type_name elif t == CTYPE_ENUM: + if protocol_version >= (2, 0): + str_len = hton.unpack_uint32(frb_read(spec, 4)) + type_name = cpythonx.PyUnicode_FromStringAndSize( + frb_read(spec, str_len), str_len) + schema_defined = frb_read(spec, 1)[0] + ancestor_count = hton.unpack_int16(frb_read(spec, 2)) + for _ in range(ancestor_count): + ancestor_pos = hton.unpack_int16( + frb_read(spec, 2)) + ancestor_codec = codecs_list[ancestor_pos] + else: + type_name = None els = hton.unpack_int16(frb_read(spec, 2)) names = cpython.PyTuple_New(els) for i in range(els): @@ -271,8 +404,21 @@ cdef class CodecsRegistry: cpython.PyTuple_SetItem(names, i, name) res = EnumCodec.new(tid, names) + res.type_name = type_name elif t == CTYPE_ARRAY: + if protocol_version >= (2, 0): + str_len = hton.unpack_uint32(frb_read(spec, 4)) + type_name = cpythonx.PyUnicode_FromStringAndSize( + frb_read(spec, str_len), str_len) + schema_defined = frb_read(spec, 1)[0] + ancestor_count = hton.unpack_int16(frb_read(spec, 2)) + for _ in range(ancestor_count): + ancestor_pos = hton.unpack_int16( + frb_read(spec, 2)) + ancestor_codec = codecs_list[ancestor_pos] + else: + type_name = None pos = hton.unpack_int16(frb_read(spec, 2)) els = hton.unpack_int16(frb_read(spec, 2)) if els != 1: @@ -282,11 +428,53 @@ cdef class CodecsRegistry: dim_len = hton.unpack_int32(frb_read(spec, 4)) sub_codec = codecs_list[pos] res = ArrayCodec.new(tid, sub_codec, dim_len) + res.type_name = type_name elif t == CTYPE_RANGE: + if protocol_version >= (2, 0): + str_len = hton.unpack_uint32(frb_read(spec, 4)) + type_name = cpythonx.PyUnicode_FromStringAndSize( + frb_read(spec, str_len), str_len) + schema_defined = frb_read(spec, 1)[0] + ancestor_count = hton.unpack_int16(frb_read(spec, 2)) + for _ in range(ancestor_count): + ancestor_pos = hton.unpack_int16( + frb_read(spec, 2)) + ancestor_codec = codecs_list[ancestor_pos] + else: + type_name = None pos = hton.unpack_int16(frb_read(spec, 2)) sub_codec = codecs_list[pos] res = RangeCodec.new(tid, sub_codec) + res.type_name = type_name + + elif t == CTYPE_MULTIRANGE: + if protocol_version >= (2, 0): + str_len = hton.unpack_uint32(frb_read(spec, 4)) + type_name = cpythonx.PyUnicode_FromStringAndSize( + frb_read(spec, str_len), str_len) + schema_defined = frb_read(spec, 1)[0] + ancestor_count = hton.unpack_int16(frb_read(spec, 2)) + for _ in range(ancestor_count): + ancestor_pos = hton.unpack_int16( + frb_read(spec, 2)) + ancestor_codec = codecs_list[ancestor_pos] + else: + type_name = None + pos = hton.unpack_int16(frb_read(spec, 2)) + sub_codec = codecs_list[pos] + res = MultiRangeCodec.new(tid, sub_codec) + res.type_name = type_name + + elif t == CTYPE_OBJECT and protocol_version >= (2, 0): + # Ignore + frb_read(spec, desc_len) + res = NULL_CODEC + + elif t == CTYPE_COMPOUND and protocol_version >= (2, 0): + # Ignore + frb_read(spec, desc_len) + res = NULL_CODEC else: raise NotImplementedError( @@ -318,6 +506,7 @@ cdef class CodecsRegistry: cdef BaseCodec build_codec(self, bytes spec, protocol_version): cdef: FRBuffer buf + FRBuffer elem_buf BaseCodec res list codecs_list @@ -328,7 +517,16 @@ cdef class CodecsRegistry: codecs_list = [] while frb_get_len(&buf): - res = self._build_codec(&buf, codecs_list, protocol_version) + if protocol_version >= (2, 0): + desc_len = hton.unpack_int32(frb_read(&buf, 4)) + frb_slice_from(&elem_buf, &buf, desc_len) + res = self._build_codec( + &elem_buf, codecs_list, protocol_version) + if frb_get_len(&elem_buf): + raise RuntimeError( + f'unexpected trailing data in type descriptor datum') + else: + res = self._build_codec(&buf, codecs_list, protocol_version) if res is None: # An annotation; ignore. continue @@ -347,14 +545,16 @@ cdef dict BASE_SCALAR_CODECS = {} cdef register_base_scalar_codec( str name, pgproto.encode_func encoder, - pgproto.decode_func decoder): + pgproto.decode_func decoder, + object tid = None): cdef: BaseCodec codec - tid = TYPE_IDS.get(name) if tid is None: - raise RuntimeError(f'cannot find known ID for type {name!r}') + tid = TYPE_IDS.get(name) + if tid is None: + raise RuntimeError(f'cannot find known ID for type {name!r}') tid = tid.bytes if tid in BASE_SCALAR_CODECS: @@ -510,6 +710,94 @@ cdef config_memory_decode(pgproto.CodecContext settings, FRBuffer *buf): return datatypes.ConfigMemory(bytes=bytes) +DEF PGVECTOR_MAX_DIM = (1 << 16) - 1 + + +cdef pgvector_encode_memview(pgproto.CodecContext settings, WriteBuffer buf, + float[:] obj): + cdef: + float item + Py_ssize_t objlen + Py_ssize_t i + + objlen = len(obj) + if objlen > PGVECTOR_MAX_DIM: + raise ValueError('too many elements in vector value') + + buf.write_int32(4 + objlen*4) + buf.write_int16(objlen) + buf.write_int16(0) + for i in range(objlen): + buf.write_float(obj[i]) + + +cdef pgvector_encode(pgproto.CodecContext settings, WriteBuffer buf, + object obj): + cdef: + float item + Py_ssize_t objlen + float[:] memview + Py_ssize_t i + + # If we can take a typed memview of the object, we use that. + # That is good, because it means we can consume array.array and + # numpy.ndarray without needing to unbox. + # Otherwise we take the slow path, indexing into the array using + # the normal protocol. + try: + memview = obj + except (ValueError, TypeError) as e: + pass + else: + pgvector_encode_memview(settings, buf, memview) + return + + if not _is_array_iterable(obj): + raise TypeError( + 'a sized iterable container expected (got type {!r})'.format( + type(obj).__name__)) + + # Annoyingly, this is literally identical code to the fast path... + # but the types are different in critical ways. + objlen = len(obj) + if objlen > PGVECTOR_MAX_DIM: + raise ValueError('too many elements in vector value') + + buf.write_int32(4 + objlen*4) + buf.write_int16(objlen) + buf.write_int16(0) + for i in range(objlen): + buf.write_float(obj[i]) + + +cdef object ONE_EL_ARRAY = array.array('f', [0.0]) + + +cdef pgvector_decode(pgproto.CodecContext settings, FRBuffer *buf): + cdef: + int32_t dim + Py_ssize_t size + Py_buffer view + char *p + float[:] array_view + + dim = hton.unpack_uint16(frb_read(buf, 2)) + frb_read(buf, 2) + + size = dim * 4 + p = frb_read(buf, size) + + # Create a float array with size dim + val = ONE_EL_ARRAY * dim + + # And fill it with the buffer contents + array_view = val + memcpy(&array_view[0], p, size) + val.byteswap() + + return val + + cdef checked_decimal_encode( pgproto.CodecContext settings, WriteBuffer buf, obj ): @@ -708,4 +996,12 @@ cdef register_base_scalar_codecs(): config_memory_decode) + register_base_scalar_codec( + 'ext::pgvector::vector', + pgvector_encode, + pgvector_decode, + uuid.UUID('9565dd88-04f5-11ee-a691-0b6ebe179825'), + ) + + register_base_scalar_codecs() diff --git a/edgedb/protocol/codecs/namedtuple.pyx b/edgedb/protocol/codecs/namedtuple.pyx index 930ee0ee..6514ef36 100644 --- a/edgedb/protocol/codecs/namedtuple.pyx +++ b/edgedb/protocol/codecs/namedtuple.pyx @@ -78,7 +78,7 @@ cdef class NamedTupleCodec(BaseNamedRecordCodec): def make_type(self, describe_context): return describe.NamedTupleType( desc_id=uuid.UUID(bytes=self.tid), - name=None, + name=self.type_name, element_types={ field: codec.make_type(describe_context) for field, codec in zip( diff --git a/edgedb/protocol/codecs/range.pxd b/edgedb/protocol/codecs/range.pxd index 13d642f2..9b232b10 100644 --- a/edgedb/protocol/codecs/range.pxd +++ b/edgedb/protocol/codecs/range.pxd @@ -25,3 +25,19 @@ cdef class RangeCodec(BaseCodec): @staticmethod cdef BaseCodec new(bytes tid, BaseCodec sub_codec) + + @staticmethod + cdef encode_range(WriteBuffer buf, object obj, BaseCodec sub_codec) + + @staticmethod + cdef decode_range(FRBuffer *buf, BaseCodec sub_codec) + + +@cython.final +cdef class MultiRangeCodec(BaseCodec): + + cdef: + BaseCodec sub_codec + + @staticmethod + cdef BaseCodec new(bytes tid, BaseCodec sub_codec) diff --git a/edgedb/protocol/codecs/range.pyx b/edgedb/protocol/codecs/range.pyx index 9555d969..ea573b89 100644 --- a/edgedb/protocol/codecs/range.pyx +++ b/edgedb/protocol/codecs/range.pyx @@ -46,7 +46,8 @@ cdef class RangeCodec(BaseCodec): return codec - cdef encode(self, WriteBuffer buf, object obj): + @staticmethod + cdef encode_range(WriteBuffer buf, object obj, BaseCodec sub_codec): cdef: uint8_t flags = 0 WriteBuffer sub_data @@ -56,10 +57,10 @@ cdef class RangeCodec(BaseCodec): bint inc_upper = obj.inc_upper bint empty = obj.is_empty() - if not isinstance(self.sub_codec, ScalarCodec): + if not isinstance(sub_codec, ScalarCodec): raise TypeError( 'only scalar ranges are supported (got type {!r})'.format( - type(self.sub_codec).__name__ + type(sub_codec).__name__ ) ) @@ -78,14 +79,14 @@ cdef class RangeCodec(BaseCodec): sub_data = WriteBuffer.new() if lower is not None: try: - self.sub_codec.encode(sub_data, lower) + sub_codec.encode(sub_data, lower) except TypeError as e: raise ValueError( 'invalid range lower bound: {}'.format( e.args[0])) from None if upper is not None: try: - self.sub_codec.encode(sub_data, upper) + sub_codec.encode(sub_data, upper) except TypeError as e: raise ValueError( 'invalid range upper bound: {}'.format( @@ -95,7 +96,8 @@ cdef class RangeCodec(BaseCodec): buf.write_byte(flags) buf.write_buffer(sub_data) - cdef decode(self, FRBuffer *buf): + @staticmethod + cdef decode_range(FRBuffer *buf, BaseCodec sub_codec): cdef: uint8_t flags = frb_read(buf, 1)[0] bint empty = (flags & RANGE_EMPTY) != 0 @@ -107,7 +109,6 @@ cdef class RangeCodec(BaseCodec): object upper = None int32_t sub_len FRBuffer sub_buf - BaseCodec sub_codec = self.sub_codec if has_lower: sub_len = hton.unpack_int32(frb_read(buf, 4)) @@ -137,12 +138,119 @@ cdef class RangeCodec(BaseCodec): empty=empty, ) + cdef encode(self, WriteBuffer buf, object obj): + RangeCodec.encode_range(buf, obj, self.sub_codec) + + cdef decode(self, FRBuffer *buf): + return RangeCodec.decode_range(buf, self.sub_codec) + cdef dump(self, int level = 0): return f'{level * " "}{self.name}\n{self.sub_codec.dump(level + 1)}' def make_type(self, describe_context): return describe.RangeType( desc_id=uuid.UUID(bytes=self.tid), - name=None, + name=self.type_name, value_type=self.sub_codec.make_type(describe_context), ) + + +@cython.final +cdef class MultiRangeCodec(BaseCodec): + + def __cinit__(self): + self.sub_codec = None + + @staticmethod + cdef BaseCodec new(bytes tid, BaseCodec sub_codec): + cdef: + MultiRangeCodec codec + + codec = MultiRangeCodec.__new__(MultiRangeCodec) + + codec.tid = tid + codec.name = 'MultiRange' + codec.sub_codec = sub_codec + + return codec + + cdef encode(self, WriteBuffer buf, object obj): + cdef: + WriteBuffer elem_data + Py_ssize_t objlen + Py_ssize_t elem_data_len + + if not isinstance(self.sub_codec, ScalarCodec): + raise TypeError( + f'only scalar multiranges are supported (got type ' + f'{type(self.sub_codec).__name__!r})' + ) + + if not _is_array_iterable(obj): + raise TypeError( + f'a sized iterable container expected (got type ' + f'{type(obj).__name__!r})' + ) + + objlen = len(obj) + if objlen > _MAXINT32: + raise ValueError('too many elements in multirange value') + + elem_data = WriteBuffer.new() + for item in obj: + try: + RangeCodec.encode_range(elem_data, item, self.sub_codec) + except TypeError as e: + raise ValueError( + f'invalid multirange element: {e.args[0]}') from None + + elem_data_len = elem_data.len() + if elem_data_len > _MAXINT32 - 4: + raise OverflowError( + f'size of encoded multirange datum exceeds the maximum ' + f'allowed {_MAXINT32 - 4} bytes') + + # Datum length + buf.write_int32(4 + elem_data_len) + # Number of elements in multirange + buf.write_int32(objlen) + buf.write_buffer(elem_data) + + cdef decode(self, FRBuffer *buf): + cdef: + Py_ssize_t elem_count = hton.unpack_int32( + frb_read(buf, 4)) + object result + Py_ssize_t i + int32_t elem_len + FRBuffer elem_buf + + result = cpython.PyList_New(elem_count) + for i in range(elem_count): + elem_len = hton.unpack_int32(frb_read(buf, 4)) + if elem_len == -1: + raise RuntimeError( + 'unexpected NULL element in multirange value') + else: + frb_slice_from(&elem_buf, buf, elem_len) + elem = RangeCodec.decode_range(&elem_buf, self.sub_codec) + if frb_get_len(&elem_buf): + raise RuntimeError( + f'unexpected trailing data in buffer after ' + f'multirange element decoding: ' + f'{frb_get_len(&elem_buf)}') + + cpython.Py_INCREF(elem) + cpython.PyList_SET_ITEM(result, i, elem) + + return range_mod.MultiRange(result) + + cdef dump(self, int level = 0): + return f'{level * " "}{self.name}\n{self.sub_codec.dump(level + 1)}' + + def make_type(self, describe_context): + return describe.MultiRangeType( + desc_id=uuid.UUID(bytes=self.tid), + name=self.type_name, + value_type=self.sub_codec.make_type(describe_context), + ) \ No newline at end of file diff --git a/edgedb/protocol/codecs/tuple.pyx b/edgedb/protocol/codecs/tuple.pyx index 68ed0352..d29415d6 100644 --- a/edgedb/protocol/codecs/tuple.pyx +++ b/edgedb/protocol/codecs/tuple.pyx @@ -81,7 +81,7 @@ cdef class TupleCodec(BaseRecordCodec): def make_type(self, describe_context): return describe.TupleType( desc_id=uuid.UUID(bytes=self.tid), - name=None, + name=self.type_name, element_types=tuple( codec.make_type(describe_context) for codec in self.fields_codecs diff --git a/edgedb/protocol/consts.pxi b/edgedb/protocol/consts.pxi index 6d93c07f..10fe4cf8 100644 --- a/edgedb/protocol/consts.pxi +++ b/edgedb/protocol/consts.pxi @@ -61,8 +61,8 @@ DEF TRANS_STATUS_IDLE = b'I' DEF TRANS_STATUS_INTRANS = b'T' DEF TRANS_STATUS_ERROR = b'E' -DEF PROTO_VER_MAJOR = 1 +DEF PROTO_VER_MAJOR = 2 DEF PROTO_VER_MINOR = 0 -DEF LEGACY_PROTO_VER_MAJOR = 0 -DEF LEGACY_PROTO_VER_MINOR_MIN = 13 +DEF MIN_PROTO_VER_MAJOR = 0 +DEF MIN_PROTO_VER_MINOR = 13 diff --git a/edgedb/protocol/protocol.pyx b/edgedb/protocol/protocol.pyx index 6bc5cd75..469d3cb4 100644 --- a/edgedb/protocol/protocol.pyx +++ b/edgedb/protocol/protocol.pyx @@ -148,7 +148,7 @@ cdef class SansIOProtocol: self.internal_reg = CodecsRegistry() self.server_settings = {} self.reset_status() - self.protocol_version = (PROTO_VER_MAJOR, 0) + self.protocol_version = (PROTO_VER_MAJOR, PROTO_VER_MINOR) self.state_type_id = NULL_CODEC_ID self.state_codec = None @@ -305,6 +305,7 @@ cdef class SansIOProtocol: elif mtype == ERROR_RESPONSE_MSG: exc = self.parse_error_message() + exc._query = query exc = self._amend_parse_error( exc, output_format, expect_one, required_one) @@ -435,6 +436,7 @@ cdef class SansIOProtocol: elif mtype == ERROR_RESPONSE_MSG: exc = self.parse_error_message() + exc._query = query if exc.get_code() == parameter_type_mismatch_code: if not isinstance(in_dc, NullCodec): buf = WriteBuffer.new() @@ -711,6 +713,27 @@ cdef class SansIOProtocol: else: self.fallthrough() + async def ping(self): + cdef char mtype + self.write(WriteBuffer.new_message(SYNC_MSG).end_message()) + exc = None + while True: + if not self.buffer.take_message(): + await self.wait_for_message() + mtype = self.buffer.get_message_type() + + if mtype == READY_FOR_COMMAND_MSG: + self.parse_sync_message() + break + elif mtype == ERROR_RESPONSE_MSG: + exc = self.parse_error_message() + self.buffer.finish_message() + break + else: + self.fallthrough() + if exc is not None: + raise exc + async def restore(self, bytes header, data_gen): cdef: WriteBuffer buf @@ -826,6 +849,8 @@ cdef class SansIOProtocol: 'user': self.con_params.user, 'database': self.con_params.database, } + if self.con_params.secret_key: + params['secret_key'] = self.con_params.secret_key handshake_buf.write_int16(len(params)) for k, v in params.items(): handshake_buf.write_len_prefixed_utf8(k) @@ -848,22 +873,19 @@ cdef class SansIOProtocol: minor = self.buffer.read_int16() # TODO: drop this branch when dropping protocol_v0 - if major == LEGACY_PROTO_VER_MAJOR: + if major == 0: self.is_legacy = True self.ignore_headers() self.buffer.finish_message() - if major != PROTO_VER_MAJOR and not ( - major == LEGACY_PROTO_VER_MAJOR and - minor >= LEGACY_PROTO_VER_MINOR_MIN - ): + if (major, minor) < (MIN_PROTO_VER_MAJOR, MIN_PROTO_VER_MINOR): raise errors.ClientConnectionError( f'the server requested an unsupported version of ' f'the protocol: {major}.{minor}' ) - - self.protocol_version = (major, minor) + else: + self.protocol_version = (major, minor) elif mtype == AUTH_REQUEST_MSG: # Authentication... diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..fed528d4 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" diff --git a/setup.py b/setup.py index 784c3d3e..73873448 100644 --- a/setup.py +++ b/setup.py @@ -287,7 +287,7 @@ def finalize_options(self): author_email='hello@magic.io', url='https://github.com/edgedb/edgedb-python', license='Apache License, Version 2.0', - packages=['edgedb'], + packages=setuptools.find_packages(), provides=['edgedb'], zip_safe=False, include_package_data=True, @@ -336,7 +336,7 @@ def finalize_options(self): include_dirs=INCLUDE_DIRS), ], cmdclass={'build_ext': build_ext}, - test_suite='tests.suite', + python_requires=">=3.7", install_requires=[ 'typing-extensions>=3.10.0; python_version < "3.8.0"', 'certifi>=2021.5.30; platform_system == "Windows"', diff --git a/tests/codegen/test-project2/generated_async_edgeql.py.assert b/tests/codegen/test-project2/generated_async_edgeql.py.assert index 40e256ba..75f91b5b 100644 --- a/tests/codegen/test-project2/generated_async_edgeql.py.assert +++ b/tests/codegen/test-project2/generated_async_edgeql.py.assert @@ -51,6 +51,9 @@ class LinkPropResultFriendsItem: class MyEnum(enum.Enum): THIS = "This" THAT = "That" + E_1 = "1" + F_B = "f. b" + F_B_1 = "f-b" @dataclasses.dataclass @@ -218,7 +221,7 @@ async def my_query( return await executor.query_single( """\ create scalar type MyScalar extending int64; - create scalar type MyEnum extending enum; + create scalar type MyEnum extending enum<'This', 'That', '1', 'f. b', 'f-b'>; select { a := $a, diff --git a/tests/codegen/test-project2/parpkg/subpkg/my_query.edgeql b/tests/codegen/test-project2/parpkg/subpkg/my_query.edgeql index 2a9b2e49..a00f8964 100644 --- a/tests/codegen/test-project2/parpkg/subpkg/my_query.edgeql +++ b/tests/codegen/test-project2/parpkg/subpkg/my_query.edgeql @@ -1,5 +1,5 @@ create scalar type MyScalar extending int64; -create scalar type MyEnum extending enum; +create scalar type MyEnum extending enum<'This', 'That', '1', 'f. b', 'f-b'>; select { a := $a, diff --git a/tests/codegen/test-project2/parpkg/subpkg/my_query_async_edgeql.py.assert b/tests/codegen/test-project2/parpkg/subpkg/my_query_async_edgeql.py.assert index ba841cc8..6225eb88 100644 --- a/tests/codegen/test-project2/parpkg/subpkg/my_query_async_edgeql.py.assert +++ b/tests/codegen/test-project2/parpkg/subpkg/my_query_async_edgeql.py.assert @@ -26,6 +26,9 @@ class NoPydanticValidation: class MyEnum(enum.Enum): THIS = "This" THAT = "That" + E_1 = "1" + F_B = "f. b" + F_B_1 = "f-b" @dataclasses.dataclass @@ -144,7 +147,7 @@ async def my_query( return await executor.query_single( """\ create scalar type MyScalar extending int64; - create scalar type MyEnum extending enum; + create scalar type MyEnum extending enum<'This', 'That', '1', 'f. b', 'f-b'>; select { a := $a, diff --git a/tests/codegen/test-project2/parpkg/subpkg/my_query_edgeql.py.assert b/tests/codegen/test-project2/parpkg/subpkg/my_query_edgeql.py.assert index 00fd747a..35f89338 100644 --- a/tests/codegen/test-project2/parpkg/subpkg/my_query_edgeql.py.assert +++ b/tests/codegen/test-project2/parpkg/subpkg/my_query_edgeql.py.assert @@ -17,6 +17,9 @@ MyScalar = int class MyEnum(enum.Enum): THIS = "This" THAT = "That" + E_1 = "1" + F_B = "f. b" + F_B_1 = "f-b" @dataclasses.dataclass @@ -135,7 +138,7 @@ def my_query( return executor.query_single( """\ create scalar type MyScalar extending int64; - create scalar type MyEnum extending enum; + create scalar type MyEnum extending enum<'This', 'That', '1', 'f. b', 'f-b'>; select { a := $a, diff --git a/tests/datatypes/test_datatypes.py b/tests/datatypes/test_datatypes.py index eaff8aff..741489d4 100644 --- a/tests/datatypes/test_datatypes.py +++ b/tests/datatypes/test_datatypes.py @@ -1003,3 +1003,165 @@ def test_array_6(self): self.assertNotEqual( edgedb.Array([1, 2, 3]), False) + + +class TestRange(unittest.TestCase): + + def test_range_empty_1(self): + t = edgedb.Range(empty=True) + self.assertEqual(t.lower, None) + self.assertEqual(t.upper, None) + self.assertFalse(t.inc_lower) + self.assertFalse(t.inc_upper) + self.assertTrue(t.is_empty()) + self.assertFalse(t) + + self.assertEqual(t, edgedb.Range(1, 1, empty=True)) + + with self.assertRaisesRegex(ValueError, 'conflicting arguments'): + edgedb.Range(1, 2, empty=True) + + def test_range_2(self): + t = edgedb.Range(1, 2) + self.assertEqual(repr(t), "") + self.assertEqual(str(t), "") + + self.assertEqual(t.lower, 1) + self.assertEqual(t.upper, 2) + self.assertTrue(t.inc_lower) + self.assertFalse(t.inc_upper) + self.assertFalse(t.is_empty()) + self.assertTrue(t) + + def test_range_3(self): + t = edgedb.Range(1) + self.assertEqual(t.lower, 1) + self.assertEqual(t.upper, None) + self.assertTrue(t.inc_lower) + self.assertFalse(t.inc_upper) + self.assertFalse(t.is_empty()) + + t = edgedb.Range(None, 1) + self.assertEqual(t.lower, None) + self.assertEqual(t.upper, 1) + self.assertFalse(t.inc_lower) + self.assertFalse(t.inc_upper) + self.assertFalse(t.is_empty()) + + t = edgedb.Range(None, None) + self.assertEqual(t.lower, None) + self.assertEqual(t.upper, None) + self.assertFalse(t.inc_lower) + self.assertFalse(t.inc_upper) + self.assertFalse(t.is_empty()) + + def test_range_4(self): + for il in (False, True): + for iu in (False, True): + t = edgedb.Range(1, 2, inc_lower=il, inc_upper=iu) + self.assertEqual(t.lower, 1) + self.assertEqual(t.upper, 2) + self.assertEqual(t.inc_lower, il) + self.assertEqual(t.inc_upper, iu) + self.assertFalse(t.is_empty()) + + def test_range_5(self): + # test hash + self.assertEqual( + { + edgedb.Range(None, 2, inc_upper=True), + edgedb.Range(1, 2), + edgedb.Range(1, 2), + edgedb.Range(1, 2), + edgedb.Range(None, 2, inc_upper=True), + }, + { + edgedb.Range(1, 2), + edgedb.Range(None, 2, inc_upper=True), + } + ) + + +class TestMultiRange(unittest.TestCase): + + def test_multirange_empty_1(self): + t = edgedb.MultiRange() + self.assertEqual(len(t), 0) + self.assertEqual(t, edgedb.MultiRange([])) + + def test_multirange_2(self): + t = edgedb.MultiRange([ + edgedb.Range(1, 2), + edgedb.Range(4), + ]) + self.assertEqual( + repr(t), ", ]>") + self.assertEqual( + str(t), ", ]>") + + self.assertEqual( + t, + edgedb.MultiRange([ + edgedb.Range(1, 2), + edgedb.Range(4), + ]) + ) + + def test_multirange_3(self): + ranges = [ + edgedb.Range(None, 0), + edgedb.Range(1, 2), + edgedb.Range(4), + ] + t = edgedb.MultiRange([ + edgedb.Range(None, 0), + edgedb.Range(1, 2), + edgedb.Range(4), + ]) + + for el, r in zip(t, ranges): + self.assertEqual(el, r) + + def test_multirange_4(self): + # test hash + self.assertEqual( + { + edgedb.MultiRange([ + edgedb.Range(1, 2), + edgedb.Range(4), + ]), + edgedb.MultiRange([edgedb.Range(None, 2, inc_upper=True)]), + edgedb.MultiRange([ + edgedb.Range(1, 2), + edgedb.Range(4), + ]), + edgedb.MultiRange([ + edgedb.Range(1, 2), + edgedb.Range(4), + ]), + edgedb.MultiRange([edgedb.Range(None, 2, inc_upper=True)]), + }, + { + edgedb.MultiRange([edgedb.Range(None, 2, inc_upper=True)]), + edgedb.MultiRange([ + edgedb.Range(1, 2), + edgedb.Range(4), + ]), + } + ) + + def test_multirange_5(self): + # test hash + self.assertEqual( + edgedb.MultiRange([ + edgedb.Range(None, 2, inc_upper=True), + edgedb.Range(5, 9), + edgedb.Range(5, 9), + edgedb.Range(5, 9), + edgedb.Range(None, 2, inc_upper=True), + ]), + edgedb.MultiRange([ + edgedb.Range(5, 9), + edgedb.Range(None, 2, inc_upper=True), + ]), + ) diff --git a/tests/shared-client-testcases b/tests/shared-client-testcases index 70433a6d..b8959be8 160000 --- a/tests/shared-client-testcases +++ b/tests/shared-client-testcases @@ -1 +1 @@ -Subproject commit 70433a6da0f3f1c9e991fdac7bb7f7ccab5ad878 +Subproject commit b8959be8968aceeeac2af3da7639de02b19d7030 diff --git a/tests/test_async_query.py b/tests/test_async_query.py index af334216..796bf87f 100644 --- a/tests/test_async_query.py +++ b/tests/test_async_query.py @@ -61,12 +61,10 @@ async def test_async_parse_error_recover_01(self): with self.assertRaises(edgedb.EdgeQLSyntaxError): await self.client.query('select syntax error') - with self.assertRaisesRegex(edgedb.EdgeQLSyntaxError, - 'Unexpected end of line'): + with self.assertRaises(edgedb.EdgeQLSyntaxError): await self.client.query('select (') - with self.assertRaisesRegex(edgedb.EdgeQLSyntaxError, - 'Unexpected end of line'): + with self.assertRaises(edgedb.EdgeQLSyntaxError): await self.client.query_json('select (') for _ in range(10): @@ -380,6 +378,12 @@ async def test_async_args_03(self): 'combine positional and named parameters'): await self.client.query('select $0 + $bar;') + with self.assertRaisesRegex(edgedb.InvalidArgumentError, + "None is not allowed"): + await self.client.query( + "select >$0", [1, None, 3] + ) + async def test_async_args_04(self): aware_datetime = datetime.datetime.now(datetime.timezone.utc) naive_datetime = datetime.datetime.now() @@ -794,6 +798,83 @@ async def test_async_range_02(self): ) self.assertEqual([edgedb.Range(1, 2)], result) + async def test_async_multirange_01(self): + has_range = await self.client.query( + "select schema::ObjectType filter .name = 'schema::MultiRange'") + if not has_range: + raise unittest.SkipTest( + "server has no support for std::multirange") + + samples = [ + ('multirange', [ + edgedb.MultiRange(), + dict( + input=edgedb.MultiRange([edgedb.Range(empty=True)]), + output=edgedb.MultiRange(), + ), + edgedb.MultiRange([ + edgedb.Range(None, 0), + edgedb.Range(1, 2), + edgedb.Range(4), + ]), + dict( + input=edgedb.MultiRange([ + edgedb.Range(None, 2, inc_upper=True), + edgedb.Range(5, 9), + edgedb.Range(5, 9), + edgedb.Range(5, 9), + edgedb.Range(None, 2, inc_upper=True), + ]), + output=edgedb.MultiRange([ + edgedb.Range(5, 9), + edgedb.Range(None, 3), + ]), + ), + dict( + input=edgedb.MultiRange([ + edgedb.Range(None, 2), + edgedb.Range(-5, 9), + edgedb.Range(13), + ]), + output=edgedb.MultiRange([ + edgedb.Range(None, 9), + edgedb.Range(13), + ]), + ), + ]), + ] + + for typename, sample_data in samples: + for sample in sample_data: + with self.subTest(sample=sample, typname=typename): + stmt = f"SELECT <{typename}>$0" + if isinstance(sample, dict): + inputval = sample['input'] + outputval = sample['output'] + else: + inputval = outputval = sample + + result = await self.client.query_single(stmt, inputval) + err_msg = ( + "unexpected result for {} when passing {!r}: " + "received {!r}, expected {!r}".format( + typename, inputval, result, outputval)) + + self.assertEqual(result, outputval, err_msg) + + async def test_async_multirange_02(self): + has_range = await self.client.query( + "select schema::ObjectType filter .name = 'schema::MultiRange'") + if not has_range: + raise unittest.SkipTest( + "server has no support for std::multirange") + + result = await self.client.query_single( + "SELECT >>$0", + [edgedb.MultiRange([edgedb.Range(1, 2)])] + ) + self.assertEqual([edgedb.MultiRange([edgedb.Range(1, 2)])], result) + async def test_async_wait_cancel_01(self): underscored_lock = await self.client.query_single(""" SELECT EXISTS( @@ -844,7 +925,7 @@ async def exec_to_fail(): g.create_task(exec_to_fail()) - await asyncio.wait_for(fut, 1) + await asyncio.wait_for(fut, 5) await asyncio.sleep(0.1) with self.assertRaises(asyncio.TimeoutError): @@ -1024,3 +1105,18 @@ async def test_dup_link_prop_name(self): DROP TYPE test::dup_link_prop_name_p; DROP TYPE test::dup_link_prop_name; ''') + + async def test_transaction_state(self): + with self.assertRaisesRegex(edgedb.QueryError, "cannot assign to.*id"): + async for tx in self.client.transaction(): + async with tx: + await tx.execute(''' + INSERT test::Tmp { id := $0, tmp := '' } + ''', uuid.uuid4()) + + client = self.client.with_config(allow_user_specified_id=True) + async for tx in client.transaction(): + async with tx: + await tx.execute(''' + INSERT test::Tmp { id := $0, tmp := '' } + ''', uuid.uuid4()) diff --git a/tests/test_async_tx.py b/tests/test_async_tx.py index 71a10287..8ceeb239 100644 --- a/tests/test_async_tx.py +++ b/tests/test_async_tx.py @@ -16,6 +16,7 @@ # limitations under the License. # +import asyncio import itertools import edgedb @@ -89,3 +90,17 @@ async def test_async_transaction_commit_failure(self): async with tx: await tx.execute("start migration to {};") self.assertEqual(await self.client.query_single("select 42"), 42) + + async def test_async_transaction_exclusive(self): + async for tx in self.client.transaction(): + async with tx: + query = "select sys::_sleep(0.01)" + f1 = self.loop.create_task(tx.execute(query)) + f2 = self.loop.create_task(tx.execute(query)) + with self.assertRaisesRegex( + edgedb.InterfaceError, + "concurrent queries within the same transaction " + "are not allowed" + ): + await asyncio.wait_for(f1, timeout=5) + await asyncio.wait_for(f2, timeout=5) diff --git a/tests/test_codegen.py b/tests/test_codegen.py index ddc6307f..35580303 100644 --- a/tests/test_codegen.py +++ b/tests/test_codegen.py @@ -62,7 +62,7 @@ async def run(*args, extra_env=None): stderr=subprocess.STDOUT, ) try: - await asyncio.wait_for(p.wait(), 30) + await asyncio.wait_for(p.wait(), 120) except asyncio.TimeoutError: p.terminate() await p.wait() diff --git a/tests/test_con_utils.py b/tests/test_con_utils.py index 0582541a..c2820a6e 100644 --- a/tests/test_con_utils.py +++ b/tests/test_con_utils.py @@ -32,6 +32,7 @@ class TestConUtils(unittest.TestCase): + maxDiff = 1000 error_mapping = { 'credentials_file_not_found': ( @@ -46,6 +47,8 @@ class TestConUtils(unittest.TestCase): RuntimeError, 'cannot read credentials'), 'invalid_dsn_or_instance_name': ( ValueError, 'invalid DSN or instance name'), + 'invalid_instance_name': ( + ValueError, 'invalid instance name'), 'invalid_dsn': (ValueError, 'invalid DSN'), 'unix_socket_unsupported': ( ValueError, 'unix socket paths not supported'), @@ -68,7 +71,12 @@ class TestConUtils(unittest.TestCase): 'file_not_found': (FileNotFoundError, 'No such file or directory'), 'invalid_tls_security': ( ValueError, 'tls_security can only be one of `insecure`, ' - '|tls_security must be set to strict') + '|tls_security must be set to strict'), + 'invalid_secret_key': ( + errors.ClientConnectionError, "Invalid secret key"), + 'secret_key_not_found': ( + errors.ClientConnectionError, + "Cannot connect to cloud instances without secret key"), } @contextlib.contextmanager @@ -98,6 +106,7 @@ def run_testcase(self, testcase): env = testcase.get('env', {}) test_env = {'EDGEDB_HOST': None, 'EDGEDB_PORT': None, 'EDGEDB_USER': None, 'EDGEDB_PASSWORD': None, + 'EDGEDB_SECRET_KEY': None, 'EDGEDB_DATABASE': None, 'PGSSLMODE': None, 'XDG_CONFIG_HOME': None} test_env.update(env) @@ -105,7 +114,7 @@ def run_testcase(self, testcase): fs = testcase.get('fs') opts = testcase.get('opts', {}) - dsn = opts.get('dsn') + dsn = opts['instance'] if 'instance' in opts else opts.get('dsn') credentials = opts.get('credentials') credentials_file = opts.get('credentialsFile') host = opts.get('host') @@ -113,6 +122,7 @@ def run_testcase(self, testcase): database = opts.get('database') user = opts.get('user') password = opts.get('password') + secret_key = opts.get('secretKey') tls_ca = opts.get('tlsCA') tls_ca_file = opts.get('tlsCAFile') tls_security = opts.get('tlsSecurity') @@ -172,6 +182,12 @@ def run_testcase(self, testcase): files[instance] = v['instance-name'] project = os.path.join(dir, 'project-path') files[project] = v['project-path'] + if 'cloud-profile' in v: + profile = os.path.join(dir, 'cloud-profile') + files[profile] = v['cloud-profile'] + if 'database' in v: + database_file = os.path.join(dir, 'database') + files[database_file] = v['database'] del files[f] es.enter_context( @@ -219,6 +235,7 @@ def mocked_open(filepath, *args, **kwargs): database=database, user=user, password=password, + secret_key=secret_key, tls_ca=tls_ca, tls_ca_file=tls_ca_file, tls_security=tls_security, @@ -235,6 +252,7 @@ def mocked_open(filepath, *args, **kwargs): 'database': connect_config.database, 'user': connect_config.user, 'password': connect_config.password, + 'secretKey': connect_config.secret_key, 'tlsCAData': connect_config._tls_ca_data, 'tlsSecurity': connect_config.tls_security, 'serverSettings': connect_config.server_settings, @@ -285,6 +303,7 @@ def test_test_connect_params_run_testcase(self): 'database': 'edgedb', 'user': '__test__', 'password': None, + 'secretKey': None, 'tlsCAData': None, 'tlsSecurity': 'strict', 'serverSettings': {}, @@ -378,6 +397,7 @@ def test_project_config(self): credentials_file=None, user=None, password=None, + secret_key=None, database=None, tls_ca=None, tls_ca_file=None, diff --git a/tests/test_datetime.py b/tests/test_datetime.py index 08199077..ff7dfbc0 100644 --- a/tests/test_datetime.py +++ b/tests/test_datetime.py @@ -25,6 +25,11 @@ from edgedb.datatypes.datatypes import RelativeDuration, DateDuration +USECS_PER_HOUR = 3600000000 +USECS_PER_MINUTE = 60000000 +USECS_PER_SEC = 1000000 + + class TestDatetimeTypes(tb.SyncQueryTestCase): async def test_duration_01(self): @@ -60,6 +65,57 @@ async def test_duration_01(self): ''', durs) self.assertEqual(list(durs_from_db), durs) + async def test_duration_02(self): + # Make sure that when we break down the microseconds into the bigger + # components we still get consistent values. + tdn1h = timedelta(microseconds=-USECS_PER_HOUR) + tdn1m = timedelta(microseconds=-USECS_PER_MINUTE) + tdn1s = timedelta(microseconds=-USECS_PER_SEC) + tdn1us = timedelta(microseconds=-1) + durs = [ + ( + tdn1h, tdn1m, + timedelta(microseconds=-USECS_PER_HOUR - USECS_PER_MINUTE), + ), + ( + tdn1h, tdn1s, + timedelta(microseconds=-USECS_PER_HOUR - USECS_PER_SEC), + ), + ( + tdn1m, tdn1s, + timedelta(microseconds=-USECS_PER_MINUTE - USECS_PER_SEC), + ), + ( + tdn1h, tdn1us, + timedelta(microseconds=-USECS_PER_HOUR - 1), + ), + ( + tdn1m, tdn1us, + timedelta(microseconds=-USECS_PER_MINUTE - 1), + ), + ( + tdn1s, tdn1us, + timedelta(microseconds=-USECS_PER_SEC - 1), + ), + ] + + # Test encode + durs_enc = self.client.query(''' + WITH args := array_unpack( + >>$0) + SELECT args.0 + args.1 = args.2; + ''', durs) + + # Test decode + durs_dec = self.client.query(''' + WITH args := array_unpack( + >>$0) + SELECT (args.0 + args.1, args.2); + ''', durs) + + self.assertEqual(durs_enc, [True] * len(durs)) + self.assertEqual(list(durs_dec), [(d[2], d[2]) for d in durs]) + async def test_relative_duration_01(self): try: self.client.query("SELECT '1y'") @@ -124,6 +180,41 @@ async def test_relative_duration_02(self): self.assertEqual(repr(d1), '') + async def test_relative_duration_03(self): + # Make sure that when we break down the microseconds into the bigger + # components we still get the sign correctly in string + # representation. + durs = [ + RelativeDuration(microseconds=-USECS_PER_HOUR), + RelativeDuration(microseconds=-USECS_PER_MINUTE), + RelativeDuration(microseconds=-USECS_PER_SEC), + RelativeDuration(microseconds=-USECS_PER_HOUR - USECS_PER_MINUTE), + RelativeDuration(microseconds=-USECS_PER_HOUR - USECS_PER_SEC), + RelativeDuration(microseconds=-USECS_PER_MINUTE - USECS_PER_SEC), + RelativeDuration(microseconds=-USECS_PER_HOUR - USECS_PER_MINUTE - + USECS_PER_SEC), + RelativeDuration(microseconds=-USECS_PER_HOUR - 1), + RelativeDuration(microseconds=-USECS_PER_MINUTE - 1), + RelativeDuration(microseconds=-USECS_PER_SEC - 1), + RelativeDuration(microseconds=-1), + ] + + # Test that RelativeDuration.__str__ formats the + # same as + durs_as_text = self.client.query(''' + WITH args := array_unpack(>$0) + SELECT args; + ''', durs) + + # Test encode/decode roundtrip + durs_from_db = self.client.query(''' + WITH args := array_unpack(>$0) + SELECT args; + ''', durs) + + self.assertEqual(durs_as_text, [str(d) for d in durs]) + self.assertEqual(list(durs_from_db), durs) + async def test_date_duration_01(self): try: self.client.query("SELECT '1y'") @@ -168,3 +259,32 @@ async def test_date_duration_01(self): self.assertEqual(db_dur, str(client_dur)) self.assertEqual(list(durs_from_db), durs) + + async def test_date_duration_02(self): + # Make sure that when we break down the microseconds into the bigger + # components we still get the sign correctly in string + # representation. + durs = [ + DateDuration(months=11), + DateDuration(months=12), + DateDuration(months=13), + DateDuration(months=-11), + DateDuration(months=-12), + DateDuration(months=-13), + ] + + # Test that DateDuration.__str__ formats the + # same as + durs_as_text = self.client.query(''' + WITH args := array_unpack(>$0) + SELECT args; + ''', durs) + + # Test encode/decode roundtrip + durs_from_db = self.client.query(''' + WITH args := array_unpack(>$0) + SELECT args; + ''', durs) + + self.assertEqual(durs_as_text, [str(d) for d in durs]) + self.assertEqual(list(durs_from_db), durs) diff --git a/tests/test_enum.py b/tests/test_enum.py index a5e99b15..86bf40b6 100644 --- a/tests/test_enum.py +++ b/tests/test_enum.py @@ -99,3 +99,12 @@ async def test_enum_03(self): c_red = await self.client.query_single('SELECT "red"') c_red2 = await self.client.query_single('SELECT $0', c_red) self.assertIs(c_red, c_red2) + + async def test_enum_04(self): + enums = await self.client.query_single( + 'SELECT >$0', ['red', 'white'] + ) + enums2 = await self.client.query_single( + 'SELECT >$0', enums + ) + self.assertEqual(enums, enums2) diff --git a/tests/test_sync_query.py b/tests/test_sync_query.py index 0d793616..79dae829 100644 --- a/tests/test_sync_query.py +++ b/tests/test_sync_query.py @@ -53,12 +53,10 @@ def test_sync_parse_error_recover_01(self): with self.assertRaises(edgedb.EdgeQLSyntaxError): self.client.query('select syntax error') - with self.assertRaisesRegex(edgedb.EdgeQLSyntaxError, - 'Unexpected end of line'): + with self.assertRaises(edgedb.EdgeQLSyntaxError): self.client.query('select (') - with self.assertRaisesRegex(edgedb.EdgeQLSyntaxError, - 'Unexpected end of line'): + with self.assertRaises(edgedb.EdgeQLSyntaxError): self.client.query_json('select (') for _ in range(10): @@ -868,3 +866,18 @@ def test_sync_banned_transaction(self): r'cannot execute transaction control commands', ): self.client.execute('start transaction') + + def test_transaction_state(self): + with self.assertRaisesRegex(edgedb.QueryError, "cannot assign to.*id"): + for tx in self.client.transaction(): + with tx: + tx.execute(''' + INSERT test::Tmp { id := $0, tmp := '' } + ''', uuid.uuid4()) + + client = self.client.with_config(allow_user_specified_id=True) + for tx in client.transaction(): + with tx: + tx.execute(''' + INSERT test::Tmp { id := $0, tmp := '' } + ''', uuid.uuid4()) diff --git a/tests/test_sync_tx.py b/tests/test_sync_tx.py index eb1abc0e..3ed2fc55 100644 --- a/tests/test_sync_tx.py +++ b/tests/test_sync_tx.py @@ -17,6 +17,7 @@ # import itertools +from concurrent.futures import ThreadPoolExecutor import edgedb @@ -97,3 +98,18 @@ def test_sync_transaction_commit_failure(self): with tx: tx.execute("start migration to {};") self.assertEqual(self.client.query_single("select 42"), 42) + + def test_sync_transaction_exclusive(self): + for tx in self.client.transaction(): + with tx: + query = "select sys::_sleep(0.01)" + with ThreadPoolExecutor(max_workers=2) as executor: + f1 = executor.submit(tx.execute, query) + f2 = executor.submit(tx.execute, query) + with self.assertRaisesRegex( + edgedb.InterfaceError, + "concurrent queries within the same transaction " + "are not allowed" + ): + f1.result(timeout=5) + f2.result(timeout=5) diff --git a/tests/test_vector.py b/tests/test_vector.py new file mode 100644 index 00000000..ede4a3d0 --- /dev/null +++ b/tests/test_vector.py @@ -0,0 +1,131 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2019-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from edgedb import _testbase as tb +import edgedb + +import array + + +# An array.array subtype where indexing doesn't work. +# We use this to verify that the non-boxing memoryview based +# fast path works, since the slow path won't work on this object. +class brokenarray(array.array): + def __getitem__(self, i): + raise AssertionError("the fast path wasn't used!") + + +class TestVector(tb.SyncQueryTestCase): + def setUp(self): + super().setUp() + + if not self.client.query_required_single(''' + select exists ( + select sys::ExtensionPackage filter .name = 'pgvector' + ) + '''): + self.skipTest("feature not implemented") + + self.client.execute(''' + create extension pgvector; + ''') + + def tearDown(self): + try: + self.client.execute(''' + drop extension pgvector; + ''') + finally: + super().tearDown() + + async def test_vector_01(self): + val = self.client.query_single(''' + select [1.5,2.0,3.8] + ''') + self.assertTrue(isinstance(val, array.array)) + self.assertEqual(val, array.array('f', [1.5, 2.0, 3.8])) + + val = self.client.query_single( + ''' + select $0 + ''', + [3.0, 9.0, -42.5], + ) + self.assertEqual(val, '[3, 9, -42.5]') + + val = self.client.query_single( + ''' + select $0 + ''', + array.array('f', [3.0, 9.0, -42.5]) + ) + self.assertEqual(val, '[3, 9, -42.5]') + + val = self.client.query_single( + ''' + select $0 + ''', + array.array('i', [1, 2, 3]), + ) + self.assertEqual(val, '[1, 2, 3]') + + # Test that the fast-path works: if the encoder tries to + # call __getitem__ on this brokenarray, it will fail. + val = self.client.query_single( + ''' + select $0 + ''', + brokenarray('f', [3.0, 9.0, -42.5]) + ) + self.assertEqual(val, '[3, 9, -42.5]') + + # I don't think it's worth adding a dependency to test this, + # but this works too: + # import numpy as np + # val = self.client.query_single( + # ''' + # select $0 + # ''', + # np.asarray([3.0, 9.0, -42.5], dtype=np.float32), + # ) + # self.assertEqual(val, '[3,9,-42.5]') + + # Some sad path tests + with self.assertRaises(edgedb.InvalidArgumentError): + self.client.query_single( + ''' + select $0 + ''', + [3.0, None, -42.5], + ) + + with self.assertRaises(edgedb.InvalidArgumentError): + self.client.query_single( + ''' + select $0 + ''', + [3.0, 'x', -42.5], + ) + + with self.assertRaises(edgedb.InvalidArgumentError): + self.client.query_single( + ''' + select $0 + ''', + 'foo', + )