Skip to content

Commit

Permalink
Merge pull request #22 from fadedDexofan/feat/iterable-providers
Browse files Browse the repository at this point in the history
feat: add iterable providers
  • Loading branch information
ThirVondukr authored Jan 9, 2025
2 parents a4a7b50 + d7952ab commit a2b82d1
Show file tree
Hide file tree
Showing 21 changed files with 461 additions and 154 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ coverage.xml
# virtualenv
.venv
venv*/

.python-version

# python cached files
*.py[cod]
2 changes: 1 addition & 1 deletion Taskfile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ tasks:
desc: Lint python source files
cmds:
- "{{.RUNNER}} ruff check {{.SOURCES}}"
- "{{.RUNNER}} ruff format --checm {{.SOURCES}}"
- "{{.RUNNER}} ruff format --check {{.SOURCES}}"

format:
desc: Format python source files
Expand Down
25 changes: 11 additions & 14 deletions aioinject/_features/generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,18 @@
from types import GenericAlias
from typing import TYPE_CHECKING, Any, TypeGuard

from aioinject._utils import is_iterable_generic_collection


if TYPE_CHECKING:
from aioinject.providers import Dependency


def _is_generic_alias(type_: Any) -> TypeGuard[GenericAlias]:
# we currently don't support tuple, list, dict, set, type
return isinstance(
type_,
types.GenericAlias | t._GenericAlias, # type: ignore[attr-defined] # noqa: SLF001
) and t.get_origin(type_) not in (tuple, list, dict, set, type)
) and not is_iterable_generic_collection(type_)


def _get_orig_bases(type_: type) -> tuple[type, ...] | None:
Expand All @@ -33,51 +34,47 @@ def _get_generic_arguments(type_: Any) -> list[t.TypeVar] | None:
return None


@functools.lru_cache
@functools.cache
def _get_generic_args_map(type_: type[object]) -> dict[str, type[object]]:
if _is_generic_alias(type_):
args = type_.__args__
params: dict[str, Any] = {
param.__name__: param
for param in type_.__origin__.__parameters__ # type: ignore[attr-defined]
}
# TODO(Doctor, nrbnlulu): Tests pass with strct=True, is this needed?
return dict(zip(params, args, strict=False))
return dict(zip(params, type_.__args__, strict=True))

args_map = {}
if orig_bases := _get_orig_bases(type_):
# find the generic parent
for base in orig_bases:
if _is_generic_alias(base):
args = base.__args__
if _is_generic_alias(base): # noqa: SIM102
if params := {
param.__name__: param
for param in getattr(base.__origin__, "__parameters__", ())
}:
args_map.update(
dict(zip(params, args, strict=True)),
dict(zip(params, base.__args__, strict=True)),
)
return args_map


@functools.lru_cache
@functools.cache
def get_generic_parameter_map(
provided_type: type[object],
dependencies: tuple[Dependency[Any], ...],
) -> dict[str, type[object]]:
args_map = _get_generic_args_map(provided_type) # type: ignore[arg-type]
result = {}
for dependency in dependencies:
inner_type = dependency.inner_type
if args_map and (
generic_arguments := _get_generic_arguments(dependency.type_)
generic_arguments := _get_generic_arguments(inner_type)
):
# This is a generic type, we need to resolve the type arguments
# and pass them to the provider.
resolved_args = [
args_map[arg.__name__] for arg in generic_arguments
]
# We can use `[]` when we drop support for 3.10
result[dependency.name] = dependency.type_.__getitem__(
*resolved_args
)
result[dependency.name] = inner_type.__getitem__(*resolved_args)
return result
35 changes: 19 additions & 16 deletions aioinject/_store.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import asyncio
import collections
import contextlib
import enum
Expand All @@ -11,6 +10,8 @@
from types import TracebackType
from typing import TYPE_CHECKING, Any, Literal, TypeVar

import anyio

from aioinject._utils import enter_context_maybe, enter_sync_context_maybe
from aioinject.providers import DependencyLifetime

