Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add iterable providers #22

Merged
merged 3 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ coverage.xml
# virtualenv
.venv
venv*/

.python-version

# python cached files
*.py[cod]
2 changes: 1 addition & 1 deletion Taskfile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ tasks:
desc: Lint python source files
cmds:
- "{{.RUNNER}} ruff check {{.SOURCES}}"
- "{{.RUNNER}} ruff format --checm {{.SOURCES}}"
- "{{.RUNNER}} ruff format --check {{.SOURCES}}"

format:
desc: Format python source files
Expand Down
25 changes: 11 additions & 14 deletions aioinject/_features/generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,18 @@
from types import GenericAlias
from typing import TYPE_CHECKING, Any, TypeGuard

from aioinject._utils import is_iterable_generic_collection


if TYPE_CHECKING:
from aioinject.providers import Dependency


def _is_generic_alias(type_: Any) -> TypeGuard[GenericAlias]:
# we currently don't support tuple, list, dict, set, type
return isinstance(
type_,
types.GenericAlias | t._GenericAlias, # type: ignore[attr-defined] # noqa: SLF001
) and t.get_origin(type_) not in (tuple, list, dict, set, type)
) and not is_iterable_generic_collection(type_)


def _get_orig_bases(type_: type) -> tuple[type, ...] | None:
Expand All @@ -33,51 +34,47 @@ def _get_generic_arguments(type_: Any) -> list[t.TypeVar] | None:
return None


@functools.lru_cache
@functools.cache
def _get_generic_args_map(type_: type[object]) -> dict[str, type[object]]:
if _is_generic_alias(type_):
args = type_.__args__
params: dict[str, Any] = {
param.__name__: param
for param in type_.__origin__.__parameters__ # type: ignore[attr-defined]
}
# TODO(Doctor, nrbnlulu): Tests pass with strct=True, is this needed?
return dict(zip(params, args, strict=False))
return dict(zip(params, type_.__args__, strict=True))

args_map = {}
if orig_bases := _get_orig_bases(type_):
# find the generic parent
for base in orig_bases:
if _is_generic_alias(base):
args = base.__args__
if _is_generic_alias(base): # noqa: SIM102
if params := {
param.__name__: param
for param in getattr(base.__origin__, "__parameters__", ())
}:
args_map.update(
dict(zip(params, args, strict=True)),
dict(zip(params, base.__args__, strict=True)),
)
return args_map


@functools.lru_cache
@functools.cache
def get_generic_parameter_map(
provided_type: type[object],
dependencies: tuple[Dependency[Any], ...],
) -> dict[str, type[object]]:
args_map = _get_generic_args_map(provided_type) # type: ignore[arg-type]
result = {}
for dependency in dependencies:
inner_type = dependency.inner_type
if args_map and (
generic_arguments := _get_generic_arguments(dependency.type_)
generic_arguments := _get_generic_arguments(inner_type)
):
# This is a generic type, we need to resolve the type arguments
# and pass them to the provider.
resolved_args = [
args_map[arg.__name__] for arg in generic_arguments
]
# We can use `[]` when we drop support for 3.10
result[dependency.name] = dependency.type_.__getitem__(
*resolved_args
)
result[dependency.name] = inner_type.__getitem__(*resolved_args)
return result
35 changes: 19 additions & 16 deletions aioinject/_store.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import asyncio
import collections
import contextlib
import enum
Expand All @@ -11,6 +10,8 @@
from types import TracebackType
from typing import TYPE_CHECKING, Any, Literal, TypeVar

import anyio

from aioinject._utils import enter_context_maybe, enter_sync_context_maybe
from aioinject.providers import DependencyLifetime

Expand All @@ -33,28 +34,28 @@ def __init__(
exit_stack: contextlib.AsyncExitStack | None = None,
sync_exit_stack: contextlib.ExitStack | None = None,
) -> None:
self._cache: dict[type, Any] = {}
self._cache: dict[Provider[Any], Any] = {}
self._exit_stack = exit_stack or contextlib.AsyncExitStack()
self._sync_exit_stack = sync_exit_stack or contextlib.ExitStack()

