Skip to content

Commit

Permalink
Enable Optional[ChoiceType] (#29)
Browse files Browse the repository at this point in the history
* Enable optional choice_type

* Add test
  • Loading branch information
aliberts authored Jan 4, 2025
1 parent f78f352 commit 55e456a
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 3 deletions.
8 changes: 5 additions & 3 deletions draccus/wrappers/choice_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from ..choice_types import CHOICE_TYPE_KEY, ChoiceType
from ..parsers.decoding import has_custom_decoder
from ..utils import canonicalize_union, is_union
from ..utils import canonicalize_union, is_choice_type, is_union
from . import FieldWrapper, docstring
from .wrapper import AggregateWrapper, Wrapper

Expand Down Expand Up @@ -237,12 +237,14 @@ def _wrap_child(child: Type) -> Optional[Wrapper]:
)
field.required = False
return field
elif is_choice_type(child):
return ChoiceWrapper(
child, name=self.name, parent=self.parent, _field=self._field, preferred_help=self.preferred_help
)
elif dataclasses.is_dataclass(child):
return DataclassWrapper(
child, name=self.name, parent=self.parent, _field=self._field, preferred_help=self.preferred_help
)
elif inspect.isclass(child) and issubclass(child, ChoiceType):
return ChoiceWrapper(child, parent=self.parent, _field=self._field, preferred_help=self.preferred_help)
elif is_union(child):
return UnionWrapper(child, parent=self.parent, _field=self._field, preferred_help=self.preferred_help)
elif child is None or child is type(None):
Expand Down
91 changes: 91 additions & 0 deletions tests/test_optional_choice_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# test_optional_choice_type.py

from dataclasses import dataclass
from typing import Optional

import pytest

import draccus
from draccus.choice_types import ChoiceRegistry
from draccus.utils import DecodingError
from tests.testutils import TestSetup


@dataclass
class Person(ChoiceRegistry):
name: str


@dataclass
class Adult(Person):
age: int


@dataclass
class Child(Person):
favorite_toy: str


Person.register_subclass("adult", Adult)
Person.register_subclass("child", Child)


@dataclass
class Profile(TestSetup):
person: Optional[Person] = None


def test_optional_choice_empty():
profile = Profile.setup("")
assert profile.person is None


def test_optional_choice_adult():
profile = Profile.setup("--person.type adult --person.name Bob --person.age 30")
assert profile.person == Adult(name="Bob", age=30)


def test_optional_choice_child():
profile = Profile.setup("--person.type child --person.name Alice --person.favorite_toy truck")
assert profile.person == Child(name="Alice", favorite_toy="truck")


def test_invalid_choice():
with pytest.raises(SystemExit) as excinfo:
Profile.setup("--person.type invalid_type --person.name Jill")
assert excinfo.type is SystemExit
assert excinfo.value.code == 2


def test_invalid_fields_adult():
with pytest.raises(DecodingError):
Profile.setup("--person.type adult --person.name Bob --person.age 30 --person.favorite_toy truck")


def test_encode_optional_none():
profile = Profile()
assert draccus.encode(profile) == {"person": None}


def test_encode_optional_child():
profile = Profile(person=Child(name="Kevin", favorite_toy="ball"))
encoded = draccus.encode(profile)
assert encoded == {
"person": {
"type": "child",
"name": "Kevin",
"favorite_toy": "ball",
}
}


def test_encode_optional_adult():
profile = Profile(person=Adult(name="Bob", age=42))
encoded = draccus.encode(profile)
assert encoded == {
"person": {
"type": "adult",
"name": "Bob",
"age": 42,
}
}

0 comments on commit 55e456a

Please sign in to comment.