Expand All @@ -33,28 +34,28 @@ def __init__(
exit_stack: contextlib.AsyncExitStack | None = None,
sync_exit_stack: contextlib.ExitStack | None = None,
) -> None:
self._cache: dict[type, Any] = {}
self._cache: dict[Provider[Any], Any] = {}
self._exit_stack = exit_stack or contextlib.AsyncExitStack()
self._sync_exit_stack = sync_exit_stack or contextlib.ExitStack()

def get(self, provider: Provider[T]) -> T | Literal[NotInCache.sentinel]:
return self._cache.get(provider.type_, NotInCache.sentinel)
return self._cache.get(provider, NotInCache.sentinel)

def add(self, provider: Provider[T], obj: T) -> None:
if provider.lifetime is not DependencyLifetime.transient:
self._cache[provider.type_] = obj
self._cache[provider] = obj

def lock(
self,
provider: Provider[Any],
) -> AbstractAsyncContextManager[bool]:
return contextlib.nullcontext(provider.type_ not in self._cache)
return contextlib.nullcontext(provider not in self._cache)

def sync_lock(
self,
provider: Provider[Any],
) -> AbstractContextManager[bool]:
return contextlib.nullcontext(provider.type_ not in self._cache)
return contextlib.nullcontext(provider not in self._cache)

