diff --git a/.gitignore b/.gitignore index bfd22c9..0e17e1a 100644 --- a/.gitignore +++ b/.gitignore @@ -16,7 +16,7 @@ coverage.xml # virtualenv .venv venv*/ - +.python-version # python cached files *.py[cod] diff --git a/Taskfile.yml b/Taskfile.yml index 75505c3..3605cce 100644 --- a/Taskfile.yml +++ b/Taskfile.yml @@ -14,7 +14,7 @@ tasks: desc: Lint python source files cmds: - "{{.RUNNER}} ruff check {{.SOURCES}}" - - "{{.RUNNER}} ruff format --checm {{.SOURCES}}" + - "{{.RUNNER}} ruff format --check {{.SOURCES}}" format: desc: Format python source files diff --git a/aioinject/_features/generics.py b/aioinject/_features/generics.py index bad3c97..9c7a262 100644 --- a/aioinject/_features/generics.py +++ b/aioinject/_features/generics.py @@ -6,17 +6,18 @@ from types import GenericAlias from typing import TYPE_CHECKING, Any, TypeGuard +from aioinject._utils import is_iterable_generic_collection + if TYPE_CHECKING: from aioinject.providers import Dependency def _is_generic_alias(type_: Any) -> TypeGuard[GenericAlias]: - # we currently don't support tuple, list, dict, set, type return isinstance( type_, types.GenericAlias | t._GenericAlias, # type: ignore[attr-defined] # noqa: SLF001 - ) and t.get_origin(type_) not in (tuple, list, dict, set, type) + ) and not is_iterable_generic_collection(type_) def _get_orig_bases(type_: type) -> tuple[type, ...] | None: @@ -33,34 +34,31 @@ def _get_generic_arguments(type_: Any) -> list[t.TypeVar] | None: return None -@functools.lru_cache +@functools.cache def _get_generic_args_map(type_: type[object]) -> dict[str, type[object]]: if _is_generic_alias(type_): - args = type_.__args__ params: dict[str, Any] = { param.__name__: param for param in type_.__origin__.__parameters__ # type: ignore[attr-defined] } - # TODO(Doctor, nrbnlulu): Tests pass with strct=True, is this needed? - return dict(zip(params, args, strict=False)) + return dict(zip(params, type_.__args__, strict=True)) args_map = {} if orig_bases := _get_orig_bases(type_): # find the generic parent for base in orig_bases: - if _is_generic_alias(base): - args = base.__args__ + if _is_generic_alias(base): # noqa: SIM102 if params := { param.__name__: param for param in getattr(base.__origin__, "__parameters__", ()) }: args_map.update( - dict(zip(params, args, strict=True)), + dict(zip(params, base.__args__, strict=True)), ) return args_map -@functools.lru_cache +@functools.cache def get_generic_parameter_map( provided_type: type[object], dependencies: tuple[Dependency[Any], ...], @@ -68,8 +66,9 @@ def get_generic_parameter_map( args_map = _get_generic_args_map(provided_type) # type: ignore[arg-type] result = {} for dependency in dependencies: + inner_type = dependency.inner_type if args_map and ( - generic_arguments := _get_generic_arguments(dependency.type_) + generic_arguments := _get_generic_arguments(inner_type) ): # This is a generic type, we need to resolve the type arguments # and pass them to the provider. @@ -77,7 +76,5 @@ def get_generic_parameter_map( args_map[arg.__name__] for arg in generic_arguments ] # We can use `[]` when we drop support for 3.10 - result[dependency.name] = dependency.type_.__getitem__( - *resolved_args - ) + result[dependency.name] = inner_type.__getitem__(*resolved_args) return result diff --git a/aioinject/_store.py b/aioinject/_store.py index 16f0469..23eeac4 100644 --- a/aioinject/_store.py +++ b/aioinject/_store.py @@ -1,6 +1,5 @@ from __future__ import annotations -import asyncio import collections import contextlib import enum @@ -11,6 +10,8 @@ from types import TracebackType from typing import TYPE_CHECKING, Any, Literal, TypeVar +import anyio + from aioinject._utils import enter_context_maybe, enter_sync_context_maybe from aioinject.providers import DependencyLifetime @@ -33,28 +34,28 @@ def __init__( exit_stack: contextlib.AsyncExitStack | None = None, sync_exit_stack: contextlib.ExitStack | None = None, ) -> None: - self._cache: dict[type, Any] = {} + self._cache: dict[Provider[Any], Any] = {} self._exit_stack = exit_stack or contextlib.AsyncExitStack() self._sync_exit_stack = sync_exit_stack or contextlib.ExitStack() def get(self, provider: Provider[T]) -> T | Literal[NotInCache.sentinel]: - return self._cache.get(provider.type_, NotInCache.sentinel) + return self._cache.get(provider, NotInCache.sentinel) def add(self, provider: Provider[T], obj: T) -> None: if provider.lifetime is not DependencyLifetime.transient: - self._cache[provider.type_] = obj + self._cache[provider] = obj def lock( self, provider: Provider[Any], ) -> AbstractAsyncContextManager[bool]: - return contextlib.nullcontext(provider.type_ not in self._cache) + return contextlib.nullcontext(provider not in self._cache) def sync_lock( self, provider: Provider[Any], ) -> AbstractContextManager[bool]: - return contextlib.nullcontext(provider.type_ not in self._cache) + return contextlib.nullcontext(provider not in self._cache) @typing.overload async def enter_context( @@ -119,18 +120,20 @@ def __init__( sync_exit_stack: contextlib.ExitStack | None = None, ) -> None: super().__init__(exit_stack, sync_exit_stack) - self._locks: dict[type, asyncio.Lock] = collections.defaultdict( - asyncio.Lock, + self._locks: dict[Provider[Any], anyio.Lock] = collections.defaultdict( + anyio.Lock, ) - self._sync_locks: dict[type, threading.Lock] = collections.defaultdict( - threading.Lock, + self._sync_locks: dict[Provider[Any], threading.Lock] = ( + collections.defaultdict( + threading.Lock, + ) ) @contextlib.asynccontextmanager async def lock(self, provider: Provider[Any]) -> AsyncIterator[bool]: - if provider.type_ not in self._cache: - async with self._locks[provider.type_]: - yield provider.type_ not in self._cache + if provider not in self._cache: + async with self._locks[provider]: + yield provider not in self._cache return yield False @@ -139,8 +142,8 @@ def sync_lock( self, provider: Provider[Any], ) -> Iterator[bool]: - if provider.type_ not in self._cache: - with self._sync_locks[provider.type_]: - yield provider.type_ not in self._cache + if provider not in self._cache: + with self._sync_locks[provider]: + yield provider not in self._cache return yield False diff --git a/aioinject/_types.py b/aioinject/_types.py index b0a5916..163549c 100644 --- a/aioinject/_types.py +++ b/aioinject/_types.py @@ -6,5 +6,5 @@ P = ParamSpec("P") T = TypeVar("T") -Providers: TypeAlias = dict[type[T], "Provider[T]"] +Providers: TypeAlias = dict[type[T], list["Provider[T]"]] AnyCtx: TypeAlias = Union["InjectionContext", "SyncInjectionContext"] diff --git a/aioinject/_utils.py b/aioinject/_utils.py index e2381fb..943a41d 100644 --- a/aioinject/_utils.py +++ b/aioinject/_utils.py @@ -1,4 +1,8 @@ +from __future__ import annotations + +import collections.abc import contextlib +import functools import inspect import sys import typing @@ -15,7 +19,7 @@ _T = TypeVar("_T") -_F = TypeVar("_F", bound=typing.Callable[..., Any]) +_F = TypeVar("_F", bound=Callable[..., Any]) sentinel = object() @@ -37,7 +41,7 @@ def clear_wrapper(wrapper: _F) -> _F: def get_inject_annotations( - function: typing.Callable[..., Any], + function: Callable[..., Any], ) -> dict[str, Any]: with remove_annotation(function.__annotations__, "return"): return { @@ -67,13 +71,13 @@ async def enter_context_maybe( ), stack: AsyncExitStack, ) -> _T: - if isinstance(resolved, contextlib.ContextDecorator): - return stack.enter_context(resolved) # type: ignore[arg-type] - if isinstance(resolved, contextlib.AsyncContextDecorator): return await stack.enter_async_context( resolved, # type: ignore[arg-type] ) + if isinstance(resolved, contextlib.ContextDecorator): + return stack.enter_context(resolved) # type: ignore[arg-type] + return resolved # type: ignore[return-value] @@ -115,3 +119,12 @@ def get_return_annotation( context: dict[str, Any], ) -> type[Any]: return eval(ret_annotation, context) # noqa: S307 + + +@functools.cache +def is_iterable_generic_collection(type_: Any) -> bool: + if not (origin := typing.get_origin(type_)): + return False + return collections.abc.Iterable in inspect.getmro(origin) or issubclass( + origin, collections.abc.Iterable + ) diff --git a/aioinject/containers.py b/aioinject/containers.py index 75db95b..a5c94e2 100644 --- a/aioinject/containers.py +++ b/aioinject/containers.py @@ -1,4 +1,5 @@ import contextlib +from collections import defaultdict from collections.abc import Iterator, Sequence from contextlib import AsyncExitStack from types import TracebackType @@ -25,7 +26,7 @@ def __init__(self, extensions: Sequence[Extension] | None = None) -> None: self._exit_stack = AsyncExitStack() self._singletons = SingletonStore(exit_stack=self._exit_stack) - self.providers: _types.Providers[Any] = {} + self.providers: _types.Providers[Any] = defaultdict(list) self.type_context: dict[str, type[Any]] = {} self.extensions = extensions or [] self._init_extensions(self.extensions) @@ -35,27 +36,42 @@ def _init_extensions(self, extensions: Sequence[Extension]) -> None: if isinstance(extension, OnInitExtension): extension.on_init(self) - def register( - self, - *providers: Provider[Any], - ) -> None: + def register(self, *providers: Provider[Any]) -> None: for provider in providers: - if provider.type_ in self.providers: - msg = ( - f"Provider for type {provider.type_} is already registered" - ) - raise ValueError(msg) + self._register(provider) - self.providers[provider.type_] = provider - if class_name := getattr(provider.type_, "__name__", None): - self.type_context[class_name] = provider.type_ + def try_register(self, *providers: Provider[Any]) -> None: + for provider in providers: + with contextlib.suppress(ValueError): + self._register(provider) + + def _register(self, provider: Provider[Any]) -> None: + existing_impls = { + existing_provider.impl + for existing_provider in self.providers.get(provider.type_, []) + } + if provider.impl in existing_impls: + msg = ( + f"Provider for type {provider.type_} with same " + f"implementation already registered" + ) + raise ValueError(msg) + + self.providers[provider.type_].append(provider) + + class_name = getattr(provider.type_, "__name__", None) + if class_name and class_name not in self.type_context: + self.type_context[class_name] = provider.type_ def get_provider(self, type_: type[T]) -> Provider[T]: - try: - return self.providers[type_] - except KeyError as exc: - err_msg = f"Provider for type {type_.__qualname__} not found" - raise ValueError(err_msg) from exc + return self.get_providers(type_)[0] + + def get_providers(self, type_: type[T]) -> list[Provider[T]]: + if providers := self.providers[type_]: + return providers + + err_msg = f"Providers for type {type_.__qualname__} not found" + raise ValueError(err_msg) def context( self, @@ -79,18 +95,23 @@ def sync_context( @contextlib.contextmanager def override(self, *providers: Provider[Any]) -> Iterator[None]: - previous: dict[type[Any], Provider[Any] | None] = {} - for provider in providers: - previous[provider.type_] = self.providers.get(provider.type_) - self.providers[provider.type_] = provider + previous = { + provider.type_: self.providers.get(provider.type_, None) + for provider in providers + } + overridden = defaultdict( + list, + {provider.type_: [provider] for provider in providers}, + ) + + self.providers.update(overridden) try: yield finally: for provider in providers: del self.providers[provider.type_] - prev = previous[provider.type_] - if prev is not None: + if (prev := previous[provider.type_]) is not None: self.providers[provider.type_] = prev async def __aenter__(self) -> Self: diff --git a/aioinject/context.py b/aioinject/context.py index 1a6e886..aa901c7 100644 --- a/aioinject/context.py +++ b/aioinject/context.py @@ -2,6 +2,7 @@ import contextvars import inspect +from collections import defaultdict from collections.abc import Callable, Coroutine, Iterable, Mapping, Sequence from contextvars import ContextVar from types import TracebackType @@ -9,6 +10,7 @@ TYPE_CHECKING, Any, Generic, + Literal, TypeVar, overload, ) @@ -52,7 +54,7 @@ def __init__( self._store = InstanceStore() self._token: contextvars.Token[AnyCtx] | None = None - self._providers: _types.Providers[Any] = {} + self._providers: _types.Providers[Any] = defaultdict(list) self._closed = False @@ -61,35 +63,73 @@ def _get_store(self, lifetime: DependencyLifetime) -> InstanceStore: return self._singletons return self._store - def _get_provider(self, type_: type[_T]) -> Provider[_T]: - return self._providers.get(type_) or self._container.get_provider( + def _get_providers(self, type_: type[_T]) -> list[Provider[_T]]: + return self._providers.get(type_) or self._container.get_providers( type_, ) def register(self, provider: Provider[Any]) -> None: - self._providers[provider.type_] = provider + self._providers[provider.type_].append(provider) class InjectionContext(_BaseInjectionContext[ContextExtension]): - async def resolve( + async def resolve(self, type_: type[_T]) -> _T: + return await self._resolve(type_, is_iterable=False) + + async def resolve_iterable(self, type_: type[_T]) -> list[_T]: + return await self._resolve(type_, is_iterable=True) + + @overload + async def _resolve( + self, + type_: type[_T], + *, + is_iterable: Literal[False], + ) -> _T: ... + + @overload + async def _resolve( + self, + type_: type[_T], + *, + is_iterable: Literal[True], + ) -> list[_T]: ... + + async def _resolve( self, type_: type[_T], + *, + is_iterable: bool, + ) -> _T | list[_T]: + providers = self._get_providers(type_) + if not is_iterable: + return await self._resolve_provider(providers[-1]) # type: ignore[arg-type] + return [ + await self._resolve_provider(provider) for provider in providers + ] + + async def _resolve_provider( + self, + provider: Provider[_T], ) -> _T: - provider = self._get_provider(type_) store = self._get_store(provider.lifetime) if (cached := store.get(provider)) is not NotInCache.sentinel: return cached - provider_dependencies = provider.resolve_dependencies( + provider_dependencies = provider.collect_dependencies( context=self._container.type_context ) dependencies_map = get_generic_parameter_map( - type_, # type: ignore[arg-type] + provider.type_, # type: ignore[arg-type] provider_dependencies, ) dependencies = { - dependency.name: await self.resolve( - dependencies_map.get(dependency.name, dependency.type_) + dependency.name: await self._resolve( # type: ignore[call-overload] + type_=dependencies_map.get( + dependency.name, + dependency.inner_type, + ), + is_iterable=dependency.is_iterable, ) for dependency in provider_dependencies } @@ -97,25 +137,25 @@ async def resolve( if provider.lifetime is DependencyLifetime.singleton: async with store.lock(provider) as should_provide: if should_provide: - return await self._resolve(provider, store, dependencies) - return store.get( # type: ignore[return-value] # pragma: no cover - provider, - ) + return await self._provide_and_store( + provider, store, dependencies + ) + return store.get(provider) # type: ignore[return-value] # pragma: no cover - return await self._resolve(provider, store, dependencies) + return await self._provide_and_store(provider, store, dependencies) - async def _resolve( + async def _provide_and_store( self, provider: Provider[_T], store: InstanceStore, - dependencies: Mapping[str, Any], + dependencies: Mapping[str, object], ) -> _T: - resolved = await provider.provide(dependencies) + provided = await provider.provide(dependencies) if provider.is_generator: - resolved = await store.enter_context(resolved) - store.add(provider, resolved) - await self._on_resolve(provider=provider, instance=resolved) - return resolved + provided = await store.enter_context(provided) + store.add(provider, provided) + await self._on_resolve(provider=provider, instance=provided) + return provided @overload async def execute( @@ -142,14 +182,15 @@ async def execute( *args: Any, **kwargs: Any, ) -> _T: - resolved = {} - for dependency in dependencies: - if dependency.name in kwargs: - continue - - resolved[dependency.name] = await self.resolve( - type_=dependency.type_, + resolved = { + dependency.name: await self._resolve( # type: ignore[call-overload] + type_=dependency.inner_type, + is_iterable=dependency.is_iterable, ) + for dependency in dependencies + if dependency.name not in kwargs + } + if inspect.iscoroutinefunction(function): return await function(*args, **kwargs, **resolved) return function(*args, **kwargs, **resolved) # type: ignore[return-value] @@ -178,43 +219,87 @@ async def __aexit__( class SyncInjectionContext(_BaseInjectionContext[SyncContextExtension]): - def resolve( + def resolve(self, type_: type[_T]) -> _T: + return self._resolve(type_, is_iterable=False) + + def resolve_iterable(self, type_: type[_T]) -> list[_T]: + return self._resolve(type_, is_iterable=True) + + @overload + def _resolve( + self, + type_: type[_T], + *, + is_iterable: Literal[False], + ) -> _T: ... + + @overload + def _resolve( + self, + type_: type[_T], + *, + is_iterable: Literal[True], + ) -> list[_T]: ... + + def _resolve( self, type_: type[_T], + *, + is_iterable: bool, + ) -> _T | list[_T]: + providers = self._get_providers(type_) + if not is_iterable: + return self._resolve_provider(providers[-1]) + return [self._resolve_provider(provider) for provider in providers] + + def _resolve_provider( + self, + provider: Provider[_T], ) -> _T: - provider = self._get_provider(type_) store = self._get_store(provider.lifetime) if (cached := store.get(provider)) is not NotInCache.sentinel: return cached - dependencies = {} - for dependency in provider.resolve_dependencies( - self._container.type_context, - ): - dependencies[dependency.name] = self.resolve( - type_=dependency.type_, + provider_dependencies = provider.collect_dependencies( + context=self._container.type_context + ) + dependencies_map = get_generic_parameter_map( + provider.type_, # type: ignore[arg-type] + provider_dependencies, + ) + dependencies = { + dependency.name: self._resolve( # type: ignore[call-overload] + type_=dependencies_map.get( + dependency.name, + dependency.inner_type, + ), + is_iterable=dependency.is_iterable, ) + for dependency in provider_dependencies + } if provider.lifetime is DependencyLifetime.singleton: with store.sync_lock(provider) as should_provide: if should_provide: - return self._resolve(provider, store, dependencies) + return self._provide_and_store( + provider, store, dependencies + ) return store.get(provider) # type: ignore[return-value] # pragma: no cover - return self._resolve(provider, store, dependencies) + return self._provide_and_store(provider, store, dependencies) - def _resolve( + def _provide_and_store( self, provider: Provider[_T], store: InstanceStore, - dependencies: Mapping[str, Any], + dependencies: Mapping[str, object], ) -> _T: - resolved = provider.provide_sync(dependencies) + provided = provider.provide_sync(dependencies) if provider.is_generator: - resolved = store.enter_sync_context(resolved) - store.add(provider, resolved) - self._on_resolve(provider=provider, instance=resolved) - return resolved + provided = store.enter_sync_context(provided) + store.add(provider, provided) + self._on_resolve(provider=provider, instance=provided) + return provided def execute( self, @@ -223,11 +308,14 @@ def execute( *args: Any, **kwargs: Any, ) -> _T: - resolved = {} - for dependency in dependencies: - if dependency.name in kwargs: - continue - resolved[dependency.name] = self.resolve(type_=dependency.type_) + resolved = { + dependency.name: self._resolve( # type: ignore[call-overload] + type_=dependency.inner_type, + is_iterable=dependency.is_iterable, + ) + for dependency in dependencies + if dependency.name not in kwargs + } return function(*args, **kwargs, **resolved) def _on_resolve(self, provider: Provider[T], instance: T) -> None: diff --git a/aioinject/providers.py b/aioinject/providers.py index 0598c86..921e753 100644 --- a/aioinject/providers.py +++ b/aioinject/providers.py @@ -7,6 +7,7 @@ import typing from collections.abc import Mapping from dataclasses import dataclass +from functools import cached_property from inspect import isclass from typing import ( Annotated, @@ -26,6 +27,7 @@ get_fn_ns, get_return_annotation, is_context_manager_function, + is_iterable_generic_collection, remove_annotation, ) from aioinject.markers import Inject @@ -34,11 +36,25 @@ _T = TypeVar("_T") -@dataclass(slots=True, kw_only=True, frozen=True) +@dataclass(kw_only=True) class Dependency(Generic[_T]): name: str type_: type[_T] + @cached_property + def inner_type(self) -> type[_T]: + return typing.cast( + type[_T], + typing.get_args(self.type_)[0] if self.is_iterable else self.type_, + ) + + @cached_property + def is_iterable(self) -> bool: + return is_iterable_generic_collection(self.type_) # type: ignore[arg-type] + + def __hash__(self) -> int: + return hash(self.type_) + def _get_annotation_args(type_hint: Any) -> tuple[type, tuple[Any, ...]]: try: @@ -215,7 +231,7 @@ async def provide(self, kwargs: Mapping[str, Any]) -> _T: ... def provide_sync(self, kwargs: Mapping[str, Any]) -> _T: ... - def resolve_dependencies( + def collect_dependencies( self, context: dict[str, Any] | None = None, ) -> tuple[Dependency[object], ...]: diff --git a/aioinject/validation/__init__.py b/aioinject/validation/__init__.py index 894a393..5d8a4f7 100644 --- a/aioinject/validation/__init__.py +++ b/aioinject/validation/__init__.py @@ -4,6 +4,7 @@ from aioinject.validation._builtin import ( ForbidDependency, all_dependencies_are_present, + all_providers_for_type_have_equal_lifetime, ) from aioinject.validation._validate import validate_container from aioinject.validation.abc import ContainerValidator @@ -26,4 +27,5 @@ "all_dependencies_are_present", "ForbidDependency", "validate_container", + "all_providers_for_type_have_equal_lifetime", ] diff --git a/aioinject/validation/_builtin.py b/aioinject/validation/_builtin.py index 6f21c1d..f05fc08 100644 --- a/aioinject/validation/_builtin.py +++ b/aioinject/validation/_builtin.py @@ -1,4 +1,5 @@ from collections.abc import Callable, Sequence +from itertools import chain from typing import Any import aioinject @@ -13,20 +14,15 @@ def all_dependencies_are_present( container: aioinject.Container, ) -> Sequence[ContainerValidationError]: - errors = [] - for provider in container.providers.values(): - for dependency in provider.resolve_dependencies( - container.type_context, - ): - dep_type = dependency.type_ - if dep_type not in container.providers: - error = DependencyNotFoundError( - message=f"Provider for type {dep_type} not found", - dependency=dep_type, - ) - errors.append(error) - - return errors + return [ + DependencyNotFoundError( + message=f"Provider for type {dependency.type_} not found", + dependency=dependency.type_, + ) + for provider in chain.from_iterable(container.providers.values()) + for dependency in provider.collect_dependencies(container.type_context) + if dependency.type_ not in container.providers + ] class ForbidDependency(ContainerValidator): @@ -43,11 +39,11 @@ def __call__( container: aioinject.Container, ) -> Sequence[ContainerValidationError]: errors = [] - for provider in container.providers.values(): + for provider in chain.from_iterable(container.providers.values()): if not self.dependant(provider): continue - for dependency in provider.resolve_dependencies( + for dependency in provider.collect_dependencies( container.type_context, ): dep_type = dependency.type_ @@ -57,4 +53,17 @@ def __call__( if self.dependency(dependency_provider): msg = f"Provider {provider!r} cannot depend on {dependency_provider!r}" errors.append(ContainerValidationError(msg)) + return errors + + +def all_providers_for_type_have_equal_lifetime( + container: aioinject.Container, +) -> Sequence[ContainerValidationError]: + return [ + ContainerValidationError( + f"Type {type_} has providers with different scopes" + ) + for type_, providers in container.providers.items() + if len({provider.lifetime for provider in providers}) > 1 + ] diff --git a/tests/container/mod_tests/provider_fn_deffered_dep_missuse.py b/tests/container/mod_tests/provider_fn_deferred_dep_misuse.py similarity index 100% rename from tests/container/mod_tests/provider_fn_deffered_dep_missuse.py rename to tests/container/mod_tests/provider_fn_deferred_dep_misuse.py diff --git a/tests/container/mod_tests/provider_fn_with_deffered_dep.py b/tests/container/mod_tests/provider_fn_with_deferred_dep.py similarity index 100% rename from tests/container/mod_tests/provider_fn_with_deffered_dep.py rename to tests/container/mod_tests/provider_fn_with_deferred_dep.py diff --git a/tests/container/test_container.py b/tests/container/test_container.py index f0aed5c..29fee48 100644 --- a/tests/container/test_container.py +++ b/tests/container/test_container.py @@ -38,7 +38,7 @@ def test_can_register_single(container: Container) -> None: provider = providers.Scoped(_ServiceA) container.register(provider) - expected = {_ServiceA: provider} + expected = {_ServiceA: [provider]} assert container.providers == expected @@ -46,7 +46,7 @@ def test_can_register_batch(container: Container) -> None: provider1 = providers.Scoped(_ServiceA) provider2 = providers.Scoped(_ServiceB) container.register(provider1, provider2) - excepted = {_ServiceA: provider1, _ServiceB: provider2} + excepted = {_ServiceA: [provider1], _ServiceB: [provider2]} assert container.providers == excepted @@ -57,23 +57,46 @@ def test_cant_register_multiple_providers_for_same_type( with pytest.raises( ValueError, - match="^Provider for type is already registered$", + match="^Provider for type with same implementation already registered$", ): container.register(Scoped(int)) +def test_can_try_register(container: Container) -> None: + def same_impl() -> _ServiceA: + return _ServiceA() + + provider = providers.Scoped(same_impl, _ServiceA) + container.register(provider) + + expected = {_ServiceA: [provider]} + assert container.providers == expected + + container.try_register(providers.Scoped(same_impl, _ServiceA)) + assert container.providers == expected + + def test_can_retrieve_single_provider(container: Container) -> None: int_provider = providers.Scoped(int) container.register(int_provider) assert container.get_provider(int) +def test_can_retrieve_multiple_providers(container: Container) -> None: + int_providers = [ + providers.Scoped(lambda: 1, int), + providers.Scoped(lambda: 2, int), + ] + container.register(*int_providers) + assert len(container.get_providers(int)) == len(int_providers) + + def test_missing_provider() -> None: container = Container() with pytest.raises(ValueError) as exc_info: # noqa: PT011 assert container.get_provider(_ServiceA) - msg = f"Provider for type {_ServiceA.__qualname__} not found" + msg = f"Providers for type {_ServiceA.__qualname__} not found" assert str(exc_info.value) == msg diff --git a/tests/container/test_future_annotations.py b/tests/container/test_future_annotations.py index bbe96bb..dbec3dd 100644 --- a/tests/container/test_future_annotations.py +++ b/tests/container/test_future_annotations.py @@ -8,11 +8,11 @@ from aioinject.providers import Scoped -async def test_deffered_dependecies() -> None: +async def test_deferred_dependencies() -> None: if TYPE_CHECKING: from decimal import Decimal - def some_deffered_type() -> Decimal: + def some_deferred_type() -> Decimal: from decimal import Decimal return Decimal("1.0") @@ -26,23 +26,19 @@ def __init__(self, decimal: Decimal) -> None: def register_decimal_scoped() -> None: from decimal import Decimal - container.register(Scoped(some_deffered_type, Decimal)) + container.register(Scoped(some_deferred_type, Decimal)) register_decimal_scoped() container.register(Scoped(DoubledDecimal)) async with container.context() as ctx: assert (await ctx.resolve(DoubledDecimal)).decimal == DoubledDecimal( - some_deffered_type(), + some_deferred_type(), ).decimal -def test_provider_fn_with_deffered_dep() -> None: - pass - - -def test_provider_fn_deffered_dep_missuse() -> None: +def test_provider_fn_deferred_dep_misuse() -> None: with pytest.raises(ValueError) as exc_info: # noqa: PT011 from tests.container.mod_tests import ( - provider_fn_deffered_dep_missuse, # noqa: F401 + provider_fn_deferred_dep_misuse, # noqa: F401 ) assert exc_info.match("Or it's type is not defined yet.") diff --git a/tests/features/test_generics.py b/tests/features/test_generics.py index b08a8a6..fdb6cd0 100644 --- a/tests/features/test_generics.py +++ b/tests/features/test_generics.py @@ -1,12 +1,16 @@ +import abc +from collections.abc import Awaitable, Callable from typing import Generic, TypeVar import pytest from aioinject import Container, Object, Scoped -from aioinject.providers import Dependency +from aioinject.providers import Dependency, Transient T = TypeVar("T") +ReqT = TypeVar("ReqT") +ResT = TypeVar("ResT") class GenericService(Generic[T]): @@ -24,20 +28,20 @@ class ConstrainedGenericDependency(WithGenericDependency[int]): async def test_generic_dependency() -> None: - assert Scoped(GenericService[int]).resolve_dependencies() == ( + assert Scoped(GenericService[int]).collect_dependencies() == ( Dependency( name="dependency", type_=str, ), ) - assert Scoped(WithGenericDependency[int]).resolve_dependencies() == ( + assert Scoped(WithGenericDependency[int]).collect_dependencies() == ( Dependency( name="dependency", type_=int, ), ) - assert Scoped(ConstrainedGenericDependency).resolve_dependencies() == ( + assert Scoped(ConstrainedGenericDependency).collect_dependencies() == ( Dependency( name="dependency", type_=int, @@ -190,3 +194,51 @@ def so_generic(self) -> T: # pragma: no cover instance = await ctx.resolve(GenericClass) assert isinstance(instance, GenericClass) assert instance.a == MEANING_OF_LIFE_INT + + +async def test_can_resolve_generic_iterable() -> None: + class MiddlewareBase(abc.ABC, Generic[ReqT, ResT]): + @abc.abstractmethod + async def __call__( + self, + request: ReqT, + handle: Callable[[ReqT], Awaitable[ResT]], + ) -> ResT: + pass + + class FirstMiddleware(MiddlewareBase[ReqT, ResT]): + async def __call__( + self, + request: ReqT, + handle: Callable[[ReqT], Awaitable[ResT]], + ) -> ResT: + return await handle(request) + + class SecondMiddleware(MiddlewareBase[ReqT, ResT]): + async def __call__( + self, + request: ReqT, + handle: Callable[[ReqT], Awaitable[ResT]], + ) -> ResT: + return await handle(request) + + class ThirdMiddleware(MiddlewareBase[ReqT, ResT]): + async def __call__( + self, + request: ReqT, + handle: Callable[[ReqT], Awaitable[ResT]], + ) -> ResT: + return await handle(request) + + container = Container() + container.register( + Transient(FirstMiddleware, MiddlewareBase[str, str]), + Transient(SecondMiddleware, MiddlewareBase[int, int]), + Transient(ThirdMiddleware, MiddlewareBase[int, int]), + ) + + async with container.context() as ctx: + instances = await ctx.resolve_iterable(MiddlewareBase[int, int]) # type: ignore[type-abstract] + assert len(instances) == 2 # noqa: PLR2004 + assert isinstance(instances[0], SecondMiddleware) + assert isinstance(instances[1], ThirdMiddleware) diff --git a/tests/providers/test_object.py b/tests/providers/test_object.py index 21e6c8f..9c229d0 100644 --- a/tests/providers/test_object.py +++ b/tests/providers/test_object.py @@ -36,7 +36,7 @@ def test_should_have_no_dependencies( ) -> None: for obj in dependencies_test_data: provider = Object(object_=obj) - assert not provider.resolve_dependencies() + assert not provider.collect_dependencies() def test_should_have_empty_type_hints( diff --git a/tests/providers/test_scoped.py b/tests/providers/test_scoped.py index a504541..c8a0f09 100644 --- a/tests/providers/test_scoped.py +++ b/tests/providers/test_scoped.py @@ -108,7 +108,7 @@ def factory() -> None: pass provider = providers.Scoped(factory) - assert provider.resolve_dependencies() == () + assert provider.collect_dependencies() == () def test_dependencies() -> None: @@ -137,7 +137,7 @@ def factory( type_=str, ), ) - assert provider.resolve_dependencies() == expected + assert provider.collect_dependencies() == expected def iterable() -> Iterator[int]: diff --git a/tests/test_inject.py b/tests/test_inject.py index 75539df..d6e091d 100644 --- a/tests/test_inject.py +++ b/tests/test_inject.py @@ -1,3 +1,5 @@ +import abc +from collections.abc import Sequence from typing import Annotated, NewType import pytest @@ -127,3 +129,44 @@ def __init__(self, a: A, b: B) -> None: service = await ctx.resolve(Service) assert service.a == 1 assert service.b == 2 # noqa: PLR2004 + + +async def test_iterable_provider() -> None: + class ILogger(abc.ABC): + @abc.abstractmethod + def log(self, msg: str) -> None: ... + + class ConsoleLogger(ILogger): + def log(self, msg: str) -> None: + print(msg) # noqa: T201 + + class FileLogger(ILogger): + def log(self, msg: str) -> None: + pass + + class Service: + def __init__( + self, + all_loggers: Sequence[ILogger], + actual_logger: ILogger, + ) -> None: + self.all_loggers = all_loggers + self.actual_logger = actual_logger + + container = Container() + container.register(Scoped(ConsoleLogger, type_=ILogger)) + container.register(Scoped(FileLogger, type_=ILogger)) + container.register(Scoped(Service)) + + async with container.context() as ctx: + service = await ctx.resolve(Service) + assert isinstance(service.all_loggers[0], ConsoleLogger) + assert isinstance(service.all_loggers[1], FileLogger) + assert isinstance(service.actual_logger, FileLogger) + + loggers = await ctx.resolve_iterable(ILogger) # type: ignore[type-abstract] + assert len(loggers) == 2 # noqa: PLR2004 + + with container.sync_context() as sync_ctx: + loggers = sync_ctx.resolve_iterable(ILogger) # type: ignore[type-abstract] + assert len(loggers) == 2 # noqa: PLR2004 diff --git a/tests/validation/test_lifetime_validator.py b/tests/validation/test_lifetime_validator.py new file mode 100644 index 0000000..d967bdd --- /dev/null +++ b/tests/validation/test_lifetime_validator.py @@ -0,0 +1,44 @@ +import pytest + +from aioinject import Container, providers +from aioinject.validation import ( + all_providers_for_type_have_equal_lifetime, + validate_container, +) +from aioinject.validation.error import ContainerValidationErrorGroup + + +_VALIDATORS = [all_providers_for_type_have_equal_lifetime] + + +class IDependency: + pass + + +class SingletonImpl(IDependency): + pass + + +class ScopedImpl(IDependency): + pass + + +def test_ok() -> None: + container = Container() + container.register(providers.Scoped(ScopedImpl, type_=IDependency)) + container.register(providers.Scoped(SingletonImpl, type_=IDependency)) + + validate_container(container, _VALIDATORS) + + +def test_err() -> None: + container = Container() + container.register(providers.Scoped(ScopedImpl, type_=IDependency)) + container.register(providers.Singleton(SingletonImpl, type_=IDependency)) + + with pytest.raises(ContainerValidationErrorGroup) as exc_info: + validate_container(container, _VALIDATORS) + + assert len(exc_info.value.errors) == 1 + err = exc_info.value.errors[0] + assert "has providers with different scopes" in err.message diff --git a/uv.lock b/uv.lock index 6350764..53f1390 100644 --- a/uv.lock +++ b/uv.lock @@ -343,7 +343,7 @@ name = "click" version = "8.1.7" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/96/d3/f04c7bfcf5c1862a2a5b845c6b2b360488cf47af55dfa79c98f6a6bf98b5/click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de", size = 336121 } wheels = [ @@ -751,7 +751,7 @@ version = "1.6.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "click" }, - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, { name = "ghp-import" }, { name = "jinja2" }, { name = "markdown" },