def get(self, provider: Provider[T]) -> T | Literal[NotInCache.sentinel]:
return self._cache.get(provider.type_, NotInCache.sentinel)
return self._cache.get(provider, NotInCache.sentinel)

def add(self, provider: Provider[T], obj: T) -> None:
if provider.lifetime is not DependencyLifetime.transient:
self._cache[provider.type_] = obj
self._cache[provider] = obj

def lock(
self,
provider: Provider[Any],
) -> AbstractAsyncContextManager[bool]:
return contextlib.nullcontext(provider.type_ not in self._cache)
return contextlib.nullcontext(provider not in self._cache)

def sync_lock(
self,
provider: Provider[Any],
) -> AbstractContextManager[bool]:
return contextlib.nullcontext(provider.type_ not in self._cache)
return contextlib.nullcontext(provider not in self._cache)

@typing.overload
async def enter_context(
Expand Down Expand Up @@ -119,18 +120,20 @@ def __init__(
sync_exit_stack: contextlib.ExitStack | None = None,
) -> None:
super().__init__(exit_stack, sync_exit_stack)
self._locks: dict[type, asyncio.Lock] = collections.defaultdict(
asyncio.Lock,
self._locks: dict[Provider[Any], anyio.Lock] = collections.defaultdict(
anyio.Lock,
)
self._sync_locks: dict[type, threading.Lock] = collections.defaultdict(
threading.Lock,
self._sync_locks: dict[Provider[Any], threading.Lock] = (
collections.defaultdict(
threading.Lock,
)
)

@contextlib.asynccontextmanager
async def lock(self, provider: Provider[Any]) -> AsyncIterator[bool]:
if provider.type_ not in self._cache:
async with self._locks[provider.type_]:
yield provider.type_ not in self._cache
if provider not in self._cache:
async with self._locks[provider]:
yield provider not in self._cache
return
yield False

Expand All @@ -139,8 +142,8 @@ def sync_lock(
self,
provider: Provider[Any],
) -> Iterator[bool]:
if provider.type_ not in self._cache:
with self._sync_locks[provider.type_]:
yield provider.type_ not in self._cache
if provider not in self._cache:
with self._sync_locks[provider]:
yield provider not in self._cache
return
yield False
2 changes: 1 addition & 1 deletion aioinject/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@

P = ParamSpec("P")
T = TypeVar("T")
Providers: TypeAlias = dict[type[T], "Provider[T]"]
Providers: TypeAlias = dict[type[T], list["Provider[T]"]]
AnyCtx: TypeAlias = Union["InjectionContext", "SyncInjectionContext"]
23 changes: 18 additions & 5 deletions aioinject/_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from __future__ import annotations

import collections.abc
import contextlib
import functools
import inspect
import sys
import typing
Expand All @@ -15,7 +19,7 @@


_T = TypeVar("_T")
_F = TypeVar("_F", bound=typing.Callable[..., Any])
_F = TypeVar("_F", bound=Callable[..., Any])

sentinel = object()

Expand All @@ -37,7 +41,7 @@ def clear_wrapper(wrapper: _F) -> _F:


def get_inject_annotations(
function: typing.Callable[..., Any],
function: Callable[..., Any],
) -> dict[str, Any]:
with remove_annotation(function.__annotations__, "return"):
return {
Expand Down Expand Up @@ -67,13 +71,13 @@ async def enter_context_maybe(
),
stack: AsyncExitStack,
) -> _T:
if isinstance(resolved, contextlib.ContextDecorator):
return stack.enter_context(resolved) # type: ignore[arg-type]

if isinstance(resolved, contextlib.AsyncContextDecorator):
return await stack.enter_async_context(
resolved, # type: ignore[arg-type]
)
if isinstance(resolved, contextlib.ContextDecorator):
return stack.enter_context(resolved) # type: ignore[arg-type]

return resolved # type: ignore[return-value]


Expand Down Expand Up @@ -115,3 +119,12 @@ def get_return_annotation(
context: dict[str, Any],
) -> type[Any]:
return eval(ret_annotation, context) # noqa: S307


@functools.cache
def is_iterable_generic_collection(type_: Any) -> bool:
if not (origin := typing.get_origin(type_)):
return False
return collections.abc.Iterable in inspect.getmro(origin) or issubclass(
origin, collections.abc.Iterable
)
69 changes: 45 additions & 24 deletions aioinject/containers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import contextlib
from collections import defaultdict
from collections.abc import Iterator, Sequence
from contextlib import AsyncExitStack
from types import TracebackType
Expand All @@ -25,7 +26,7 @@ def __init__(self, extensions: Sequence[Extension] | None = None) -> None:
self._exit_stack = AsyncExitStack()
self._singletons = SingletonStore(exit_stack=self._exit_stack)

self.providers: _types.Providers[Any] = {}
self.providers: _types.Providers[Any] = defaultdict(list)
self.type_context: dict[str, type[Any]] = {}
self.extensions = extensions or []
self._init_extensions(self.extensions)
Expand All @@ -35,27 +36,42 @@ def _init_extensions(self, extensions: Sequence[Extension]) -> None:
if isinstance(extension, OnInitExtension):
extension.on_init(self)

def register(
self,
*providers: Provider[Any],
) -> None:
def register(self, *providers: Provider[Any]) -> None:
for provider in providers:
if provider.type_ in self.providers:
msg = (
f"Provider for type {provider.type_} is already registered"
)
raise ValueError(msg)
self._register(provider)

self.providers[provider.type_] = provider
if class_name := getattr(provider.type_, "__name__", None):
self.type_context[class_name] = provider.type_
def try_register(self, *providers: Provider[Any]) -> None:
for provider in providers:
with contextlib.suppress(ValueError):
self._register(provider)

def _register(self, provider: Provider[Any]) -> None:
existing_impls = {
existing_provider.impl
for existing_provider in self.providers.get(provider.type_, [])
}
if provider.impl in existing_impls:
msg = (
f"Provider for type {provider.type_} with same "
f"implementation already registered"
)
raise ValueError(msg)

self.providers[provider.type_].append(provider)

class_name = getattr(provider.type_, "__name__", None)
if class_name and class_name not in self.type_context:
self.type_context[class_name] = provider.type_

def get_provider(self, type_: type[T]) -> Provider[T]:
try:
return self.providers[type_]
except KeyError as exc:
err_msg = f"Provider for type {type_.__qualname__} not found"
raise ValueError(err_msg) from exc
return self.get_providers(type_)[0]

def get_providers(self, type_: type[T]) -> list[Provider[T]]:
if providers := self.providers[type_]:
return providers

err_msg = f"Providers for type {type_.__qualname__} not found"
raise ValueError(err_msg)

def context(
self,
Expand All @@ -79,18 +95,23 @@ def sync_context(

@contextlib.contextmanager
def override(self, *providers: Provider[Any]) -> Iterator[None]:
previous: dict[type[Any], Provider[Any] | None] = {}
for provider in providers:
previous[provider.type_] = self.providers.get(provider.type_)
self.providers[provider.type_] = provider
previous = {
provider.type_: self.providers.get(provider.type_, None)
for provider in providers
}
overridden = defaultdict(
list,
{provider.type_: [provider] for provider in providers},
)

self.providers.update(overridden)

try:
yield
finally:
for provider in providers:
del self.providers[provider.type_]
prev = previous[provider.type_]
if prev is not None:
if (prev := previous[provider.type_]) is not None:
self.providers[provider.type_] = prev

async def __aenter__(self) -> Self:
Expand Down
Loading
Loading