Skip to content

Commit

Permalink
more typing
Browse files Browse the repository at this point in the history
  • Loading branch information
joente committed Oct 1, 2024
1 parent 071960a commit f8ee554
Show file tree
Hide file tree
Showing 9 changed files with 57 additions and 41 deletions.
2 changes: 1 addition & 1 deletion aiogcd/connector/client_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def _read_token_file(self):
return token
return None

async def get(self):
async def get(self) -> str:
"""Returns the access token. If _refresh_ts is passed, the token will
be refreshed. A lock is used to prevent refreshing the token twice.
Expand Down
56 changes: 32 additions & 24 deletions aiogcd/connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import os
import json
import aiohttp
from typing import Iterable, Optional, Any
from typing import Iterable, Optional, Any, Union
from .client_token import Token
from .service_account_token import ServiceAccountToken
from .entity import Entity
Expand All @@ -27,7 +27,7 @@
_MAX_LOOPS = 128


def _get_api_endpoint():
def _get_api_endpoint() -> str:
emu_host = os.getenv('DATASTORE_EMULATOR_HOST')
if emu_host is None:
return DEFAULT_API_ENDPOINT
Expand Down Expand Up @@ -73,7 +73,7 @@ def __init__(
async def connect(self):
await self._token.connect()

async def insert_entities(self, entities):
async def insert_entities(self, entities) -> tuple[bool, ...]:
"""Returns a tuple containing boolean values. Each boolean value is
True in case of a successful mutation and False if not. The order of
booleans is the same as the supplied tuple or list.
Expand All @@ -89,7 +89,7 @@ async def insert_entities(self, entities):
# alias
entities = insert_entities

async def insert_entity(self, entity: Entity):
async def insert_entity(self, entity: Entity) -> bool:
"""Returns True if successful or False if not. In case of False then
most likely a conflict was detected.
Expand All @@ -101,7 +101,8 @@ async def insert_entity(self, entity: Entity):
"""
return (await self._commit_entities_or_keys([entity], 'insert'))[0]

async def upsert_entities(self, entities: Iterable[Entity]):
async def upsert_entities(self, entities: Iterable[Entity]) -> \
tuple[bool, ...]:
"""Returns a tuple containing boolean values. Each boolean value is
True in case of a successful mutation and False if not. The order of
booleans is the same as the supplied tuple or list.
Expand All @@ -113,7 +114,7 @@ async def upsert_entities(self, entities: Iterable[Entity]):
"""
return await self._commit_entities_or_keys(entities, 'upsert')

async def upsert_entity(self, entity: Entity):
async def upsert_entity(self, entity: Entity) -> bool:
"""Returns True if successful or False if not. In case of False then
most likely a conflict was detected.
Expand All @@ -124,7 +125,8 @@ async def upsert_entity(self, entity: Entity):
"""
return (await self._commit_entities_or_keys([entity], 'upsert'))[0]

async def update_entities(self, entities: Iterable[Entity]):
async def update_entities(self, entities: Iterable[Entity]) -> \
tuple[bool, ...]:
"""Returns a tuple containing boolean values. Each boolean value is
True in case of a successful mutation and False if not. The order of
booleans is the same as the supplied tuple or list.
Expand All @@ -136,7 +138,7 @@ async def update_entities(self, entities: Iterable[Entity]):
"""
return await self._commit_entities_or_keys(entities, 'update')

async def update_entity(self, entity: Entity):
async def update_entity(self, entity: Entity) -> bool:
"""Returns True if successful or False if not. In case of False then
most likely a conflict was detected.
Expand All @@ -147,7 +149,7 @@ async def update_entity(self, entity: Entity):
"""
return (await self._commit_entities_or_keys([entity], 'update'))[0]

async def delete_keys(self, keys: Iterable[Key]):
async def delete_keys(self, keys: Iterable[Key]) -> tuple[bool, ...]:
"""Returns a tuple containing boolean values. Each boolean value is
True in case of a successful mutation and False if not. The order of
booleans is the same as the supplied tuple or list.
Expand All @@ -159,7 +161,7 @@ async def delete_keys(self, keys: Iterable[Key]):
"""
return await self._commit_entities_or_keys(keys, 'delete')

async def delete_key(self, key: Key):
async def delete_key(self, key: Key) -> bool:
"""Returns True if successful or False if not. In case of False then
most likely a conflict was detected.
Expand All @@ -170,7 +172,8 @@ async def delete_key(self, key: Key):
"""
return (await self._commit_entities_or_keys([key], 'delete'))[0]

