Skip to content

Commit

Permalink
Replace resolve_type method with type_ property
Browse files Browse the repository at this point in the history
  • Loading branch information
ThirVondukr committed Jan 28, 2024
1 parent 56327f3 commit dfa62ba
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 59 deletions.
24 changes: 10 additions & 14 deletions aioinject/_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,27 +34,23 @@ def __init__(self) -> None:
self._sync_exit_stack = contextlib.ExitStack()

def get(self, provider: Provider[T]) -> T | Literal[NotInCache.sentinel]:
return self._cache.get(provider.resolve_type(), NotInCache.sentinel)
return self._cache.get(provider.type_, NotInCache.sentinel)

def add(self, provider: Provider[T], obj: T) -> None:
if provider.lifetime is not DependencyLifetime.transient:
self._cache[provider.resolve_type()] = obj
self._cache[provider.type_] = obj

def lock(
self,
provider: Provider[Any],
) -> AbstractAsyncContextManager[bool]:
return contextlib.nullcontext(
provider.resolve_type() not in self._cache,
)
return contextlib.nullcontext(provider.type_ not in self._cache)

def sync_lock(
self,
provider: Provider[Any],
) -> AbstractContextManager[bool]:
return contextlib.nullcontext(
provider.resolve_type() not in self._cache,
)
return contextlib.nullcontext(provider.type_ not in self._cache)

@typing.overload
async def enter_context(
Expand Down Expand Up @@ -128,9 +124,9 @@ def __init__(self) -> None:

@contextlib.asynccontextmanager
async def lock(self, provider: Provider[Any]) -> AsyncIterator[bool]:
if provider.resolve_type() not in self._cache:
async with self._locks[provider.resolve_type()]:
yield provider.resolve_type() not in self._cache
if provider.type_ not in self._cache:
async with self._locks[provider.type_]:
yield provider.type_ not in self._cache
return
yield False

Expand All @@ -139,8 +135,8 @@ def sync_lock(
self,
provider: Provider[Any],
) -> Iterator[bool]:
if provider.resolve_type() not in self._cache:
with self._sync_locks[provider.resolve_type()]:
yield provider.resolve_type() not in self._cache
if provider.type_ not in self._cache:
with self._sync_locks[provider.type_]:
yield provider.type_ not in self._cache
return
yield False
19 changes: 9 additions & 10 deletions aioinject/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,13 @@ def register(
self,
provider: Provider[Any],
) -> None:
provider_type = provider.resolve_type(self.type_context)
if provider_type in self.providers:
msg = f"Provider for type {provider_type} is already registered"
if provider.type_ in self.providers:
msg = f"Provider for type {provider.type_} is already registered"
raise ValueError(msg)

self.providers[provider_type] = provider
if class_name := getattr(provider_type, "__name__", None):
self.type_context[class_name] = provider_type
self.providers[provider.type_] = provider
if class_name := getattr(provider.type_, "__name__", None):
self.type_context[class_name] = provider.type_

def get_provider(self, type_: type[_T]) -> Provider[_T]:
try:
Expand All @@ -58,14 +57,14 @@ def override(
self,
provider: Provider[Any],
) -> Iterator[None]:
previous = self.providers.get(provider.resolve_type())
self.providers[provider.resolve_type()] = provider
previous = self.providers.get(provider.type_)
self.providers[provider.type_] = provider

yield

del self.providers[provider.resolve_type()]
del self.providers[provider.type_]
if previous is not None:
self.providers[provider.resolve_type()] = previous
self.providers[provider.type_] = previous

async def __aenter__(self) -> Self:
return self
Expand Down
39 changes: 5 additions & 34 deletions aioinject/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,11 @@ def _get_provider_type_hints(
)


def _guess_return_type(
factory: _FactoryType[_T],
context: dict[str, type[Any]] | None = None,
) -> type[_T]:
def _guess_return_type(factory: _FactoryType[_T]) -> type[_T]:
if isclass(factory):
return typing.cast(type[_T], factory)

type_hints = _get_type_hints(factory, context=context)
type_hints = _get_type_hints(factory)
try:
return_type = type_hints["return"]
except KeyError as e:
Expand Down Expand Up @@ -167,29 +164,16 @@ class DependencyLifetime(enum.Enum):
@runtime_checkable
class Provider(Protocol[_T]):
impl: Any
type_: type[_T]
lifetime: DependencyLifetime
_cached_dependencies: tuple[Dependency[object], ...]
_cached_type: type[_T]

async def provide(self, kwargs: Mapping[str, Any]) -> _T:
...

def provide_sync(self, kwargs: Mapping[str, Any]) -> _T:
...

def _resolve_type_impl(
self,
context: dict[str, Any] | None = None,
) -> type[_T]:
...

def resolve_type(self, context: dict[str, Any] | None = None) -> type[_T]:
try:
return self._cached_type
except AttributeError:
self._cached_type = self._resolve_type_impl(context)
return self._cached_type

def resolve_dependencies(
self,
context: dict[str, Any] | None = None,
Expand All @@ -214,11 +198,7 @@ def is_generator(self) -> bool:
return is_context_manager_function(self.impl)

def __repr__(self) -> str: # pragma: no cover
try:
type_ = repr(self.resolve_type())
except NameError:
type_ = "UNKNOWN"
return f"{self.__class__.__qualname__}(type={type_}, implementation={self.impl})"
return f"{self.__class__.__qualname__}(type={self.type_}, implementation={self.impl})"


class Scoped(Provider[_T]):
Expand All @@ -229,14 +209,8 @@ def __init__(
factory: _FactoryType[_T],
type_: type[_T] | None = None,
) -> None:
self.type_ = type_
self.impl = factory

def _resolve_type_impl(
self,
context: dict[str, Any] | None = None,
) -> type[_T]:
return self.type_ or _guess_return_type(self.impl, context=context)
self.type_ = type_ or _guess_return_type(factory)

def provide_sync(self, kwargs: Mapping[str, Any]) -> _T:
return self.impl(**kwargs) # type: ignore[return-value]
Expand Down Expand Up @@ -297,9 +271,6 @@ def __init__(
self.type_ = type_ or type(object_)
self.impl = object_

def _resolve_type_impl(self, _: dict[str, Any] | None = None) -> type[_T]:
return self.type_

def provide_sync(self, kwargs: Mapping[str, Any]) -> _T: # noqa: ARG002
return self.impl

Expand Down
2 changes: 1 addition & 1 deletion tests/providers/test_scoped.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,4 +160,4 @@ async def async_gen() -> AsyncGenerator[int, None]:
@pytest.mark.parametrize("factory", [iterable, gen, async_iterable, async_gen])
def test_generator_return_types(factory: Any) -> None:
provider = providers.Scoped(factory)
assert provider.resolve_type() is int
assert provider.type_ is int

0 comments on commit dfa62ba

Please sign in to comment.