Skip to content

Commit

Permalink
fix: update get_typevars to return a list of TypeVars and enhance nes…
Browse files Browse the repository at this point in the history
…ted generic tests
  • Loading branch information
nrbnlulu committed Dec 23, 2024
1 parent 46e1815 commit 979b77f
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 9 deletions.
9 changes: 6 additions & 3 deletions aioinject/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,14 @@ def is_generic_alias(type_: Any) -> TypeGuard[GenericAlias]:
def get_orig_bases(type_: type) -> tuple[type, ...] | None:
return getattr(type_, "__orig_bases__", None)

def get_typevars(type_: Any) -> tuple[t.TypeVar, ...] | None:
def get_typevars(type_: Any) -> list[t.TypeVar] | None:
if is_generic_alias(type_):
args = t.get_args(type_)
if all(isinstance(arg, t.TypeVar) for arg in args):
return args
return [
arg
for arg in args
if isinstance(arg, t.TypeVar)
]
return None
class InjectionContext(_BaseInjectionContext[ContextExtension]):
async def resolve(
Expand Down
45 changes: 39 additions & 6 deletions tests/features/test_generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,26 +73,28 @@ class NestedGenericService(Generic[T]):
def __init__(self, service: T) -> None:
self.service = service

MEANING_OF_LIFE = 42
MEANING_OF_LIFE_INT = 42
MEANING_OF_LIFE_STR = "42"

class Something:
def __init__(self) -> None:
self.a = MEANING_OF_LIFE
self.a = MEANING_OF_LIFE_INT

async def test_nested_generics() -> None:
container = Container()
container.register(
Scoped(NestedGenericService[WithGenericDependency[Something]]),
Scoped(WithGenericDependency[Something]),
Scoped(Something),
Object(MEANING_OF_LIFE),
Object(MEANING_OF_LIFE_INT),
Object("42"))

async with container.context() as ctx:
instance = await ctx.resolve(NestedGenericService[WithGenericDependency[Something]])
assert isinstance(instance, NestedGenericService)
assert isinstance(instance.service, WithGenericDependency)
assert isinstance(instance.service.dependency, Something)
assert instance.service.dependency.a == MEANING_OF_LIFE
assert instance.service.dependency.a == MEANING_OF_LIFE_INT

IS_PY_312 = sys.version_info >= (3, 12)
skip_ifnot_312 = pytest.mark.skipif(not IS_PY_312, reason="Python 3.12+ required")
Expand All @@ -113,7 +115,7 @@ async def test_nested_unresolved_generic() -> None:
instance = await ctx.resolve(TestNestedUnresolvedGeneric[int])
assert isinstance(instance, TestNestedUnresolvedGeneric)
assert isinstance(instance.service, WithGenericDependency)
assert instance.service.dependency == 42
assert instance.service.dependency == MEANING_OF_LIFE_INT



Expand All @@ -134,4 +136,35 @@ class GenericImpl(TestNestedUnresolvedGeneric[str]):
instance = await ctx.resolve(GenericImpl)
assert isinstance(instance, GenericImpl)
assert isinstance(instance.service, WithGenericDependency)
assert instance.service.dependency == "42"
assert instance.service.dependency == "42"



async def test_partially_resolved_generic() -> None:
K = TypeVar("K")
class TwoGeneric(Generic[T, K]):
def __init__(self, a: WithGenericDependency[T], b: WithGenericDependency[K]) -> None:
self.a = a
self.b = b


class UsesTwoGeneric(Generic[T]):
def __init__(self, service: TwoGeneric[T, str]) -> None:
self.service = service

container = Container()
container.register(Scoped(UsesTwoGeneric[int]),
Scoped(TwoGeneric[int, str]),
Scoped(WithGenericDependency[int]),
Scoped(WithGenericDependency[str]),
Object(MEANING_OF_LIFE_INT),
Object("42"))

async with container.context() as ctx:
instance = await ctx.resolve(UsesTwoGeneric[int])
assert isinstance(instance, UsesTwoGeneric)
assert isinstance(instance.service, TwoGeneric)
assert isinstance(instance.service.a, WithGenericDependency)
assert isinstance(instance.service.b, WithGenericDependency)
assert instance.service.a.dependency == MEANING_OF_LIFE_INT
assert instance.service.b.dependency == MEANING_OF_LIFE_STR

0 comments on commit 979b77f

Please sign in to comment.