@typing.overload
async def enter_context(
Expand Down Expand Up @@ -119,18 +120,20 @@ def __init__(
sync_exit_stack: contextlib.ExitStack | None = None,
) -> None:
super().__init__(exit_stack, sync_exit_stack)
self._locks: dict[type, asyncio.Lock] = collections.defaultdict(
asyncio.Lock,
self._locks: dict[Provider[Any], anyio.Lock] = collections.defaultdict(
anyio.Lock,
)
self._sync_locks: dict[type, threading.Lock] = collections.defaultdict(
threading.Lock,
self._sync_locks: dict[Provider[Any], threading.Lock] = (
collections.defaultdict(
threading.Lock,
)
)

@contextlib.asynccontextmanager
async def lock(self, provider: Provider[Any]) -> AsyncIterator[bool]:
if provider.type_ not in self._cache:
async with self._locks[provider.type_]:
yield provider.type_ not in self._cache
if provider not in self._cache:
async with self._locks[provider]:
yield provider not in self._cache
return
yield False

Expand All @@ -139,8 +142,8 @@ def sync_lock(
self,
provider: Provider[Any],
) -> Iterator[bool]:
if provider.type_ not in self._cache:
with self._sync_locks[provider.type_]:
yield provider.type_ not in self._cache
if provider not in self._cache:
with self._sync_locks[provider]:
yield provider not in self._cache
return
yield False
2 changes: 1 addition & 1 deletion aioinject/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@

P = ParamSpec("P")
T = TypeVar("T")
Providers: TypeAlias = dict[type[T], "Provider[T]"]
Providers: TypeAlias = dict[type[T], list["Provider[T]"]]
AnyCtx: TypeAlias = Union["InjectionContext", "SyncInjectionContext"]
23 changes: 18 additions & 5 deletions aioinject/_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from __future__ import annotations

import collections.abc
import contextlib
import functools
import inspect
import sys
import typing
Expand All @@ -15,7 +19,7 @@


_T = TypeVar("_T")
_F = TypeVar("_F", bound=typing.Callable[..., Any])
_F = TypeVar("_F", bound=Callable[..., Any])

sentinel = object()

Expand All @@ -37,7 +41,7 @@ def clear_wrapper(wrapper: _F) -> _F:


def get_inject_annotations(
function: typing.Callable[..., Any],
function: Callable[..., Any],
) -> dict[str, Any]:
with remove_annotation(function.__annotations__, "return"):
return {
Expand Down Expand Up @@ -67,13 +71,13 @@ async def enter_context_maybe(
),
stack: AsyncExitStack,
) -> _T:
if isinstance(resolved, contextlib.ContextDecorator):
return stack.enter_context(resolved) # type: ignore[arg-type]

if isinstance(resolved, contextlib.AsyncContextDecorator):
return await stack.enter_async_context(
resolved, # type: ignore[arg-type]
)
if isinstance(resolved, contextlib.ContextDecorator):
return stack.enter_context(resolved) # type: ignore[arg-type]

return resolved # type: ignore[return-value]


Expand Down Expand Up @@ -115,3 +119,12 @@ def get_return_annotation(
context: dict[str, Any],
) -> type[Any]:
return eval(ret_annotation, context) # noqa: S307


@functools.cache
def is_iterable_generic_collection(type_: Any) -> bool:
if not (origin := typing.get_origin(type_)):
return False
return collections.abc.Iterable in inspect.getmro(origin) or issubclass(
origin, collections.abc.Iterable
)
69 changes: 45 additions & 24 deletions aioinject/containers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import contextlib
from collections import defaultdict
from collections.abc import Iterator, Sequence
from contextlib import AsyncExitStack
from types import TracebackType
Expand All @@ -25,7 +26,7 @@ def __init__(self, extensions: Sequence[Extension] | None = None) -> None:
self._exit_stack = AsyncExitStack()
self._singletons = SingletonStore(exit_stack=self._exit_stack)

self.providers: _types.Providers[Any] = {}
self.providers: _types.Providers[Any] = defaultdict(list)
self.type_context: dict[str, type[Any]] = {}
self.extensions = extensions or []
self._init_extensions(self.extensions)
Expand All @@ -35,27 +36,42 @@ def _init_extensions(self, extensions: Sequence[Extension]) -> None:
if isinstance(extension, OnInitExtension):
extension.on_init(self)

def register(
self,
*providers: Provider[Any],
) -> None:
def register(self, *providers: Provider[Any]) -> None:
for provider in providers:
if provider.type_ in self.providers:
msg = (
f"Provider for type {provider.type_} is already registered"
)
raise ValueError(msg)
self._register(provider)

self.providers[provider.type_] = provider
if class_name := getattr(provider.type_, "__name__", None):
self.type_context[class_name] = provider.type_
def try_register(self, *providers: Provider[Any]) -> None:
for provider in providers:
with contextlib.suppress(ValueError):
self._register(provider)

def _register(self, provider: Provider[Any]) -> None:
existing_impls = {
existing_provider.impl
for existing_provider in self.providers.get(provider.type_, [])
}
if provider.impl in existing_impls:
msg = (
f"Provider for type {provider.type_} with same "
f"implementation already registered"
)
raise ValueError(msg)

self.providers[provider.type_].append(provider)

class_name = getattr(provider.type_, "__name__", None)
if class_name and class_name not in self.type_context:
self.type_context[class_name] = provider.type_

def get_provider(self, type_: type[T]) -> Provider[T]:
try:
return self.providers[type_]
except KeyError as exc:
err_msg = f"Provider for type {type_.__qualname__} not found"
raise ValueError(err_msg) from exc
return self.get_providers(type_)[0]

def get_providers(self, type_: type[T]) -> list[Provider[T]]:
if providers := self.providers[type_]:
return providers

err_msg = f"Providers for type {type_.__qualname__} not found"
raise ValueError(err_msg)

def context(
self,
Expand All @@ -79,18 +95,23 @@ def sync_context(

@contextlib.contextmanager
def override(self, *providers: Provider[Any]) -> Iterator[None]:
previous: dict[type[Any], Provider[Any] | None] = {}
for provider in providers:
previous[provider.type_] = self.providers.get(provider.type_)
self.providers[provider.type_] = provider
previous = {
provider.type_: self.providers.get(provider.type_, None)
for provider in providers
}
overridden = defaultdict(
list,
{provider.type_: [provider] for provider in providers},
)

self.providers.update(overridden)

try:
yield
finally:
for provider in providers:
del self.providers[provider.type_]
prev = previous[provider.type_]
if prev is not None:
if (prev := previous[provider.type_]) is not None:
self.providers[provider.type_] = prev

async def __aenter__(self) -> Self:
Expand Down
Loading

0 comments on commit a2b82d1

Please sign in to comment.