diff --git a/simple_parsing/helpers/serialization/serializable.py b/simple_parsing/helpers/serialization/serializable.py index e935ed89..6fa39f17 100644 --- a/simple_parsing/helpers/serialization/serializable.py +++ b/simple_parsing/helpers/serialization/serializable.py @@ -204,7 +204,7 @@ def __init_subclass__( register_decoding_fn(cls, cls.from_dict) def to_dict( - self, dict_factory: type[dict] = dict, recurse: bool = True, save_dc_types: bool = False + self, dict_factory: type[dict] = dict, recurse: bool = True, save_dc_types: int | bool = False ) -> dict: """Serializes this dataclass to a dict. @@ -597,7 +597,7 @@ def save( obj: Any, path: str | Path, format: FormatExtension | None = None, - save_dc_types: bool = False, + save_dc_types: int | bool = False, **kwargs, ) -> None: """Save the given dataclass or dictionary to the given file.""" @@ -688,7 +688,7 @@ def to_dict( dc: DataclassT, dict_factory: type[dict] = dict, recurse: bool = True, - save_dc_types: bool = False, + save_dc_types: int | bool = False, ) -> dict: """Serializes this dataclass to a dict. @@ -720,6 +720,10 @@ def to_dict( else: d[DC_TYPE_KEY] = module + "." + class_name + # Decrement save_dc_types if it is an int + if save_dc_types is not True and save_dc_types > 0: + save_dc_types -= 1 + for f in fields(dc): name = f.name value = getattr(dc, name) diff --git a/simple_parsing/helpers/subgroups.py b/simple_parsing/helpers/subgroups.py index d9487998..79754370 100644 --- a/simple_parsing/helpers/subgroups.py +++ b/simple_parsing/helpers/subgroups.py @@ -10,6 +10,7 @@ from typing_extensions import TypeAlias +from simple_parsing.helpers.serialization.serializable import to_dict from simple_parsing.utils import DataclassT, is_dataclass_instance, is_dataclass_type logger = get_logger(__name__) @@ -112,6 +113,14 @@ def subgroups( metadata["subgroup_default"] = default metadata["subgroup_dataclass_types"] = {} + def _encoding_fn(value: Any) -> dict: + """Custom encoding function that will simply represent the value as the + the key in the dict rather than the value itself. + """ + return to_dict(value, save_dc_types=1) + + kwargs.setdefault("encoding_fn", _encoding_fn) + subgroup_dataclass_types: dict[Key, type[DataclassT]] = {} choices = subgroups.keys() diff --git a/simple_parsing/parsing.py b/simple_parsing/parsing.py index ec81bdd1..940357b1 100644 --- a/simple_parsing/parsing.py +++ b/simple_parsing/parsing.py @@ -14,7 +14,7 @@ from collections import defaultdict from logging import getLogger from pathlib import Path -from typing import Any, Callable, Sequence, Type, overload +from typing import Any, Callable, Sequence, Type, cast, overload from simple_parsing.helpers.subgroups import SubgroupKey from simple_parsing.wrappers.dataclass_wrapper import DataclassWrapperType @@ -22,7 +22,7 @@ from . import utils from .conflicts import ConflictResolution, ConflictResolver from .help_formatter import SimpleHelpFormatter -from .helpers.serialization.serializable import read_file +from .helpers.serialization.serializable import DC_TYPE_KEY, from_dict, read_file from .utils import ( Dataclass, DataclassT, @@ -646,7 +646,6 @@ def _resolve_subgroups( if subgroup_field.subgroup_default is dataclasses.MISSING: assert argument_options["required"] else: - assert argument_options["default"] is subgroup_field.subgroup_default assert not is_dataclass_instance(argument_options["default"]) # TODO: Do we really need to care about this "SUPPRESS" stuff here? @@ -674,7 +673,7 @@ def _resolve_subgroups( # here. subgroup_dict = subgroup_field.subgroup_choices chosen_subgroup_key: SubgroupKey = getattr(parsed_args, dest) - assert chosen_subgroup_key in subgroup_dict + assert isinstance(chosen_subgroup_key, dict) or chosen_subgroup_key in subgroup_dict # Changing the default value of the (now parsed) field for the subgroup choice, # just so it shows (default: {chosen_subgroup_key}) on the command-line. @@ -687,7 +686,11 @@ def _resolve_subgroups( f"{chosen_subgroup_key!r}" ) - default_or_dataclass_fn = subgroup_dict[chosen_subgroup_key] + if isinstance(chosen_subgroup_key, dict): + default_or_dataclass_fn = from_dict(cast(Type[Dataclass], None), chosen_subgroup_key) + else: + default_or_dataclass_fn = subgroup_dict[chosen_subgroup_key] + if is_dataclass_instance(default_or_dataclass_fn): # The chosen value in the subgroup dict is a frozen dataclass instance. default = default_or_dataclass_fn @@ -1124,6 +1127,11 @@ def _create_dataclass_instance( # None. # TODO: (BUG!) This doesn't distinguish the case where the defaults are passed via the # command-line from the case where no arguments are passed at all! + dc_type = constructor_args.pop(DC_TYPE_KEY, None) + if dc_type is not None: + from simple_parsing.helpers.serialization.serializable import _locate + constructor = _locate(dc_type) + if wrapper.optional and wrapper.default is None: for field_wrapper in wrapper.fields: