Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug with config_path when using save(config, save_dc_types=True) #284

Open
wants to merge 20 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ repos:
require_serial: true
- id: check-added-large-files
require_serial: true
- id: check-merge-conflict
require_serial: true
args: ["--assume-in-merge"]

- repo: https://github.com/charliermarsh/ruff-pre-commit
# Ruff version.
Expand All @@ -41,7 +44,7 @@ repos:

# python docstring formatting
- repo: https://github.com/myint/docformatter
rev: v1.5.1
rev: v1.7.5
hooks:
- id: docformatter
exclude: ^test/test_docstrings.py
Expand All @@ -63,7 +66,6 @@ repos:
- id: nbstripout
require_serial: true


# md formatting
- repo: https://github.com/executablebooks/mdformat
rev: 0.7.16
Expand Down
4 changes: 2 additions & 2 deletions examples/docstrings/docstrings_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ class DocStringsExample:

# comment above 42
attribute4: float = 1.0 # inline comment
"""docstring below (this appears in --help)"""
"""Docstring below (this appears in --help)"""

# comment above (this appears in --help) 46
attribute5: float = 1.0 # inline comment

attribute6: float = 1.0 # inline comment (this appears in --help)

attribute7: float = 1.0 # inline comment
"""docstring below (this appears in --help)"""
"""Docstring below (this appears in --help)"""


parser.add_arguments(DocStringsExample, "example")
Expand Down
2 changes: 1 addition & 1 deletion examples/merging/multiple_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class Config:
run_name: str = "train" # Some parameter for the run name.
some_int: int = 10 # an optional int parameter.
log_dir: str = "logs" # an optional string parameter.
"""the logging directory to use.
"""The logging directory to use.

(This is an attribute docstring for the log_dir attribute, and shows up when using the "--help"
argument!)
Expand Down
5 changes: 4 additions & 1 deletion examples/simple/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ class HParams:
expected += """
parser = ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

group = parser.add_argument_group(title="HParams ['hparams']", description="Set of options for the training of a Model.")
group = parser.add_argument_group(
title="HParams ['hparams']",
description="Set of options for the training of a Model.",
)
group.add_argument(*['--num_layers'], **{'type': int, 'required': False, 'dest': 'hparams.num_layers', 'default': 4, 'help': ' '})
group.add_argument(*['--num_units'], **{'type': int, 'required': False, 'dest': 'hparams.num_units', 'default': 64, 'help': ' '})
group.add_argument(*['--optimizer'], **{'type': str, 'required': False, 'dest': 'hparams.optimizer', 'default': 'ADAM', 'help': ' '})
Expand Down
12 changes: 2 additions & 10 deletions examples/simple/flag.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,9 @@
from dataclasses import dataclass

from simple_parsing import ArgumentParser
from simple_parsing import ArgumentParser, parse
from simple_parsing.helpers import flag


def parse(cls, args: str = ""):
"""Removes some boilerplate code from the examples."""
parser = ArgumentParser() # Create an argument parser
parser.add_arguments(cls, dest="hparams") # add arguments for the dataclass
ns = parser.parse_args(args.split()) # parse the given `args`
return ns.hparams


@dataclass
class HParams:
"""Set of options for the training of a Model."""
Expand All @@ -32,7 +24,7 @@ class HParams:
"""

# Example 2 using the flags negative prefix
assert parse(HParams, "--no-train") == HParams(train=False)
assert parse(HParams, args="--no-train") == HParams(train=False)


# showing what --help outputs
Expand Down
5 changes: 4 additions & 1 deletion examples/simple/inheritance.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ class MAML(Method):
expected += """
parser = ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

group = parser.add_argument_group(title="MAML ['hparams']", description="Overwrites some of the default values and adds new arguments/attributes.")
group = parser.add_argument_group(
title="MAML ['hparams']",
description="Overwrites some of the default values and adds new arguments/attributes.",
)
group.add_argument(*['--num_layers'], **{'type': int, 'required': False, 'dest': 'hparams.num_layers', 'default': 6, 'help': ' '})
group.add_argument(*['--num_units'], **{'type': int, 'required': False, 'dest': 'hparams.num_units', 'default': 128, 'help': ' '})
group.add_argument(*['--optimizer'], **{'type': str, 'required': False, 'dest': 'hparams.optimizer', 'default': 'ADAM', 'help': ' '})
Expand Down
2 changes: 1 addition & 1 deletion examples/ugly/ugly_example_after.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ class RenderingParams:

@dataclass
class Parameters:
"""base options."""
"""Base options."""

# Dataset parameters.
dataset: DatasetParams = field(default_factory=DatasetParams)
Expand Down
2 changes: 1 addition & 1 deletion examples/ugly/ugly_example_before.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


class Parameters:
"""base options."""
"""Base options."""

def __init__(self):
"""Constructor."""
Expand Down
1 change: 1 addition & 0 deletions simple_parsing/docstring.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def get_attribute_docstring(
dataclass: type, field_name: str, accumulate_from_bases: bool = True
) -> AttributeDocString:
"""Returns the docstrings of a dataclass field.

NOTE: a docstring can either be:
- An inline comment, starting with <#>
- A Comment on the preceding line, starting with <#>
Expand Down
37 changes: 36 additions & 1 deletion simple_parsing/helpers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
"""Collection of helper classes and functions to reduce boilerplate code."""
from .fields import *
from .fields import (
choice,
dict_field,
field,
flag,
flags,
list_field,
mutable_field,
set_field,
subparsers,
)
from .flatten import FlattenedAccess
from .hparams import HyperParameters
from .partial import Partial, config_for
from .serialization import FrozenSerializable, Serializable, SimpleJsonEncoder, encode
from .subgroups import subgroups

