Skip to content

Commit

Permalink
fix unions with dict inside them
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Nov 13, 2024
1 parent 919084c commit 29b772d
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 8 deletions.
13 changes: 8 additions & 5 deletions draccus/wrappers/choice_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import inspect
from dataclasses import Field
from functools import cached_property
from typing import Dict, Optional, Type
from typing import Dict, Optional, Sequence, Type

from ..choice_types import CHOICE_TYPE_KEY, ChoiceType
from ..parsers.decoding import has_custom_decoder
Expand Down Expand Up @@ -195,7 +195,7 @@ def register_actions(self, parser: argparse.ArgumentParser) -> None:

has_field_wrapper = False

for child in children.values():
for child in children:
from .dataclass_wrapper import DataclassWrapper

if isinstance(child, DataclassWrapper):
Expand Down Expand Up @@ -225,7 +225,7 @@ def register_actions(self, parser: argparse.ArgumentParser) -> None:
)

@cached_property
def _children(self) -> Dict[str, Optional[Wrapper]]:
def _children(self) -> Sequence[Optional[Wrapper]]:
from .dataclass_wrapper import DataclassWrapper
from .field_wrapper import FieldWrapper

Expand All @@ -246,6 +246,9 @@ def _wrap_child(child: Type) -> Optional[Wrapper]:
elif child is None or child is type(None):
return None
else:
raise ValueError(f"Unexpected child type: {child}")
assert self._field is not None
wrapper = FieldWrapper(parent=self.parent, field=self._field, preferred_help=self.preferred_help)
wrapper.required = False
return wrapper

return {child.__name__: _wrap_child(child) for child in self.union.__args__}
return [_wrap_child(child) for child in self.union.__args__]
1 change: 0 additions & 1 deletion draccus/wrappers/dataclass_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,6 @@ def _wrap_field(
return child_wrapper

elif utils.is_optional_or_union_with_dataclass_type_arg(field.type):
# TODO(dlwh): I don't like this. Add UnionWrapper or something
name = field.name
from .choice_wrapper import UnionWrapper

Expand Down
17 changes: 15 additions & 2 deletions tests/test_union.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Union
from dataclasses import dataclass, field
from typing import Dict, Union

import pytest

Expand Down Expand Up @@ -148,3 +148,16 @@ class Foo(TestSetup):
Bar: The fields `y` are not valid for Bar""".strip() in str(
e.value
)


def test_union_argparse_dict():
@dataclass
class Bar:
y: int

@dataclass
class Foo(TestSetup):
x: Optional[Union[Bar, Dict[str, Bar]]] = field(default=None)

foo = Foo.setup('--x \'{"a": {"y": 1}, "b": {"y": 2}}\'')
assert foo.x == {"a": Bar(y=1), "b": Bar(y=2)}

0 comments on commit 29b772d

Please sign in to comment.