Skip to content

Commit

Permalink
feat: add iterable providers
Browse files Browse the repository at this point in the history
  • Loading branch information
fadedDexofan committed Jan 4, 2025
1 parent 21ced5b commit 990f5ab
Show file tree
Hide file tree
Showing 12 changed files with 330 additions and 99 deletions.
2 changes: 1 addition & 1 deletion Taskfile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ tasks:
desc: Lint python source files
cmds:
- "{{.RUNNER}} ruff check {{.SOURCES}}"
- "{{.RUNNER}} ruff format --checm {{.SOURCES}}"
- "{{.RUNNER}} ruff format --check {{.SOURCES}}"

format:
desc: Format python source files
Expand Down
46 changes: 30 additions & 16 deletions aioinject/_store.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import asyncio
import collections
import contextlib
import enum
Expand All @@ -11,6 +10,8 @@
from types import TracebackType
from typing import TYPE_CHECKING, Any, Literal, TypeVar

import anyio

from aioinject._utils import enter_context_maybe, enter_sync_context_maybe
from aioinject.providers import DependencyLifetime

Expand All @@ -33,28 +34,32 @@ def __init__(
exit_stack: contextlib.AsyncExitStack | None = None,
sync_exit_stack: contextlib.ExitStack | None = None,
) -> None:
self._cache: dict[type, Any] = {}
self._cache: dict[tuple[type[Any], Any], Any] = {}
self._exit_stack = exit_stack or contextlib.AsyncExitStack()
self._sync_exit_stack = sync_exit_stack or contextlib.ExitStack()

def get(self, provider: Provider[T]) -> T | Literal[NotInCache.sentinel]:
return self._cache.get(provider.type_, NotInCache.sentinel)
return self._cache.get(self._cache_key(provider), NotInCache.sentinel)

def add(self, provider: Provider[T], obj: T) -> None:
if provider.lifetime is not DependencyLifetime.transient:
self._cache[provider.type_] = obj
self._cache[self._cache_key(provider)] = obj

def lock(
self,
provider: Provider[Any],
) -> AbstractAsyncContextManager[bool]:
return contextlib.nullcontext(provider.type_ not in self._cache)
return contextlib.nullcontext(
self._cache_key(provider) not in self._cache
)

def sync_lock(
self,
provider: Provider[Any],
) -> AbstractContextManager[bool]:
return contextlib.nullcontext(provider.type_ not in self._cache)
return contextlib.nullcontext(
self._cache_key(provider) not in self._cache
)

