From 979b77f1f58b0d8c94888de3b95e72df544b00ac Mon Sep 17 00:00:00 2001 From: nir Date: Mon, 23 Dec 2024 10:24:33 +0200 Subject: [PATCH] fix: update get_typevars to return a list of TypeVars and enhance nested generic tests --- aioinject/context.py | 9 ++++--- tests/features/test_generics.py | 45 ++++++++++++++++++++++++++++----- 2 files changed, 45 insertions(+), 9 deletions(-) diff --git a/aioinject/context.py b/aioinject/context.py index 51b6a83..b90ffcc 100644 --- a/aioinject/context.py +++ b/aioinject/context.py @@ -81,11 +81,14 @@ def is_generic_alias(type_: Any) -> TypeGuard[GenericAlias]: def get_orig_bases(type_: type) -> tuple[type, ...] | None: return getattr(type_, "__orig_bases__", None) -def get_typevars(type_: Any) -> tuple[t.TypeVar, ...] | None: +def get_typevars(type_: Any) -> list[t.TypeVar] | None: if is_generic_alias(type_): args = t.get_args(type_) - if all(isinstance(arg, t.TypeVar) for arg in args): - return args + return [ + arg + for arg in args + if isinstance(arg, t.TypeVar) + ] return None class InjectionContext(_BaseInjectionContext[ContextExtension]): async def resolve( diff --git a/tests/features/test_generics.py b/tests/features/test_generics.py index 980492f..02a64b1 100644 --- a/tests/features/test_generics.py +++ b/tests/features/test_generics.py @@ -73,10 +73,12 @@ class NestedGenericService(Generic[T]): def __init__(self, service: T) -> None: self.service = service -MEANING_OF_LIFE = 42 +MEANING_OF_LIFE_INT = 42 +MEANING_OF_LIFE_STR = "42" + class Something: def __init__(self) -> None: - self.a = MEANING_OF_LIFE + self.a = MEANING_OF_LIFE_INT async def test_nested_generics() -> None: container = Container() @@ -84,7 +86,7 @@ async def test_nested_generics() -> None: Scoped(NestedGenericService[WithGenericDependency[Something]]), Scoped(WithGenericDependency[Something]), Scoped(Something), - Object(MEANING_OF_LIFE), + Object(MEANING_OF_LIFE_INT), Object("42")) async with container.context() as ctx: @@ -92,7 +94,7 @@ async def test_nested_generics() -> None: assert isinstance(instance, NestedGenericService) assert isinstance(instance.service, WithGenericDependency) assert isinstance(instance.service.dependency, Something) - assert instance.service.dependency.a == MEANING_OF_LIFE + assert instance.service.dependency.a == MEANING_OF_LIFE_INT IS_PY_312 = sys.version_info >= (3, 12) skip_ifnot_312 = pytest.mark.skipif(not IS_PY_312, reason="Python 3.12+ required") @@ -113,7 +115,7 @@ async def test_nested_unresolved_generic() -> None: instance = await ctx.resolve(TestNestedUnresolvedGeneric[int]) assert isinstance(instance, TestNestedUnresolvedGeneric) assert isinstance(instance.service, WithGenericDependency) - assert instance.service.dependency == 42 + assert instance.service.dependency == MEANING_OF_LIFE_INT @@ -134,4 +136,35 @@ class GenericImpl(TestNestedUnresolvedGeneric[str]): instance = await ctx.resolve(GenericImpl) assert isinstance(instance, GenericImpl) assert isinstance(instance.service, WithGenericDependency) - assert instance.service.dependency == "42" \ No newline at end of file + assert instance.service.dependency == "42" + + + +async def test_partially_resolved_generic() -> None: + K = TypeVar("K") + class TwoGeneric(Generic[T, K]): + def __init__(self, a: WithGenericDependency[T], b: WithGenericDependency[K]) -> None: + self.a = a + self.b = b + + + class UsesTwoGeneric(Generic[T]): + def __init__(self, service: TwoGeneric[T, str]) -> None: + self.service = service + + container = Container() + container.register(Scoped(UsesTwoGeneric[int]), + Scoped(TwoGeneric[int, str]), + Scoped(WithGenericDependency[int]), + Scoped(WithGenericDependency[str]), + Object(MEANING_OF_LIFE_INT), + Object("42")) + + async with container.context() as ctx: + instance = await ctx.resolve(UsesTwoGeneric[int]) + assert isinstance(instance, UsesTwoGeneric) + assert isinstance(instance.service, TwoGeneric) + assert isinstance(instance.service.a, WithGenericDependency) + assert isinstance(instance.service.b, WithGenericDependency) + assert instance.service.a.dependency == MEANING_OF_LIFE_INT + assert instance.service.b.dependency == MEANING_OF_LIFE_STR