diff --git a/draccus/wrappers/choice_wrapper.py b/draccus/wrappers/choice_wrapper.py index df7c297..be1ddc7 100644 --- a/draccus/wrappers/choice_wrapper.py +++ b/draccus/wrappers/choice_wrapper.py @@ -7,7 +7,7 @@ from ..choice_types import CHOICE_TYPE_KEY, ChoiceType from ..parsers.decoding import has_custom_decoder -from ..utils import canonicalize_union, is_union +from ..utils import canonicalize_union, is_choice_type, is_union from . import FieldWrapper, docstring from .wrapper import AggregateWrapper, Wrapper @@ -237,12 +237,14 @@ def _wrap_child(child: Type) -> Optional[Wrapper]: ) field.required = False return field + elif is_choice_type(child): + return ChoiceWrapper( + child, name=self.name, parent=self.parent, _field=self._field, preferred_help=self.preferred_help + ) elif dataclasses.is_dataclass(child): return DataclassWrapper( child, name=self.name, parent=self.parent, _field=self._field, preferred_help=self.preferred_help ) - elif inspect.isclass(child) and issubclass(child, ChoiceType): - return ChoiceWrapper(child, parent=self.parent, _field=self._field, preferred_help=self.preferred_help) elif is_union(child): return UnionWrapper(child, parent=self.parent, _field=self._field, preferred_help=self.preferred_help) elif child is None or child is type(None): diff --git a/tests/test_optional_choice_type.py b/tests/test_optional_choice_type.py new file mode 100644 index 0000000..be68485 --- /dev/null +++ b/tests/test_optional_choice_type.py @@ -0,0 +1,91 @@ +# test_optional_choice_type.py + +from dataclasses import dataclass +from typing import Optional + +import pytest + +import draccus +from draccus.choice_types import ChoiceRegistry +from draccus.utils import DecodingError +from tests.testutils import TestSetup + + +@dataclass +class Person(ChoiceRegistry): + name: str + + +@dataclass +class Adult(Person): + age: int + + +@dataclass +class Child(Person): + favorite_toy: str + + +Person.register_subclass("adult", Adult) +Person.register_subclass("child", Child) + + +@dataclass +class Profile(TestSetup): + person: Optional[Person] = None + + +def test_optional_choice_empty(): + profile = Profile.setup("") + assert profile.person is None + + +def test_optional_choice_adult(): + profile = Profile.setup("--person.type adult --person.name Bob --person.age 30") + assert profile.person == Adult(name="Bob", age=30) + + +def test_optional_choice_child(): + profile = Profile.setup("--person.type child --person.name Alice --person.favorite_toy truck") + assert profile.person == Child(name="Alice", favorite_toy="truck") + + +def test_invalid_choice(): + with pytest.raises(SystemExit) as excinfo: + Profile.setup("--person.type invalid_type --person.name Jill") + assert excinfo.type is SystemExit + assert excinfo.value.code == 2 + + +def test_invalid_fields_adult(): + with pytest.raises(DecodingError): + Profile.setup("--person.type adult --person.name Bob --person.age 30 --person.favorite_toy truck") + + +def test_encode_optional_none(): + profile = Profile() + assert draccus.encode(profile) == {"person": None} + + +def test_encode_optional_child(): + profile = Profile(person=Child(name="Kevin", favorite_toy="ball")) + encoded = draccus.encode(profile) + assert encoded == { + "person": { + "type": "child", + "name": "Kevin", + "favorite_toy": "ball", + } + } + + +def test_encode_optional_adult(): + profile = Profile(person=Adult(name="Bob", age=42)) + encoded = draccus.encode(profile) + assert encoded == { + "person": { + "type": "adult", + "name": "Bob", + "age": 42, + } + }