From f8ee554da0027e037a92a926af583e4f21cb4130 Mon Sep 17 00:00:00 2001 From: Jeroen van der Heijden Date: Tue, 1 Oct 2024 15:11:38 +0200 Subject: [PATCH] more typing --- aiogcd/connector/client_token.py | 2 +- aiogcd/connector/connector.py | 56 ++++++++++++++++++-------------- aiogcd/connector/decoder.py | 2 +- aiogcd/connector/entity.py | 2 +- aiogcd/connector/key.py | 3 +- aiogcd/connector/path.py | 13 +++++--- aiogcd/connector/pathelement.py | 13 ++++---- aiogcd/orm/filter.py | 4 ++- aiogcd/orm/model.py | 3 +- 9 files changed, 57 insertions(+), 41 deletions(-) diff --git a/aiogcd/connector/client_token.py b/aiogcd/connector/client_token.py index 8d701ad..688e778 100644 --- a/aiogcd/connector/client_token.py +++ b/aiogcd/connector/client_token.py @@ -89,7 +89,7 @@ def _read_token_file(self): return token return None - async def get(self): + async def get(self) -> str: """Returns the access token. If _refresh_ts is passed, the token will be refreshed. A lock is used to prevent refreshing the token twice. diff --git a/aiogcd/connector/connector.py b/aiogcd/connector/connector.py index c2494d5..a368e41 100644 --- a/aiogcd/connector/connector.py +++ b/aiogcd/connector/connector.py @@ -7,7 +7,7 @@ import os import json import aiohttp -from typing import Iterable, Optional, Any +from typing import Iterable, Optional, Any, Union from .client_token import Token from .service_account_token import ServiceAccountToken from .entity import Entity @@ -27,7 +27,7 @@ _MAX_LOOPS = 128 -def _get_api_endpoint(): +def _get_api_endpoint() -> str: emu_host = os.getenv('DATASTORE_EMULATOR_HOST') if emu_host is None: return DEFAULT_API_ENDPOINT @@ -73,7 +73,7 @@ def __init__( async def connect(self): await self._token.connect() - async def insert_entities(self, entities): + async def insert_entities(self, entities) -> tuple[bool, ...]: """Returns a tuple containing boolean values. Each boolean value is True in case of a successful mutation and False if not. The order of booleans is the same as the supplied tuple or list. @@ -89,7 +89,7 @@ async def insert_entities(self, entities): # alias entities = insert_entities - async def insert_entity(self, entity: Entity): + async def insert_entity(self, entity: Entity) -> bool: """Returns True if successful or False if not. In case of False then most likely a conflict was detected. @@ -101,7 +101,8 @@ async def insert_entity(self, entity: Entity): """ return (await self._commit_entities_or_keys([entity], 'insert'))[0] - async def upsert_entities(self, entities: Iterable[Entity]): + async def upsert_entities(self, entities: Iterable[Entity]) -> \ + tuple[bool, ...]: """Returns a tuple containing boolean values. Each boolean value is True in case of a successful mutation and False if not. The order of booleans is the same as the supplied tuple or list. @@ -113,7 +114,7 @@ async def upsert_entities(self, entities: Iterable[Entity]): """ return await self._commit_entities_or_keys(entities, 'upsert') - async def upsert_entity(self, entity: Entity): + async def upsert_entity(self, entity: Entity) -> bool: """Returns True if successful or False if not. In case of False then most likely a conflict was detected. @@ -124,7 +125,8 @@ async def upsert_entity(self, entity: Entity): """ return (await self._commit_entities_or_keys([entity], 'upsert'))[0] - async def update_entities(self, entities: Iterable[Entity]): + async def update_entities(self, entities: Iterable[Entity]) -> \ + tuple[bool, ...]: """Returns a tuple containing boolean values. Each boolean value is True in case of a successful mutation and False if not. The order of booleans is the same as the supplied tuple or list. @@ -136,7 +138,7 @@ async def update_entities(self, entities: Iterable[Entity]): """ return await self._commit_entities_or_keys(entities, 'update') - async def update_entity(self, entity: Entity): + async def update_entity(self, entity: Entity) -> bool: """Returns True if successful or False if not. In case of False then most likely a conflict was detected. @@ -147,7 +149,7 @@ async def update_entity(self, entity: Entity): """ return (await self._commit_entities_or_keys([entity], 'update'))[0] - async def delete_keys(self, keys: Iterable[Key]): + async def delete_keys(self, keys: Iterable[Key]) -> tuple[bool, ...]: """Returns a tuple containing boolean values. Each boolean value is True in case of a successful mutation and False if not. The order of booleans is the same as the supplied tuple or list. @@ -159,7 +161,7 @@ async def delete_keys(self, keys: Iterable[Key]): """ return await self._commit_entities_or_keys(keys, 'delete') - async def delete_key(self, key: Key): + async def delete_key(self, key: Key) -> bool: """Returns True if successful or False if not. In case of False then most likely a conflict was detected. @@ -170,7 +172,8 @@ async def delete_key(self, key: Key): """ return (await self._commit_entities_or_keys([key], 'delete'))[0] - async def commit(self, mutations: Iterable[dict[str, Any]]): + async def commit(self, mutations: Iterable[dict[str, Any]]) -> \ + tuple[dict, ...]: """Commit mutations. The only supported commit mode is NON_TRANSACTIONAL. @@ -206,7 +209,7 @@ async def commit(self, mutations: Iterable[dict[str, Any]]): resp.status )) - async def run_query(self, data): + async def run_query(self, data) -> list[dict]: """Return entities by given query data. :param data: see the following link for the data format: @@ -217,7 +220,7 @@ async def run_query(self, data): results, _ = await self._run_query(data) return results - async def _run_query(self, data): + async def _run_query(self, data) -> tuple[list[dict], Optional[str]]: results = [] cursor = None @@ -277,11 +280,12 @@ async def _run_query(self, data): return results, cursor - async def _get_entities_cursor(self, data): + async def _get_entities_cursor(self, data) -> \ + tuple[list[Entity], Optional[str]]: results, cursor = await self._run_query(data) return [Entity(result['entity']) for result in results], cursor - async def get_entities(self, data): + async def get_entities(self, data) -> list[Entity]: """Return entities by given query data. :param data: see the following link for the data format: @@ -292,12 +296,12 @@ async def get_entities(self, data): results, _ = await self._run_query(data) return [Entity(result['entity']) for result in results] - async def get_keys(self, data): + async def get_keys(self, data) -> list[Key]: data['query']['projection'] = [{'property': {'name': '__key__'}}] results, _ = await self._run_query(data) return [Key(result['entity']['key']) for result in results] - async def get_entity(self, data): + async def get_entity(self, data) -> Optional[Entity]: """Return an entity object by given query data. :param data: see the following link for the data format: @@ -309,7 +313,7 @@ async def get_entity(self, data): result = await self.get_entities(data) return result[0] if result else None - async def get_key(self, data): + async def get_key(self, data) -> Optional[Key]: data['query']['limit'] = 1 result = await self.get_keys(data) return result[0] if result else None @@ -317,7 +321,10 @@ async def get_key(self, data): async def get_entities_by_kind(self, kind: str, offset: Optional[int] = None, limit: Optional[int] = None, - cursor: Optional[str] = None): + cursor: Optional[str] = None) -> Union[ + list[Entity], + tuple[list[Entity], Optional[str]] + ]: """Returns entities by kind. When a limit is set, this function returns a list and a cursor. @@ -339,7 +346,7 @@ async def get_entities_by_kind(self, kind: str, async def get_entities_by_keys(self, keys: Iterable[Key], missing: Optional[list[Any]] = None, deferred: Optional[list[Key]] = None, - eventual: bool = False): + eventual: bool = False) -> list[Entity]: """Returns entity objects for the given keys or an empty list in case no entity is found. The order of entities might not be equal to the order of provided keys. @@ -393,7 +400,7 @@ def data(): async def get_entity_by_key(self, key: Key, missing: Optional[list[Any]] = None, deferred: Optional[list[Key]] = None, - eventual: bool = False): + eventual: bool = False) -> Optional[Entity]: """Returns an entity object for the given key or None in case no entity is found. @@ -405,7 +412,7 @@ async def get_entity_by_key(self, key: Key, if entity: return entity[0] - async def _get_headers(self): + async def _get_headers(self) -> dict[str, str]: token = await self._token.get() return { 'Authorization': 'Bearer {}'.format(token), @@ -413,7 +420,7 @@ async def _get_headers(self): } @staticmethod - def _check_mutation_result(entity_or_key, mutation_result): + def _check_mutation_result(entity_or_key, mutation_result) -> bool: if 'key' in mutation_result: # The automatically allocated key. # Set only when the mutation allocated a key. @@ -421,7 +428,8 @@ def _check_mutation_result(entity_or_key, mutation_result): return not mutation_result.get('conflictDetected', False) - async def _commit_entities_or_keys(self, entities_or_keys, method): + async def _commit_entities_or_keys(self, entities_or_keys, method) -> \ + tuple[bool, ...]: mutations = [ {method: entity_or_key.get_dict()} for entity_or_key in entities_or_keys] diff --git a/aiogcd/connector/decoder.py b/aiogcd/connector/decoder.py index 99e6ff1..409f3d8 100644 --- a/aiogcd/connector/decoder.py +++ b/aiogcd/connector/decoder.py @@ -97,7 +97,7 @@ def get_var_int64(self): return result - def get_prefixed_string(self): + def get_prefixed_string(self) -> str: n = self.get_var_int32() if self._idx + n > len(self): raise BufferDecodeError('truncated') diff --git a/aiogcd/connector/entity.py b/aiogcd/connector/entity.py index 6812f4f..80f2a04 100644 --- a/aiogcd/connector/entity.py +++ b/aiogcd/connector/entity.py @@ -12,7 +12,7 @@ class Entity: - def __init__(self, entity_res): + def __init__(self, entity_res: dict): """Initialize an Entity object. Example: diff --git a/aiogcd/connector/key.py b/aiogcd/connector/key.py index ae8b86c..9db47d6 100644 --- a/aiogcd/connector/key.py +++ b/aiogcd/connector/key.py @@ -151,7 +151,8 @@ def _extract_id_or_name(pair): @staticmethod def _deserialize_ks(ks: str): - """Returns a Key() object from a key string.""" + """Returns a tuple with the project_id, namespace_id and Path + from a key string.""" decoder = Decoder(ks=ks) project_id = None diff --git a/aiogcd/connector/path.py b/aiogcd/connector/path.py index 7a8a7d0..db76489 100644 --- a/aiogcd/connector/path.py +++ b/aiogcd/connector/path.py @@ -3,6 +3,7 @@ Created on: May 19, 2017 Author: Jeroen van der Heijden """ +from typing import Iterable, Union from .pathelement import PathElement from .pathelement import path_element_from_decoder from .buffer import BufferDecodeError @@ -10,8 +11,10 @@ class Path: - def __init__(self, pairs): - self._path = tuple( + def __init__(self, pairs: Union[ + Iterable[PathElement], + Iterable[tuple[str, Union[int, str]]]]): + self._path: tuple[PathElement] = tuple( pe if isinstance(pe, PathElement) else PathElement(*pe) for pe in pairs) @@ -30,17 +33,17 @@ def __getitem__(self, item): def __repr__(self): return str(self.get_as_tuple()) - def get_dict(self): + def get_dict(self) -> dict: return {'path': [pe.get_dict() for pe in self._path]} @property - def byte_size(self): + def byte_size(self) -> int: n = 2 * len(self._path) for path_element in self._path: n += path_element.byte_size return n - def get_as_tuple(self): + def get_as_tuple(self) -> tuple[tuple[str, Union[str, int]], ...]: """Returns a tuple of pairs (tuples) representing the key path of an entity. Useful for composing entities with a specific ancestor.""" return tuple((pe.kind, pe.id) for pe in self._path) diff --git a/aiogcd/connector/pathelement.py b/aiogcd/connector/pathelement.py index 5522512..4563395 100644 --- a/aiogcd/connector/pathelement.py +++ b/aiogcd/connector/pathelement.py @@ -1,3 +1,4 @@ +from typing import Union from .buffer import BufferDecodeError TYPE_ID = 0 @@ -6,7 +7,7 @@ class PathElement: - def __init__(self, kind, name_or_id): + def __init__(self, kind: str, name_or_id: Union[str, int]): assert name_or_id is None or isinstance(name_or_id, (int, str)), \ 'Expecting a str or int type but got: {}'.format( type(name_or_id)) @@ -28,7 +29,7 @@ def encode(self, buffer): buffer.add_prefixed_string(self.id) @property - def byte_size(self): + def byte_size(self) -> int: n = self._size_str(self.kind) if isinstance(self.id, int): n += 1 + self._size_var_int(self.id) @@ -37,7 +38,7 @@ def byte_size(self): return n + 1 - def get_dict(self): + def get_dict(self) -> dict: if isinstance(self.id, int): return {'kind': self.kind, 'id': str(self.id)} @@ -47,12 +48,12 @@ def get_dict(self): return {'kind': self.kind} @classmethod - def _size_str(cls, s): + def _size_str(cls, s) -> int: sz = len(s) return cls._size_var_int(sz) + sz @staticmethod - def _size_var_int(n): + def _size_var_int(n) -> int: if n < 0: return 10 @@ -89,4 +90,4 @@ def path_element_from_decoder(decoder) -> PathElement: assert kind is not None and name_or_id is not None, \ 'Expecting a path element with a kind and name/id.' - return PathElement(kind, name_or_id) \ No newline at end of file + return PathElement(kind, name_or_id) diff --git a/aiogcd/orm/filter.py b/aiogcd/orm/filter.py index 69b8f1b..d1d97a7 100644 --- a/aiogcd/orm/filter.py +++ b/aiogcd/orm/filter.py @@ -39,7 +39,9 @@ def __init__(self, model, *filters, 'value': {'keyValue': key.get_dict()}, 'op': 'EQUAL'}) - filter_dict: dict[str, Any] = {'query': {'kind': [{'name': self._model.get_kind()}]}} + filter_dict: dict[str, Any] = { + 'query': {'kind': [{'name': self._model.get_kind()}]} + } if self._model.__namespace__: filter_dict['partitionId'] = { diff --git a/aiogcd/orm/model.py b/aiogcd/orm/model.py index 5e997d8..ebb5112 100644 --- a/aiogcd/orm/model.py +++ b/aiogcd/orm/model.py @@ -141,7 +141,8 @@ def del_property(self, prop): super().del_property(prop) @classmethod - def filter(cls, *filters: dict[str, Any], has_ancestor: Optional[Key] = None, + def filter(cls, *filters: dict[str, Any], + has_ancestor: Optional[Key] = None, key: Optional[Key] = None): return Filter( cls,