@typing.overload
async def enter_context(
Expand Down Expand Up @@ -111,6 +116,9 @@ def __exit__(
def close(self) -> None:
self.__exit__(None, None, None)

def _cache_key(self, provider: Provider[T]) -> tuple[type[T], Any]:
return provider.type_, provider.impl


class SingletonStore(InstanceStore):
def __init__(
Expand All @@ -119,18 +127,23 @@ def __init__(
sync_exit_stack: contextlib.ExitStack | None = None,
) -> None:
super().__init__(exit_stack, sync_exit_stack)
self._locks: dict[type, asyncio.Lock] = collections.defaultdict(
asyncio.Lock,
self._locks: dict[tuple[type[Any], Any], anyio.Lock] = (
collections.defaultdict(
anyio.Lock,
)
)
self._sync_locks: dict[type, threading.Lock] = collections.defaultdict(
threading.Lock,
self._sync_locks: dict[tuple[type[Any], Any], threading.Lock] = (
collections.defaultdict(
threading.Lock,
)
)

@contextlib.asynccontextmanager
async def lock(self, provider: Provider[Any]) -> AsyncIterator[bool]:
if provider.type_ not in self._cache:
async with self._locks[provider.type_]:
yield provider.type_ not in self._cache
cache_key = self._cache_key(provider)
if cache_key not in self._cache:
async with self._locks[cache_key]:
yield cache_key not in self._cache
return
yield False

Expand All @@ -139,8 +152,9 @@ def sync_lock(
self,
provider: Provider[Any],
) -> Iterator[bool]:
if provider.type_ not in self._cache:
with self._sync_locks[provider.type_]:
yield provider.type_ not in self._cache
cache_key = self._cache_key(provider)
if cache_key not in self._cache:
with self._sync_locks[cache_key]:
yield cache_key not in self._cache
return
yield False
2 changes: 1 addition & 1 deletion aioinject/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@

P = ParamSpec("P")
T = TypeVar("T")
Providers: TypeAlias = dict[type[T], "Provider[T]"]
Providers: TypeAlias = dict[type[T], list["Provider[T]"]]
AnyCtx: TypeAlias = Union["InjectionContext", "SyncInjectionContext"]
15 changes: 13 additions & 2 deletions aioinject/_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from __future__ import annotations

import collections.abc
import contextlib
import inspect
import sys
Expand All @@ -15,7 +18,7 @@


_T = TypeVar("_T")
_F = TypeVar("_F", bound=typing.Callable[..., Any])
_F = TypeVar("_F", bound=Callable[..., Any])

sentinel = object()

Expand All @@ -37,7 +40,7 @@ def clear_wrapper(wrapper: _F) -> _F:


def get_inject_annotations(
function: typing.Callable[..., Any],
function: Callable[..., Any],
) -> dict[str, Any]:
with remove_annotation(function.__annotations__, "return"):
return {
Expand Down Expand Up @@ -115,3 +118,11 @@ def get_return_annotation(
context: dict[str, Any],
) -> type[Any]:
return eval(ret_annotation, context) # noqa: S307


def is_iterable_generic_collection(type_: Any) -> bool:
if not (origin := typing.get_origin(type_)):
return False
return collections.abc.Iterable in inspect.getmro(origin) or issubclass(
origin, collections.abc.Iterable
)
69 changes: 45 additions & 24 deletions aioinject/containers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import contextlib
from collections import defaultdict
from collections.abc import Iterator, Sequence
from contextlib import AsyncExitStack
from types import TracebackType
Expand All @@ -25,7 +26,7 @@ def __init__(self, extensions: Sequence[Extension] | None = None) -> None:
self._exit_stack = AsyncExitStack()
self._singletons = SingletonStore(exit_stack=self._exit_stack)

self.providers: _types.Providers[Any] = {}
self.providers: _types.Providers[Any] = defaultdict(list)
self.type_context: dict[str, type[Any]] = {}
self.extensions = extensions or []
self._init_extensions(self.extensions)
Expand All @@ -35,27 +36,42 @@ def _init_extensions(self, extensions: Sequence[Extension]) -> None:
if isinstance(extension, OnInitExtension):
extension.on_init(self)

def register(
self,
*providers: Provider[Any],
) -> None:
def register(self, *providers: Provider[Any]) -> None:
for provider in providers:
if provider.type_ in self.providers:
msg = (
f"Provider for type {provider.type_} is already registered"
)
raise ValueError(msg)
self._register(provider)

self.providers[provider.type_] = provider
if class_name := getattr(provider.type_, "__name__", None):
self.type_context[class_name] = provider.type_
def try_register(self, *providers: Provider[Any]) -> None:
for provider in providers:
with contextlib.suppress(ValueError):
self._register(provider)

def _register(self, provider: Provider[Any]) -> None:
existing_impls = {
existing_provider.impl
for existing_provider in self.providers.get(provider.type_, [])
}
if provider.impl in existing_impls:
msg = (
f"Provider for type {provider.type_} with same "
f"implementation already registered"
)
raise ValueError(msg)

self.providers[provider.type_].append(provider)

class_name = getattr(provider.type_, "__name__", None)
if class_name and class_name not in self.type_context:
self.type_context[class_name] = provider.type_

def get_provider(self, type_: type[T]) -> Provider[T]:
try:
return self.providers[type_]
except KeyError as exc:
err_msg = f"Provider for type {type_.__qualname__} not found"
raise ValueError(err_msg) from exc
return self.get_providers(type_)[0]

def get_providers(self, type_: type[T]) -> list[Provider[T]]:
if providers := self.providers[type_]:
return providers

err_msg = f"Providers for type {type_.__qualname__} not found"
raise ValueError(err_msg)

def context(
self,
Expand All @@ -79,18 +95,23 @@ def sync_context(

@contextlib.contextmanager
def override(self, *providers: Provider[Any]) -> Iterator[None]:
previous: dict[type[Any], Provider[Any] | None] = {}
for provider in providers:
previous[provider.type_] = self.providers.get(provider.type_)
self.providers[provider.type_] = provider
previous = {
provider.type_: self.providers.get(provider.type_, None)
for provider in providers
}
overridden = defaultdict(
list,
{provider.type_: [provider] for provider in providers},
)

self.providers.update(overridden)

try:
yield
finally:
for provider in providers:
del self.providers[provider.type_]
prev = previous[provider.type_]
if prev is not None:
if (prev := previous[provider.type_]) is not None:
self.providers[provider.type_] = prev

async def __aenter__(self) -> Self:
Expand Down
Loading

0 comments on commit 990f5ab

Please sign in to comment.