diff --git a/aioinject/_store.py b/aioinject/_store.py index bf077a9..e62bc1a 100644 --- a/aioinject/_store.py +++ b/aioinject/_store.py @@ -34,27 +34,23 @@ def __init__(self) -> None: self._sync_exit_stack = contextlib.ExitStack() def get(self, provider: Provider[T]) -> T | Literal[NotInCache.sentinel]: - return self._cache.get(provider.resolve_type(), NotInCache.sentinel) + return self._cache.get(provider.type_, NotInCache.sentinel) def add(self, provider: Provider[T], obj: T) -> None: if provider.lifetime is not DependencyLifetime.transient: - self._cache[provider.resolve_type()] = obj + self._cache[provider.type_] = obj def lock( self, provider: Provider[Any], ) -> AbstractAsyncContextManager[bool]: - return contextlib.nullcontext( - provider.resolve_type() not in self._cache, - ) + return contextlib.nullcontext(provider.type_ not in self._cache) def sync_lock( self, provider: Provider[Any], ) -> AbstractContextManager[bool]: - return contextlib.nullcontext( - provider.resolve_type() not in self._cache, - ) + return contextlib.nullcontext(provider.type_ not in self._cache) @typing.overload async def enter_context( @@ -128,9 +124,9 @@ def __init__(self) -> None: @contextlib.asynccontextmanager async def lock(self, provider: Provider[Any]) -> AsyncIterator[bool]: - if provider.resolve_type() not in self._cache: - async with self._locks[provider.resolve_type()]: - yield provider.resolve_type() not in self._cache + if provider.type_ not in self._cache: + async with self._locks[provider.type_]: + yield provider.type_ not in self._cache return yield False @@ -139,8 +135,8 @@ def sync_lock( self, provider: Provider[Any], ) -> Iterator[bool]: - if provider.resolve_type() not in self._cache: - with self._sync_locks[provider.resolve_type()]: - yield provider.resolve_type() not in self._cache + if provider.type_ not in self._cache: + with self._sync_locks[provider.type_]: + yield provider.type_ not in self._cache return yield False diff --git a/aioinject/containers.py b/aioinject/containers.py index 9f0a525..83fca30 100644 --- a/aioinject/containers.py +++ b/aioinject/containers.py @@ -25,14 +25,13 @@ def register( self, provider: Provider[Any], ) -> None: - provider_type = provider.resolve_type(self.type_context) - if provider_type in self.providers: - msg = f"Provider for type {provider_type} is already registered" + if provider.type_ in self.providers: + msg = f"Provider for type {provider.type_} is already registered" raise ValueError(msg) - self.providers[provider_type] = provider - if class_name := getattr(provider_type, "__name__", None): - self.type_context[class_name] = provider_type + self.providers[provider.type_] = provider + if class_name := getattr(provider.type_, "__name__", None): + self.type_context[class_name] = provider.type_ def get_provider(self, type_: type[_T]) -> Provider[_T]: try: @@ -58,14 +57,14 @@ def override( self, provider: Provider[Any], ) -> Iterator[None]: - previous = self.providers.get(provider.resolve_type()) - self.providers[provider.resolve_type()] = provider + previous = self.providers.get(provider.type_) + self.providers[provider.type_] = provider yield - del self.providers[provider.resolve_type()] + del self.providers[provider.type_] if previous is not None: - self.providers[provider.resolve_type()] = previous + self.providers[provider.type_] = previous async def __aenter__(self) -> Self: return self diff --git a/aioinject/providers.py b/aioinject/providers.py index ed940df..eac526a 100644 --- a/aioinject/providers.py +++ b/aioinject/providers.py @@ -124,14 +124,11 @@ def _get_provider_type_hints( ) -def _guess_return_type( - factory: _FactoryType[_T], - context: dict[str, type[Any]] | None = None, -) -> type[_T]: +def _guess_return_type(factory: _FactoryType[_T]) -> type[_T]: if isclass(factory): return typing.cast(type[_T], factory) - type_hints = _get_type_hints(factory, context=context) + type_hints = _get_type_hints(factory) try: return_type = type_hints["return"] except KeyError as e: @@ -167,9 +164,9 @@ class DependencyLifetime(enum.Enum): @runtime_checkable class Provider(Protocol[_T]): impl: Any + type_: type[_T] lifetime: DependencyLifetime _cached_dependencies: tuple[Dependency[object], ...] - _cached_type: type[_T] async def provide(self, kwargs: Mapping[str, Any]) -> _T: ... @@ -177,19 +174,6 @@ async def provide(self, kwargs: Mapping[str, Any]) -> _T: def provide_sync(self, kwargs: Mapping[str, Any]) -> _T: ... - def _resolve_type_impl( - self, - context: dict[str, Any] | None = None, - ) -> type[_T]: - ... - - def resolve_type(self, context: dict[str, Any] | None = None) -> type[_T]: - try: - return self._cached_type - except AttributeError: - self._cached_type = self._resolve_type_impl(context) - return self._cached_type - def resolve_dependencies( self, context: dict[str, Any] | None = None, @@ -214,11 +198,7 @@ def is_generator(self) -> bool: return is_context_manager_function(self.impl) def __repr__(self) -> str: # pragma: no cover - try: - type_ = repr(self.resolve_type()) - except NameError: - type_ = "UNKNOWN" - return f"{self.__class__.__qualname__}(type={type_}, implementation={self.impl})" + return f"{self.__class__.__qualname__}(type={self.type_}, implementation={self.impl})" class Scoped(Provider[_T]): @@ -229,14 +209,8 @@ def __init__( factory: _FactoryType[_T], type_: type[_T] | None = None, ) -> None: - self.type_ = type_ self.impl = factory - - def _resolve_type_impl( - self, - context: dict[str, Any] | None = None, - ) -> type[_T]: - return self.type_ or _guess_return_type(self.impl, context=context) + self.type_ = type_ or _guess_return_type(factory) def provide_sync(self, kwargs: Mapping[str, Any]) -> _T: return self.impl(**kwargs) # type: ignore[return-value] @@ -297,9 +271,6 @@ def __init__( self.type_ = type_ or type(object_) self.impl = object_ - def _resolve_type_impl(self, _: dict[str, Any] | None = None) -> type[_T]: - return self.type_ - def provide_sync(self, kwargs: Mapping[str, Any]) -> _T: # noqa: ARG002 return self.impl diff --git a/tests/providers/test_scoped.py b/tests/providers/test_scoped.py index 44bfa21..f3ea76f 100644 --- a/tests/providers/test_scoped.py +++ b/tests/providers/test_scoped.py @@ -160,4 +160,4 @@ async def async_gen() -> AsyncGenerator[int, None]: @pytest.mark.parametrize("factory", [iterable, gen, async_iterable, async_gen]) def test_generator_return_types(factory: Any) -> None: provider = providers.Scoped(factory) - assert provider.resolve_type() is int + assert provider.type_ is int