async def commit(self, mutations: Iterable[dict[str, Any]]):
async def commit(self, mutations: Iterable[dict[str, Any]]) -> \
tuple[dict, ...]:
"""Commit mutations.
The only supported commit mode is NON_TRANSACTIONAL.
Expand Down Expand Up @@ -206,7 +209,7 @@ async def commit(self, mutations: Iterable[dict[str, Any]]):
resp.status
))

async def run_query(self, data):
async def run_query(self, data) -> list[dict]:
"""Return entities by given query data.
:param data: see the following link for the data format:
Expand All @@ -217,7 +220,7 @@ async def run_query(self, data):
results, _ = await self._run_query(data)
return results

async def _run_query(self, data):
async def _run_query(self, data) -> tuple[list[dict], Optional[str]]:
results = []
cursor = None

Expand Down Expand Up @@ -277,11 +280,12 @@ async def _run_query(self, data):

return results, cursor

async def _get_entities_cursor(self, data):
async def _get_entities_cursor(self, data) -> \
tuple[list[Entity], Optional[str]]:
results, cursor = await self._run_query(data)
return [Entity(result['entity']) for result in results], cursor

async def get_entities(self, data):
async def get_entities(self, data) -> list[Entity]:
"""Return entities by given query data.
:param data: see the following link for the data format:
Expand All @@ -292,12 +296,12 @@ async def get_entities(self, data):
results, _ = await self._run_query(data)
return [Entity(result['entity']) for result in results]

async def get_keys(self, data):
async def get_keys(self, data) -> list[Key]:
data['query']['projection'] = [{'property': {'name': '__key__'}}]
results, _ = await self._run_query(data)
return [Key(result['entity']['key']) for result in results]

async def get_entity(self, data):
async def get_entity(self, data) -> Optional[Entity]:
"""Return an entity object by given query data.
:param data: see the following link for the data format:
Expand All @@ -309,15 +313,18 @@ async def get_entity(self, data):
result = await self.get_entities(data)
return result[0] if result else None

async def get_key(self, data):
async def get_key(self, data) -> Optional[Key]:
data['query']['limit'] = 1
result = await self.get_keys(data)
return result[0] if result else None

async def get_entities_by_kind(self, kind: str,
offset: Optional[int] = None,
limit: Optional[int] = None,
cursor: Optional[str] = None):
cursor: Optional[str] = None) -> Union[
list[Entity],
tuple[list[Entity], Optional[str]]
]:
"""Returns entities by kind.
When a limit is set, this function returns a list and a cursor.
Expand All @@ -339,7 +346,7 @@ async def get_entities_by_kind(self, kind: str,
async def get_entities_by_keys(self, keys: Iterable[Key],
missing: Optional[list[Any]] = None,
deferred: Optional[list[Key]] = None,
eventual: bool = False):
eventual: bool = False) -> list[Entity]:
"""Returns entity objects for the given keys or an empty list in case
no entity is found. The order of entities might not be equal to the
order of provided keys.
Expand Down Expand Up @@ -393,7 +400,7 @@ def data():
async def get_entity_by_key(self, key: Key,
missing: Optional[list[Any]] = None,
deferred: Optional[list[Key]] = None,
eventual: bool = False):
eventual: bool = False) -> Optional[Entity]:
"""Returns an entity object for the given key or None in case no
entity is found.
Expand All @@ -405,23 +412,24 @@ async def get_entity_by_key(self, key: Key,
if entity:
return entity[0]

async def _get_headers(self):
async def _get_headers(self) -> dict[str, str]:
token = await self._token.get()
return {
'Authorization': 'Bearer {}'.format(token),
'Content-Type': 'application/json'
}

@staticmethod
def _check_mutation_result(entity_or_key, mutation_result):
def _check_mutation_result(entity_or_key, mutation_result) -> bool:
if 'key' in mutation_result:
# The automatically allocated key.
# Set only when the mutation allocated a key.
entity_or_key.key = Key(mutation_result['key'])

return not mutation_result.get('conflictDetected', False)

async def _commit_entities_or_keys(self, entities_or_keys, method):
async def _commit_entities_or_keys(self, entities_or_keys, method) -> \
tuple[bool, ...]:
mutations = [
{method: entity_or_key.get_dict()}
for entity_or_key in entities_or_keys]
Expand Down
2 changes: 1 addition & 1 deletion aiogcd/connector/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def get_var_int64(self):

return result

def get_prefixed_string(self):
def get_prefixed_string(self) -> str:
n = self.get_var_int32()
if self._idx + n > len(self):
raise BufferDecodeError('truncated')
Expand Down
2 changes: 1 addition & 1 deletion aiogcd/connector/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

class Entity:

def __init__(self, entity_res):
def __init__(self, entity_res: dict):
"""Initialize an Entity object.
Example:
Expand Down
3 changes: 2 additions & 1 deletion aiogcd/connector/key.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ def _extract_id_or_name(pair):