try:
from .serialization import YamlSerializable
Expand All @@ -13,3 +24,27 @@
# For backward compatibility purposes
JsonSerializable = Serializable
SimpleEncoder = SimpleJsonEncoder

__all__ = [
"FlattenedAccess",
"HyperParameters",
"Partial",
"config_for",
"FrozenSerializable",
"Serializable",
"SimpleJsonEncoder",
"encode",
"JsonSerializable",
"SimpleEncoder",
"YamlSerializable",
"field",
"choice",
"list_field",
"dict_field",
"set_field",
"mutable_field",
"subparsers",
"flag",
"flags",
"subgroups",
]
4 changes: 2 additions & 2 deletions simple_parsing/helpers/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def _decoding_fn(value: Any) -> Any:


def list_field(*default_items: T, **kwargs) -> list[T]:
"""shorthand function for setting a `list` attribute on a dataclass, so that every instance of
"""Shorthand function for setting a `list` attribute on a dataclass, so that every instance of
the dataclass doesn't share the same list.

Accepts any of the arguments of the `dataclasses.field` function.
Expand All @@ -285,7 +285,7 @@ def list_field(*default_items: T, **kwargs) -> list[T]:


def dict_field(default_items: dict[K, V] | Iterable[tuple[K, V]] = (), **kwargs) -> dict[K, V]:
"""shorthand function for setting a `dict` attribute on a dataclass, so that every instance of
"""Shorthand function for setting a `dict` attribute on a dataclass, so that every instance of
the dataclass doesn't share the same `dict`.

NOTE: Do not use keyword arguments as you usually would with a dictionary
Expand Down
27 changes: 21 additions & 6 deletions simple_parsing/helpers/serialization/serializable.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import functools
import json
import pickle
import warnings
Expand Down Expand Up @@ -207,7 +208,8 @@ def __init_subclass__(
if parent in SerializableMixin.subclasses and parent is not SerializableMixin:
decode_into_subclasses = parent.decode_into_subclasses
logger.debug(
f"Parent class {parent} has decode_into_subclasses = {decode_into_subclasses}"
f"Parent class {parent} has decode_into_subclasses = "
f"{decode_into_subclasses}"
)
break

Expand All @@ -219,7 +221,10 @@ def __init_subclass__(
register_decoding_fn(cls, cls.from_dict)

def to_dict(
self, dict_factory: type[dict] = dict, recurse: bool = True, save_dc_types: bool = False
self,
dict_factory: type[dict] = dict,
recurse: bool = True,
save_dc_types: bool | int = False,
) -> dict:
"""Serializes this dataclass to a dict.

Expand Down Expand Up @@ -595,6 +600,7 @@ def loads_yaml(

def read_file(path: str | Path) -> dict:
"""Returns the contents of the given file as a dictionary.

Uses the right function depending on `path.suffix`:
{
".yml": yaml.safe_load,
Expand All @@ -613,7 +619,7 @@ def save(
obj: Any,
path: str | Path,
format: FormatExtension | None = None,
save_dc_types: bool = False,
save_dc_types: bool | int = False,
**kwargs,
) -> None:
"""Save the given dataclass or dictionary to the given file."""
Expand Down Expand Up @@ -704,7 +710,7 @@ def to_dict(
dc: DataclassT,
dict_factory: type[dict] = dict,
recurse: bool = True,
save_dc_types: bool = False,
save_dc_types: bool | int = False,
) -> dict:
"""Serializes this dataclass to a dict.

Expand Down Expand Up @@ -736,6 +742,11 @@ def to_dict(
else:
d[DC_TYPE_KEY] = module + "." + class_name

# Decrement save_dc_types if it is an int, so that we only save the type of the subgroups
# dataclass, not all dataclasses recursively.
if save_dc_types is not True and save_dc_types > 0:
save_dc_types -= 1

for f in fields(dc):
name = f.name
value = getattr(dc, name)
Expand Down Expand Up @@ -763,7 +774,8 @@ def to_dict(
encoded = encoding_fn(value)
except Exception as e:
logger.error(
f"Unable to encode value {value} of type {type(value)}! Leaving it as-is. (exception: {e})"
f"Unable to encode value {value} of type {type(value)}! Leaving it as-is. "
f"(exception: {e})"
)
encoded = value
d[name] = encoded
Expand Down Expand Up @@ -832,6 +844,7 @@ def from_dict(
if name not in obj_dict:
if (
field.metadata.get("to_dict", True)
and field.init
and field.default is MISSING
and field.default_factory is MISSING
):
Expand Down Expand Up @@ -928,6 +941,7 @@ def is_dataclass_or_optional_dataclass_type(t: type) -> bool:
return is_dataclass(t) or (is_optional(t) and is_dataclass(get_args(t)[0]))


@functools.lru_cache(maxsize=None)
def _locate(path: str) -> Any:
"""COPIED FROM Hydra: https://github.com/facebookresearch/hydra/blob/f8940600d0ab5c695961ad83ab
d042ffe9458caf/hydra/_internal/utils.py#L614.
Expand Down Expand Up @@ -968,7 +982,8 @@ def _locate(path: str) -> Any:
except ModuleNotFoundError as exc_import:
raise ImportError(
f"Error loading '{path}':\n{repr(exc_import)}"
+ f"\nAre you sure that '{part}' is importable from module '{parent_dotpath}'?"
+ f"\nAre you sure that '{part}' is importable from module "
f"'{parent_dotpath}'?"
) from exc_import
except Exception as exc_import:
raise ImportError(
Expand Down
Loading
Loading