diff --git a/draccus/wrappers/choice_wrapper.py b/draccus/wrappers/choice_wrapper.py index a52a2f7..ec7fe8e 100644 --- a/draccus/wrappers/choice_wrapper.py +++ b/draccus/wrappers/choice_wrapper.py @@ -3,7 +3,7 @@ import inspect from dataclasses import Field from functools import cached_property -from typing import Dict, Optional, Type +from typing import Dict, Optional, Sequence, Type from ..choice_types import CHOICE_TYPE_KEY, ChoiceType from ..parsers.decoding import has_custom_decoder @@ -195,7 +195,7 @@ def register_actions(self, parser: argparse.ArgumentParser) -> None: has_field_wrapper = False - for child in children.values(): + for child in children: from .dataclass_wrapper import DataclassWrapper if isinstance(child, DataclassWrapper): @@ -225,7 +225,7 @@ def register_actions(self, parser: argparse.ArgumentParser) -> None: ) @cached_property - def _children(self) -> Dict[str, Optional[Wrapper]]: + def _children(self) -> Sequence[Optional[Wrapper]]: from .dataclass_wrapper import DataclassWrapper from .field_wrapper import FieldWrapper @@ -246,6 +246,9 @@ def _wrap_child(child: Type) -> Optional[Wrapper]: elif child is None or child is type(None): return None else: - raise ValueError(f"Unexpected child type: {child}") + assert self._field is not None + wrapper = FieldWrapper(parent=self.parent, field=self._field, preferred_help=self.preferred_help) + wrapper.required = False + return wrapper - return {child.__name__: _wrap_child(child) for child in self.union.__args__} + return [_wrap_child(child) for child in self.union.__args__] diff --git a/draccus/wrappers/dataclass_wrapper.py b/draccus/wrappers/dataclass_wrapper.py index faf8ac7..2b397f6 100644 --- a/draccus/wrappers/dataclass_wrapper.py +++ b/draccus/wrappers/dataclass_wrapper.py @@ -180,7 +180,6 @@ def _wrap_field( return child_wrapper elif utils.is_optional_or_union_with_dataclass_type_arg(field.type): - # TODO(dlwh): I don't like this. Add UnionWrapper or something name = field.name from .choice_wrapper import UnionWrapper diff --git a/tests/test_union.py b/tests/test_union.py index 354d3a5..3a96e30 100644 --- a/tests/test_union.py +++ b/tests/test_union.py @@ -1,5 +1,5 @@ -from dataclasses import dataclass -from typing import Union +from dataclasses import dataclass, field +from typing import Dict, Union import pytest @@ -148,3 +148,16 @@ class Foo(TestSetup): Bar: The fields `y` are not valid for Bar""".strip() in str( e.value ) + + +def test_union_argparse_dict(): + @dataclass + class Bar: + y: int + + @dataclass + class Foo(TestSetup): + x: Optional[Union[Bar, Dict[str, Bar]]] = field(default=None) + + foo = Foo.setup('--x \'{"a": {"y": 1}, "b": {"y": 2}}\'') + assert foo.x == {"a": Bar(y=1), "b": Bar(y=2)}