diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9dfa16f..63a503a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,10 +2,10 @@ name: CI on: push: branches: - - main + - master pull_request: branches: - - main + - master jobs: build: @@ -26,9 +26,6 @@ jobs: python -m pip install --upgrade pip pip install pytest pycodestyle if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - - name: Run tests with pytest - run: | - pytest - name: Lint with PyCodeStyle run: | find . -name \*.py -exec pycodestyle {} + diff --git a/aiogcd/connector/client_token.py b/aiogcd/connector/client_token.py index 8d701ad..688e778 100644 --- a/aiogcd/connector/client_token.py +++ b/aiogcd/connector/client_token.py @@ -89,7 +89,7 @@ def _read_token_file(self): return token return None - async def get(self): + async def get(self) -> str: """Returns the access token. If _refresh_ts is passed, the token will be refreshed. A lock is used to prevent refreshing the token twice. diff --git a/aiogcd/connector/connector.py b/aiogcd/connector/connector.py index 8aafb70..a368e41 100644 --- a/aiogcd/connector/connector.py +++ b/aiogcd/connector/connector.py @@ -7,6 +7,7 @@ import os import json import aiohttp +from typing import Iterable, Optional, Any, Union from .client_token import Token from .service_account_token import ServiceAccountToken from .entity import Entity @@ -26,7 +27,7 @@ _MAX_LOOPS = 128 -def _get_api_endpoint(): +def _get_api_endpoint() -> str: emu_host = os.getenv('DATASTORE_EMULATOR_HOST') if emu_host is None: return DEFAULT_API_ENDPOINT @@ -37,12 +38,12 @@ class GcdConnector: def __init__( self, - project_id, - client_id, - client_secret, - token_file, - scopes=DEFAULT_SCOPES, - namespace_id=None): + project_id: str, + client_id: str, + client_secret: str, + token_file: str, + scopes: Iterable[str] = DEFAULT_SCOPES, + namespace_id: Optional[str] = None): self.project_id = project_id self.namespace_id = namespace_id @@ -72,7 +73,7 @@ def __init__( async def connect(self): await self._token.connect() - async def insert_entities(self, entities): + async def insert_entities(self, entities) -> tuple[bool, ...]: """Returns a tuple containing boolean values. Each boolean value is True in case of a successful mutation and False if not. The order of booleans is the same as the supplied tuple or list. @@ -88,7 +89,7 @@ async def insert_entities(self, entities): # alias entities = insert_entities - async def insert_entity(self, entity): + async def insert_entity(self, entity: Entity) -> bool: """Returns True if successful or False if not. In case of False then most likely a conflict was detected. @@ -100,7 +101,8 @@ async def insert_entity(self, entity): """ return (await self._commit_entities_or_keys([entity], 'insert'))[0] - async def upsert_entities(self, entities): + async def upsert_entities(self, entities: Iterable[Entity]) -> \ + tuple[bool, ...]: """Returns a tuple containing boolean values. Each boolean value is True in case of a successful mutation and False if not. The order of booleans is the same as the supplied tuple or list. @@ -112,7 +114,7 @@ async def upsert_entities(self, entities): """ return await self._commit_entities_or_keys(entities, 'upsert') - async def upsert_entity(self, entity): + async def upsert_entity(self, entity: Entity) -> bool: """Returns True if successful or False if not. In case of False then most likely a conflict was detected. @@ -123,7 +125,8 @@ async def upsert_entity(self, entity): """ return (await self._commit_entities_or_keys([entity], 'upsert'))[0] - async def update_entities(self, entities): + async def update_entities(self, entities: Iterable[Entity]) -> \ + tuple[bool, ...]: """Returns a tuple containing boolean values. Each boolean value is True in case of a successful mutation and False if not. The order of booleans is the same as the supplied tuple or list. @@ -135,7 +138,7 @@ async def update_entities(self, entities): """ return await self._commit_entities_or_keys(entities, 'update') - async def update_entity(self, entity): + async def update_entity(self, entity: Entity) -> bool: """Returns True if successful or False if not. In case of False then most likely a conflict was detected. @@ -146,7 +149,7 @@ async def update_entity(self, entity): """ return (await self._commit_entities_or_keys([entity], 'update'))[0] - async def delete_keys(self, keys): + async def delete_keys(self, keys: Iterable[Key]) -> tuple[bool, ...]: """Returns a tuple containing boolean values. Each boolean value is True in case of a successful mutation and False if not. The order of booleans is the same as the supplied tuple or list. @@ -158,7 +161,7 @@ async def delete_keys(self, keys): """ return await self._commit_entities_or_keys(keys, 'delete') - async def delete_key(self, key): + async def delete_key(self, key: Key) -> bool: """Returns True if successful or False if not. In case of False then most likely a conflict was detected. @@ -169,7 +172,8 @@ async def delete_key(self, key): """ return (await self._commit_entities_or_keys([key], 'delete'))[0] - async def commit(self, mutations): + async def commit(self, mutations: Iterable[dict[str, Any]]) -> \ + tuple[dict, ...]: """Commit mutations. The only supported commit mode is NON_TRANSACTIONAL. @@ -205,7 +209,7 @@ async def commit(self, mutations): resp.status )) - async def run_query(self, data): + async def run_query(self, data) -> list[dict]: """Return entities by given query data. :param data: see the following link for the data format: @@ -216,7 +220,7 @@ async def run_query(self, data): results, _ = await self._run_query(data) return results - async def _run_query(self, data): + async def _run_query(self, data) -> tuple[list[dict], Optional[str]]: results = [] cursor = None @@ -276,11 +280,12 @@ async def _run_query(self, data): return results, cursor - async def _get_entities_cursor(self, data): + async def _get_entities_cursor(self, data) -> \ + tuple[list[Entity], Optional[str]]: results, cursor = await self._run_query(data) return [Entity(result['entity']) for result in results], cursor - async def get_entities(self, data): + async def get_entities(self, data) -> list[Entity]: """Return entities by given query data. :param data: see the following link for the data format: @@ -291,12 +296,12 @@ async def get_entities(self, data): results, _ = await self._run_query(data) return [Entity(result['entity']) for result in results] - async def get_keys(self, data): + async def get_keys(self, data) -> list[Key]: data['query']['projection'] = [{'property': {'name': '__key__'}}] results, _ = await self._run_query(data) return [Key(result['entity']['key']) for result in results] - async def get_entity(self, data): + async def get_entity(self, data) -> Optional[Entity]: """Return an entity object by given query data. :param data: see the following link for the data format: @@ -308,19 +313,24 @@ async def get_entity(self, data): result = await self.get_entities(data) return result[0] if result else None - async def get_key(self, data): + async def get_key(self, data) -> Optional[Key]: data['query']['limit'] = 1 result = await self.get_keys(data) return result[0] if result else None - async def get_entities_by_kind(self, kind, offset=None, limit=None, - cursor=None): + async def get_entities_by_kind(self, kind: str, + offset: Optional[int] = None, + limit: Optional[int] = None, + cursor: Optional[str] = None) -> Union[ + list[Entity], + tuple[list[Entity], Optional[str]] + ]: """Returns entities by kind. When a limit is set, this function returns a list and a cursor. If no limit is used, then only the list will be returned. """ - query = {'kind': [{'name': kind}]} + query: dict[str, Any] = {'kind': [{'name': kind}]} data = {'query': query} if cursor: query['startCursor'] = cursor @@ -333,10 +343,13 @@ async def get_entities_by_kind(self, kind, offset=None, limit=None, query['limit'] = limit return await self._get_entities_cursor(data) - async def get_entities_by_keys(self, keys, missing=None, deferred=None, - eventual=False): + async def get_entities_by_keys(self, keys: Iterable[Key], + missing: Optional[list[Any]] = None, + deferred: Optional[list[Key]] = None, + eventual: bool = False) -> list[Entity]: """Returns entity objects for the given keys or an empty list in case - no entity is found. + no entity is found. The order of entities might not be equal to the + order of provided keys. :param keys: list of Key objects :return: list of Entity objects. @@ -384,8 +397,10 @@ def data(): return entities - async def get_entity_by_key(self, key, missing=None, deferred=None, - eventual=False): + async def get_entity_by_key(self, key: Key, + missing: Optional[list[Any]] = None, + deferred: Optional[list[Key]] = None, + eventual: bool = False) -> Optional[Entity]: """Returns an entity object for the given key or None in case no entity is found. @@ -397,7 +412,7 @@ async def get_entity_by_key(self, key, missing=None, deferred=None, if entity: return entity[0] - async def _get_headers(self): + async def _get_headers(self) -> dict[str, str]: token = await self._token.get() return { 'Authorization': 'Bearer {}'.format(token), @@ -405,7 +420,7 @@ async def _get_headers(self): } @staticmethod - def _check_mutation_result(entity_or_key, mutation_result): + def _check_mutation_result(entity_or_key, mutation_result) -> bool: if 'key' in mutation_result: # The automatically allocated key. # Set only when the mutation allocated a key. @@ -413,7 +428,8 @@ def _check_mutation_result(entity_or_key, mutation_result): return not mutation_result.get('conflictDetected', False) - async def _commit_entities_or_keys(self, entities_or_keys, method): + async def _commit_entities_or_keys(self, entities_or_keys, method) -> \ + tuple[bool, ...]: mutations = [ {method: entity_or_key.get_dict()} for entity_or_key in entities_or_keys] @@ -429,11 +445,11 @@ async def _commit_entities_or_keys(self, entities_or_keys, method): class GcdServiceAccountConnector(GcdConnector): def __init__( self, - project_id, - service_file, - session=None, - scopes=None, - namespace_id=None): + project_id: str, + service_file: str, + session: Optional[aiohttp.ClientSession] = None, + scopes: Optional[Iterable[str]] = None, + namespace_id: Optional[str] = None): scopes = scopes or list(DEFAULT_SCOPES) self.project_id = project_id diff --git a/aiogcd/connector/decoder.py b/aiogcd/connector/decoder.py index 63837d9..409f3d8 100644 --- a/aiogcd/connector/decoder.py +++ b/aiogcd/connector/decoder.py @@ -13,9 +13,9 @@ class Decoder(Buffer): _idx = None _end = None - def __new__(cls, *args, ks=None): + def __new__(cls, *args, ks): assert ks is not None, \ - 'Key string is required, for example: Decoder(ks=)' + 'Key string is required, for example: Decoder(ks=)' decoder = super().__new__(cls) @@ -97,7 +97,7 @@ def get_var_int64(self): return result - def get_prefixed_string(self): + def get_prefixed_string(self) -> str: n = self.get_var_int32() if self._idx + n > len(self): raise BufferDecodeError('truncated') diff --git a/aiogcd/connector/entity.py b/aiogcd/connector/entity.py index 6812f4f..80f2a04 100644 --- a/aiogcd/connector/entity.py +++ b/aiogcd/connector/entity.py @@ -12,7 +12,7 @@ class Entity: - def __init__(self, entity_res): + def __init__(self, entity_res: dict): """Initialize an Entity object. Example: diff --git a/aiogcd/connector/key.py b/aiogcd/connector/key.py index 4fa5257..9db47d6 100644 --- a/aiogcd/connector/key.py +++ b/aiogcd/connector/key.py @@ -4,6 +4,7 @@ Author: Jeroen van der Heijden """ import base64 +from typing import Optional from .buffer import Buffer from .buffer import BufferDecodeError from .path import Path @@ -40,8 +41,10 @@ class Key: """ _ks = None - def __init__(self, *args, ks=None, path=None, project_id=None, - namespace_id=None): + def __init__(self, *args, ks: Optional[str] = None, + path: Optional[Path] = None, + project_id: Optional[str] = None, + namespace_id: Optional[str] = None): if len(args) == 1 and isinstance(args[0], dict): assert ks is None and path is None and project_id is None, \ self.KEY_INIT_MSG @@ -147,8 +150,9 @@ def _extract_id_or_name(pair): return None @staticmethod - def _deserialize_ks(ks): - """Returns a Key() object from a key string.""" + def _deserialize_ks(ks: str): + """Returns a tuple with the project_id, namespace_id and Path + from a key string.""" decoder = Decoder(ks=ks) project_id = None diff --git a/aiogcd/connector/path.py b/aiogcd/connector/path.py index 157e4c2..db76489 100644 --- a/aiogcd/connector/path.py +++ b/aiogcd/connector/path.py @@ -3,29 +3,18 @@ Created on: May 19, 2017 Author: Jeroen van der Heijden """ +from typing import Iterable, Union from .pathelement import PathElement from .pathelement import path_element_from_decoder from .buffer import BufferDecodeError -def path_from_decoder(decoder): - pairs = [] - while decoder: - tt = decoder.get_var_int32() - if tt == 11: - pairs.append(path_element_from_decoder(decoder)) - continue - - if tt == 0: - raise BufferDecodeError('corrupted') - - return Path(pairs=pairs) - - class Path: - def __init__(self, pairs): - self._path = tuple( + def __init__(self, pairs: Union[ + Iterable[PathElement], + Iterable[tuple[str, Union[int, str]]]]): + self._path: tuple[PathElement] = tuple( pe if isinstance(pe, PathElement) else PathElement(*pe) for pe in pairs) @@ -44,17 +33,31 @@ def __getitem__(self, item): def __repr__(self): return str(self.get_as_tuple()) - def get_dict(self): + def get_dict(self) -> dict: return {'path': [pe.get_dict() for pe in self._path]} @property - def byte_size(self): + def byte_size(self) -> int: n = 2 * len(self._path) for path_element in self._path: n += path_element.byte_size return n - def get_as_tuple(self): + def get_as_tuple(self) -> tuple[tuple[str, Union[str, int]], ...]: """Returns a tuple of pairs (tuples) representing the key path of an entity. Useful for composing entities with a specific ancestor.""" return tuple((pe.kind, pe.id) for pe in self._path) + + +def path_from_decoder(decoder) -> Path: + pairs = [] + while decoder: + tt = decoder.get_var_int32() + if tt == 11: + pairs.append(path_element_from_decoder(decoder)) + continue + + if tt == 0: + raise BufferDecodeError('corrupted') + + return Path(pairs=pairs) diff --git a/aiogcd/connector/pathelement.py b/aiogcd/connector/pathelement.py index f4656a6..4563395 100644 --- a/aiogcd/connector/pathelement.py +++ b/aiogcd/connector/pathelement.py @@ -1,39 +1,13 @@ +from typing import Union from .buffer import BufferDecodeError TYPE_ID = 0 TYPE_NAME = 1 -def path_element_from_decoder(decoder): - kind = None - name_or_id = None - - while True: - tt = decoder.get_var_int32() - - if tt == 12: - break - if tt == 18: - kind = decoder.get_prefixed_string() - continue - if tt == 24: - name_or_id = decoder.get_var_int64() - continue - if tt == 34: - name_or_id = decoder.get_prefixed_string() - continue - if tt == 0: - raise BufferDecodeError('corrupt') - - assert kind is not None and name_or_id is not None, \ - 'Expecting a path element with a kind and name/id.' - - return PathElement(kind, name_or_id) - - class PathElement: - def __init__(self, kind, name_or_id): + def __init__(self, kind: str, name_or_id: Union[str, int]): assert name_or_id is None or isinstance(name_or_id, (int, str)), \ 'Expecting a str or int type but got: {}'.format( type(name_or_id)) @@ -55,7 +29,7 @@ def encode(self, buffer): buffer.add_prefixed_string(self.id) @property - def byte_size(self): + def byte_size(self) -> int: n = self._size_str(self.kind) if isinstance(self.id, int): n += 1 + self._size_var_int(self.id) @@ -64,7 +38,7 @@ def byte_size(self): return n + 1 - def get_dict(self): + def get_dict(self) -> dict: if isinstance(self.id, int): return {'kind': self.kind, 'id': str(self.id)} @@ -74,12 +48,12 @@ def get_dict(self): return {'kind': self.kind} @classmethod - def _size_str(cls, s): + def _size_str(cls, s) -> int: sz = len(s) return cls._size_var_int(sz) + sz @staticmethod - def _size_var_int(n): + def _size_var_int(n) -> int: if n < 0: return 10 @@ -90,3 +64,30 @@ def _size_var_int(n): if n == 0: break return result + + +def path_element_from_decoder(decoder) -> PathElement: + kind = None + name_or_id = None + + while True: + tt = decoder.get_var_int32() + + if tt == 12: + break + if tt == 18: + kind = decoder.get_prefixed_string() + continue + if tt == 24: + name_or_id = decoder.get_var_int64() + continue + if tt == 34: + name_or_id = decoder.get_prefixed_string() + continue + if tt == 0: + raise BufferDecodeError('corrupt') + + assert kind is not None and name_or_id is not None, \ + 'Expecting a path element with a kind and name/id.' + + return PathElement(kind, name_or_id) diff --git a/aiogcd/connector/service_account_token.py b/aiogcd/connector/service_account_token.py index 0c527bd..a7ff8a3 100644 --- a/aiogcd/connector/service_account_token.py +++ b/aiogcd/connector/service_account_token.py @@ -1,6 +1,7 @@ from urllib.parse import urlencode, quote_plus from asyncio_extras.contextmanager import async_contextmanager from asyncio_extras.asyncyield import yield_ +from typing import Optional, Iterable import asyncio import aiohttp import datetime @@ -8,9 +9,8 @@ import jwt import logging import time -import typing -ScopeList = typing.List[str] +ScopeList = Iterable[str] JWT_GRANT_TYPE = 'urn:ietf:params:oauth:grant-type:jwt-bearer' GCLOUD_TOKEN_DURATION = 3600 MISMATCH = "Project name passed to Token does not match service_file's " \ @@ -30,7 +30,8 @@ async def ensure_session(session): class ServiceAccountToken(): def __init__(self, project_id: str, service_file: str, - scopes: ScopeList, session: aiohttp.ClientSession = None): + scopes: ScopeList, + session: Optional[aiohttp.ClientSession] = None): self.project_id = project_id diff --git a/aiogcd/orm/filter.py b/aiogcd/orm/filter.py index c4eedb6..b863c4a 100644 --- a/aiogcd/orm/filter.py +++ b/aiogcd/orm/filter.py @@ -3,13 +3,17 @@ Created on: May 19, 2017 Author: Jeroen van der Heijden """ +from typing import Any, Optional from ..connector.key import Key from ..connector import GcdConnector class Filter(dict): - def __init__(self, model, *filters, has_ancestor=None, key=None): + def __init__(self, model, *filters, + has_ancestor: Optional[Key] = None, + key: Optional[Key] = None): + self._model = model self._cursor = None filters = list(filters) @@ -35,7 +39,9 @@ def __init__(self, model, *filters, has_ancestor=None, key=None): 'value': {'keyValue': key.get_dict()}, 'op': 'EQUAL'}) - filter_dict = {'query': {'kind': [{'name': self._model.get_kind()}]}} + filter_dict: dict[str, Any] = { + 'query': {'kind': [{'name': self._model.get_kind()}]} + } if self._model.__namespace__: filter_dict['partitionId'] = { @@ -86,7 +92,9 @@ def _set_limit(self, limit): def cursor(self): return self._cursor - def order_by(self, *order): + def order_by(self, *order: Any): + # TODO type Value.name, *order could be: + # list[Union[Type[Value], tuple[str, str]]] self['query']['order'] = [ { 'property': {'name': p[0]}, @@ -99,12 +107,12 @@ def order_by(self, *order): ] return self - def limit(self, limit, start_cursor=None): + def limit(self, limit: int, start_cursor: Optional[str] = None): self._set_limit(limit) self._set_start_cursor(start_cursor) return self - async def get_entity(self, gcd: GcdConnector): + async def get_entity(self, gcd: GcdConnector) -> Any: """Return a GcdModel instance from the supplied filter. :param gcd: GcdConnector instance. @@ -114,7 +122,8 @@ async def get_entity(self, gcd: GcdConnector): return None if entity is None else self._model(entity) async def get_entities( - self, gcd: GcdConnector, offset=None, limit=None) -> list: + self, gcd: GcdConnector, offset: Optional[int] = None, + limit: Optional[int] = None) -> list[Any]: """Returns a list containing GcdModel instances from the supplied filter. @@ -127,6 +136,7 @@ async def get_entities( self._set_limit(limit) entities, cursor = await gcd._get_entities_cursor(self) self._cursor = cursor + # TODO return type should be list[Type[GcdModel]] return [self._model(ent) for ent in entities] async def get_key(self, gcd: GcdConnector): @@ -138,7 +148,8 @@ async def get_key(self, gcd: GcdConnector): return await gcd.get_key(self) async def get_keys( - self, gcd: GcdConnector, offset=None, limit=None) -> list: + self, gcd: GcdConnector, offset: Optional[int] = None, + limit: Optional[int] = None) -> list[Key]: """Returns a list containing Gcd keys from the supplied filter. :param gcd: GcdConnector instance. @@ -150,7 +161,7 @@ async def get_keys( self._set_limit(limit) return await gcd.get_keys(self) - def set_offset_limit(self, offset, limit): + def set_offset_limit(self, offset: int, limit: int): """Set offset and limit for Filter query. :param offset: can be int or None(to avoid setting offset) :param limit: can be int or None(to avoid setting limit) diff --git a/aiogcd/orm/model.py b/aiogcd/orm/model.py index 0893dad..85ae8b1 100644 --- a/aiogcd/orm/model.py +++ b/aiogcd/orm/model.py @@ -3,9 +3,10 @@ Created on: May 19, 2017 Author: Jeroen van der Heijden """ -import functools +from typing import Any, Optional from ..connector.entity import Entity from ..orm.properties.value import Value +from ..connector import GcdConnector from ..connector.key import Key from .filter import Filter from ..connector.timestampvalue import TimestampValue @@ -139,7 +140,9 @@ def del_property(self, prop): super().del_property(prop) @classmethod - def filter(cls, *filters, has_ancestor=None, key=None): + def filter(cls, *filters: dict[str, Any], + has_ancestor: Optional[Key] = None, + key: Optional[Key] = None): return Filter( cls, *filters, @@ -153,10 +156,13 @@ def get_kind(cls): return cls.__kind__ @classmethod - async def get_entities(cls, gcd, offset=None, limit=None): + async def get_entities(cls, gcd: GcdConnector, + offset: Optional[int] = None, + limit: Optional[int] = None): return await Filter(cls).get_entities(gcd, offset, limit) - def serializable_dict(self, key_as=None, include_none=False): + def serializable_dict(self, key_as: Optional[str] = None, + include_none: bool = False): """Serialize a GcdModel to a Python dict. :param key_as: If key_as is set to a string value, then the key string diff --git a/aiogcd/orm/utils.py b/aiogcd/orm/utils.py index 31c1969..03619e6 100644 --- a/aiogcd/orm/utils.py +++ b/aiogcd/orm/utils.py @@ -3,6 +3,7 @@ Created on: May 19, 2017 Author: Jeroen van der Heijden """ +from typing import Callable, Union class ProtectedList(list): @@ -10,7 +11,7 @@ class ProtectedList(list): def __init__( self, *args, - protect=True): + protect: Union[Callable[..., None], bool] = True): self._protect = protect super().__init__(*args) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..cb94532 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,6 @@ +[tool.pyright] +pythonVersion = "3.12" +pythonPlatform = "Linux" +include = [ + "aiogcd" +] \ No newline at end of file