Skip to content

Commit

Permalink
Fix future annotations (#23)
Browse files Browse the repository at this point in the history
* partial fix for #21

* test unions
  • Loading branch information
dlwh authored Nov 15, 2024
1 parent b405346 commit 6db809c
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 53 deletions.
4 changes: 3 additions & 1 deletion draccus/parsers/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,18 @@ def decode_dataclass(cls: Type[Dataclass], d: Dict[str, Any], path: Sequence[str

logger.debug(f"from_dict for {cls}")

hints = typing.get_type_hints(cls)

for field in fields(cls): # type: ignore
name = field.name
field_type = hints.get(name, field.type)
if name not in obj_dict:
# if field.default is MISSING and field.default_factory is MISSING:
# logger.warning(f"Couldn't find the field '{name}' in the dict with keys {list(d.keys())}")
continue

raw_value = obj_dict.pop(name)
try:
field_type = field.type
logger.debug(f"Decode name = {name}, type = {field_type}")
field_value = get_decoding_fn(field_type)(raw_value, (*path, name)) # type: ignore
except ParsingError as e:
Expand Down
8 changes: 6 additions & 2 deletions draccus/wrappers/choice_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,9 @@ def _children(self) -> Sequence[Optional[Wrapper]]:
def _wrap_child(child: Type) -> Optional[Wrapper]:
if has_custom_decoder(child):
assert self._field is not None
field = FieldWrapper(parent=self.parent, field=self._field, preferred_help=self.preferred_help)
field = FieldWrapper(
parent=self.parent, field=self._field, preferred_help=self.preferred_help, field_type=child
)
field.required = False
return field
elif dataclasses.is_dataclass(child):
Expand All @@ -247,7 +249,9 @@ def _wrap_child(child: Type) -> Optional[Wrapper]:
return None
else:
assert self._field is not None
wrapper = FieldWrapper(parent=self.parent, field=self._field, preferred_help=self.preferred_help)
wrapper = FieldWrapper(
parent=self.parent, field=self._field, preferred_help=self.preferred_help, field_type=child
)
wrapper.required = False
return wrapper

Expand Down
33 changes: 21 additions & 12 deletions draccus/wrappers/dataclass_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import dataclasses
import typing
from logging import getLogger
from typing import Dict, List, Optional, Type, Union, cast

Expand Down Expand Up @@ -48,9 +49,13 @@ def __init__(
self.defaults = [default] # type: ignore

self.optional: bool = False
hints = typing.get_type_hints(dataclass)

for field in dataclasses.fields(self.dataclass): # type: ignore
child = _wrap_field(self, field, preferred_help=self.preferred_help)
# have to get the real type of the field because of __future__ annotations
field_type = hints.get(field.name, field.type)

child = _wrap_field(self, field, preferred_help=self.preferred_help, field_type=field_type)
if child is not None:
self._children.append(child)

Expand Down Expand Up @@ -150,45 +155,49 @@ def _wrap_field(
parent: Optional[Wrapper],
field: dataclasses.Field,
preferred_help: str = docstring.HelpOrder.inline,
field_type: Optional[Type] = None,
) -> Optional[Wrapper]:
if not field.init:
return None

elif has_custom_decoder(field.type):
field_wrapper = FieldWrapper(field, parent=parent, preferred_help=preferred_help)
if field_type is None:
field_type = field.type

if has_custom_decoder(field_type):
field_wrapper = FieldWrapper(field, parent=parent, preferred_help=preferred_help, field_type=field_type)
logger.debug(f"wrapped field at {field_wrapper.dest} has a default value of {field_wrapper.default}")
return field_wrapper
elif utils.is_choice_type(field.type):
elif utils.is_choice_type(field_type):
from .choice_wrapper import ChoiceWrapper

return ChoiceWrapper(
cast(Type[ChoiceType], field.type), field.name, parent=parent, _field=field, preferred_help=preferred_help
cast(Type[ChoiceType], field_type), field.name, parent=parent, _field=field, preferred_help=preferred_help
)

elif utils.is_tuple_or_list_of_dataclasses(field.type):
elif utils.is_tuple_or_list_of_dataclasses(field_type):
logger.debug(f"wrapped field at {field.name} is a list of dataclasses, treating a ordinary field for argparse")
field_wrapper = FieldWrapper(field, parent=parent, preferred_help=preferred_help)
field_wrapper = FieldWrapper(field, parent=parent, preferred_help=preferred_help, field_type=field_type)
return field_wrapper
# raise NotImplementedError(
# f"Field {field.name} is of type {field.type}, which isn't supported yet. (container of a dataclass type)"
# )

elif dataclasses.is_dataclass(field.type):
elif dataclasses.is_dataclass(field_type):
# handle a nested dataclass attribute
dataclass, name = (cast(DataclassType, field.type)), field.name
dataclass, name = (cast(DataclassType, field_type)), field.name
child_wrapper = DataclassWrapper(dataclass, name, parent=parent, _field=field, preferred_help=preferred_help)
return child_wrapper

elif utils.is_optional_or_union_with_dataclass_type_arg(field.type):
elif utils.is_optional_or_union_with_dataclass_type_arg(field_type):
name = field.name
from .choice_wrapper import UnionWrapper

wrapper = UnionWrapper(field.type, name=name, parent=parent, _field=field, preferred_help=preferred_help)
wrapper = UnionWrapper(field_type, name=name, parent=parent, _field=field, preferred_help=preferred_help)
return wrapper

else:
# a normal attribute
field_wrapper = FieldWrapper(field, parent=parent, preferred_help=preferred_help)
field_wrapper = FieldWrapper(field, parent=parent, preferred_help=preferred_help, field_type=field_type)
logger.debug(f"wrapped field at {field_wrapper.dest} has a default value of {field_wrapper.default}")
# self._children.append(field_wrapper)
return field_wrapper
3 changes: 2 additions & 1 deletion draccus/wrappers/field_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@ def field(self) -> dataclasses.Field:
def __init__(
self,
field: dataclasses.Field,
field_type: Type,
parent: Optional[Wrapper] = None,
preferred_help: str = docstring.HelpOrder.inline,
):
self._field = field
self._parent: Optional[Wrapper] = parent
self._type = field_type
# Holders used to 'cache' the properties.
# (could've used cached_property with Python 3.8).
self._option_strings: Optional[Set[str]] = None
Expand All @@ -50,7 +52,6 @@ def __init__(
self._dest: Optional[str] = None
# the argparse-related options:
self._arg_options: Dict[str, Any] = {}
self._type: Optional[Type[Any]] = None

# preferred parse for docstring / help text in < inline | above | below >
self.preferred_help = preferred_help
Expand Down
3 changes: 1 addition & 2 deletions tests/test_decoding.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import json
from dataclasses import dataclass, field
from enum import Enum, auto
from typing import Dict, List, Optional, Tuple
from typing import Dict, Tuple

import yaml

import draccus
from draccus.utils import DraccusException

from .testutils import *
Expand Down
29 changes: 29 additions & 0 deletions tests/test_future_annotations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import List

import draccus

# it seems like typing.gettypehints doesn't really work with locals so, we just make these module scope


@dataclass
class A:
b: int = 1


@dataclass
class C:
a: A = A()
elems: List[A] = field(default_factory=list)


def test_future_annotations():
an_a: A = draccus.parse(config_class=A, args="")
assert an_a.b == 1


def test_nested_future_annotations():
c: C = draccus.parse(config_class=C, args="")
assert c.a.b == 1
79 changes: 44 additions & 35 deletions tests/test_union.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Dict, Union

Expand Down Expand Up @@ -95,69 +97,76 @@ class Foo(TestSetup):
)


def test_decode_union_with_dataclass_and_atomic():
@dataclass
class Baz:
z: int
@dataclass
class Baz_u:
z: int

@dataclass
class Foo(TestSetup):
x: Union[bool, Baz]

foo = Foo.setup("--x false")
@dataclass
class Foo_u(TestSetup):
x: Union[bool, Baz_u]


def test_decode_union_with_dataclass_and_atomic():
foo = Foo_u.setup("--x false")
assert foo.x is False

foo = Foo.setup("--x.z 1")
assert foo.x == Baz(z=1)
foo = Foo_u.setup("--x.z 1")
assert foo.x == Baz_u(z=1)

try:
foo = Foo.setup("--x.z 1.2")
foo = Foo_u.setup("--x.z 1.2")
raise AssertionError()
except DecodingError:
pass


def test_union_error_message_dataclasses():
@dataclass
class Baz:
z: int
y: str
@dataclass
class Baz_e:
z: int
y: str

@dataclass
class Bar:
z: bool

@dataclass
class Foo(TestSetup):
x: Union[Baz, Bar] = 0
@dataclass
class Bar_e:
z: bool


@dataclass
class Foo_e(TestSetup):
x: Union[Baz_e, Bar_e] = Bar_e(False)


def test_union_error_message_dataclasses():
with pytest.raises(DecodingError) as e:
Foo.setup("--x.z 1.2.3")
Foo_e.setup("--x.z 1.2.3")

assert """`x`: Could not decode the value into any of the given types:
Baz: `z`: Couldn't parse '1.2.3' into an int
Bar: `z`: Couldn't parse '1.2.3' into a bool""".strip() in str(
Baz_e: `z`: Couldn't parse '1.2.3' into an int
Bar_e: `z`: Couldn't parse '1.2.3' into a bool""".strip() in str(
e.value
)

with pytest.raises(DecodingError) as e:
Foo.setup("--x.y foo")
Foo_e.setup("--x.y foo")

assert """`x`: Could not decode the value into any of the given types:
Baz: Missing required field(s) `z` for Baz
Bar: The fields `y` are not valid for Bar""".strip() in str(
Baz_e: Missing required field(s) `z` for Baz_e
Bar_e: The fields `y` are not valid for Bar_e""".strip() in str(
e.value
)


def test_union_argparse_dict():
@dataclass
class Bar:
y: int
@dataclass
class Bar:
y: int

@dataclass
class Foo(TestSetup):
x: Optional[Union[Bar, Dict[str, Bar]]] = field(default=None)

@dataclass
class Foo(TestSetup):
x: Optional[Union[Bar, Dict[str, Bar]]] = field(default=None)


def test_union_argparse_dict():
foo = Foo.setup('--x \'{"a": {"y": 1}, "b": {"y": 2}}\'')
assert foo.x == {"a": Bar(y=1), "b": Bar(y=2)}

0 comments on commit 6db809c

Please sign in to comment.