@staticmethod
def _deserialize_ks(ks: str):
"""Returns a Key() object from a key string."""
"""Returns a tuple with the project_id, namespace_id and Path
from a key string."""

decoder = Decoder(ks=ks)
project_id = None
Expand Down
13 changes: 8 additions & 5 deletions aiogcd/connector/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,18 @@
Created on: May 19, 2017
Author: Jeroen van der Heijden <[email protected]>
"""
from typing import Iterable, Union
from .pathelement import PathElement
from .pathelement import path_element_from_decoder
from .buffer import BufferDecodeError


class Path:

def __init__(self, pairs):
self._path = tuple(
def __init__(self, pairs: Union[
Iterable[PathElement],
Iterable[tuple[str, Union[int, str]]]]):
self._path: tuple[PathElement] = tuple(
pe if isinstance(pe, PathElement) else PathElement(*pe)
for pe in pairs)

Expand All @@ -30,17 +33,17 @@ def __getitem__(self, item):
def __repr__(self):
return str(self.get_as_tuple())

def get_dict(self):
def get_dict(self) -> dict:
return {'path': [pe.get_dict() for pe in self._path]}

@property
def byte_size(self):
def byte_size(self) -> int:
n = 2 * len(self._path)
for path_element in self._path:
n += path_element.byte_size
return n

def get_as_tuple(self):
def get_as_tuple(self) -> tuple[tuple[str, Union[str, int]], ...]:
"""Returns a tuple of pairs (tuples) representing the key path of an
entity. Useful for composing entities with a specific ancestor."""
return tuple((pe.kind, pe.id) for pe in self._path)
Expand Down
13 changes: 7 additions & 6 deletions aiogcd/connector/pathelement.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Union
from .buffer import BufferDecodeError

TYPE_ID = 0
Expand All @@ -6,7 +7,7 @@

class PathElement:

def __init__(self, kind, name_or_id):
def __init__(self, kind: str, name_or_id: Union[str, int]):
assert name_or_id is None or isinstance(name_or_id, (int, str)), \
'Expecting a str or int type but got: {}'.format(
type(name_or_id))
Expand All @@ -28,7 +29,7 @@ def encode(self, buffer):
buffer.add_prefixed_string(self.id)

@property
def byte_size(self):
def byte_size(self) -> int:
n = self._size_str(self.kind)
if isinstance(self.id, int):
n += 1 + self._size_var_int(self.id)
Expand All @@ -37,7 +38,7 @@ def byte_size(self):

return n + 1

def get_dict(self):
def get_dict(self) -> dict:
if isinstance(self.id, int):
return {'kind': self.kind, 'id': str(self.id)}

Expand All @@ -47,12 +48,12 @@ def get_dict(self):
return {'kind': self.kind}

@classmethod
def _size_str(cls, s):
def _size_str(cls, s) -> int:
sz = len(s)
return cls._size_var_int(sz) + sz

@staticmethod
def _size_var_int(n):
def _size_var_int(n) -> int:
if n < 0:
return 10

Expand Down Expand Up @@ -89,4 +90,4 @@ def path_element_from_decoder(decoder) -> PathElement:
assert kind is not None and name_or_id is not None, \
'Expecting a path element with a kind and name/id.'

return PathElement(kind, name_or_id)
return PathElement(kind, name_or_id)
4 changes: 3 additions & 1 deletion aiogcd/orm/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ def __init__(self, model, *filters,
'value': {'keyValue': key.get_dict()},
'op': 'EQUAL'})

filter_dict: dict[str, Any] = {'query': {'kind': [{'name': self._model.get_kind()}]}}
filter_dict: dict[str, Any] = {
'query': {'kind': [{'name': self._model.get_kind()}]}
}

if self._model.__namespace__:
filter_dict['partitionId'] = {
Expand Down
3 changes: 2 additions & 1 deletion aiogcd/orm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ def del_property(self, prop):
super().del_property(prop)

@classmethod
def filter(cls, *filters: dict[str, Any], has_ancestor: Optional[Key] = None,
def filter(cls, *filters: dict[str, Any],
has_ancestor: Optional[Key] = None,
key: Optional[Key] = None):
return Filter(
cls,
Expand Down

0 comments on commit f8ee554

Please sign in to comment.