Skip to content

Commit

Permalink
fix: unparse config value to generate toml file (a16z#383)
Browse files Browse the repository at this point in the history
  • Loading branch information
daejunpark authored Oct 2, 2024
1 parent e2bd1fd commit 4cd642b
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 7 deletions.
51 changes: 44 additions & 7 deletions src/halmos/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import sys
from collections import OrderedDict
from collections.abc import Callable
from collections.abc import Callable, Generator
from dataclasses import MISSING, dataclass, fields
from dataclasses import field as dataclass_field
from typing import Any
Expand Down Expand Up @@ -52,14 +52,29 @@ def arg(
)


def ensure_non_empty(values: list | set | dict, raw_values: str) -> list:
if not values:
raise ValueError(f"required a non-empty list, but got {raw_values}")
return values


def parse_csv(values: str, sep: str = ",") -> Generator[Any, None, None]:
"""Parse a CSV string and return a generator of *non-empty* values."""
return (x for _x in values.split(sep) if (x := _x.strip()))


class ParseCSV(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
values = ParseCSV.parse(values)
setattr(namespace, self.dest, values)

@staticmethod
def parse(values: str) -> list[int]:
return [int(x.strip()) for x in values.split(",")]
return ensure_non_empty([int(x) for x in parse_csv(values)], values)

@staticmethod
def unparse(values: list[int]) -> str:
return ",".join([str(v) for v in values])


class ParseErrorCodes(argparse.Action):
Expand All @@ -75,7 +90,13 @@ def parse(values: str) -> set[int]:
return set()

# support multiple bases: decimal, hex, etc.
return set(int(x.strip(), 0) for x in values.split(","))
return ensure_non_empty(set(int(x, 0) for x in parse_csv(values)), values)

@staticmethod
def unparse(values: set[int]) -> str:
if not values:
return "*"
return ",".join([f"0x{v:02x}" for v in values])


class ParseArrayLengths(argparse.Action):
Expand All @@ -88,13 +109,20 @@ def parse(values: str | None) -> dict[str, list[int]]:
if not values:
return {}

# TODO: update syntax: name1=size1,size2; name2=size3,...; ...
name_sizes_pairs = values.split(",")
# TODO: update syntax: name1={size1,size2},name2=size3,...
return {
name.strip(): [int(x.strip()) for x in sizes.split(";")]
for name, sizes in [x.split("=") for x in name_sizes_pairs]
name.strip(): ensure_non_empty(
[int(x) for x in parse_csv(sizes, sep=";")], sizes
)
for name, sizes in (x.split("=") for x in parse_csv(values))
}

@staticmethod
def unparse(values: dict[str, list[int]]) -> str:
return ",".join(
[f"{k}={';'.join([str(v) for v in vs])}" for k, vs in values.items()]
)


# TODO: add kw_only=True when we support Python>=3.10
@dataclass(frozen=True)
Expand Down Expand Up @@ -731,6 +759,10 @@ def _to_toml_str(value: Any, type) -> str:
continue

group_name = field_info.metadata.get("group", None)
if group_name == deprecated:
# skip deprecated options
continue

if group_name != current_group_name:
separator = "#" * 80
lines.append(f"\n{separator}")
Expand All @@ -746,6 +778,11 @@ def _to_toml_str(value: Any, type) -> str:
(value, source) = config.value_with_source(field_info.name)
default = field_info.metadata.get("global_default", None)

# unparse value if action is provided
# note: this is a workaround because not all types can be represented in toml syntax, e.g., sets.
if action := field_info.metadata.get("action", None):
value = action.unparse(value)

# callable defaults mean that the default value is not a hardcoded constant
# it depends on the context, so don't emit it in the config file unless it
# is explicitly set by the user on the command line
Expand Down
119 changes: 119 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

from halmos.config import (
Config,
ParseArrayLengths,
ParseCSV,
ParseErrorCodes,
arg_parser,
default_config,
resolve_config_files,
Expand Down Expand Up @@ -185,3 +188,119 @@ def test_config_pickle(config, parser):
# then the config object should be the same
assert config == unpickled
assert unpickled.value_with_source("verbose") == (3, "command-line")


def test_parse_csv():
with pytest.raises(ValueError):
ParseCSV.parse("")
ParseCSV.parse(" ")
ParseCSV.parse(",")
assert ParseCSV.parse("0") == [0]
assert ParseCSV.parse("0,") == [0]
assert ParseCSV.parse("1,2,3") == [1, 2, 3]
assert ParseCSV.parse("1,2,3,") == [1, 2, 3]
assert ParseCSV.parse(" 1 , 2 , 3 ") == [1, 2, 3]
assert ParseCSV.parse(" , 1 , 2 , 3 , ") == [1, 2, 3]


def test_unparse_csv():
assert ParseCSV.unparse([]) == ""
assert ParseCSV.unparse([0]) == "0"
assert ParseCSV.unparse([1, 2, 3]) == "1,2,3"


def test_parse_csv_roundtrip():
test_cases = [
[0],
[1, 2, 3],
]

for original in test_cases:
unparsed = ParseCSV.unparse(original)
parsed = ParseCSV.parse(unparsed)
assert parsed == original, f"Roundtrip failed for {original}"


def test_parse_error_codes():
with pytest.raises(ValueError):
ParseErrorCodes.parse("")
ParseErrorCodes.parse(" ")
ParseErrorCodes.parse(",")
ParseErrorCodes.parse("1,*")
ParseErrorCodes.parse(",*")
ParseErrorCodes.parse("*,")
assert ParseErrorCodes.parse("*") == set()
assert ParseErrorCodes.parse(" * ") == set()
assert ParseErrorCodes.parse("0") == {0}
assert ParseErrorCodes.parse("0,") == {0}
assert ParseErrorCodes.parse("1,2,3") == {1, 2, 3}
assert ParseErrorCodes.parse("1,2,3,") == {1, 2, 3}
assert ParseErrorCodes.parse(" 1 , 2 , 3 ") == {1, 2, 3}
assert ParseErrorCodes.parse(" , 1 , 2 , 3 , ") == {1, 2, 3}
assert ParseErrorCodes.parse(" 0b10 , 0o10 , 10, 0x10 ") == {2, 8, 10, 16}


def test_unparse_error_codes():
assert ParseErrorCodes.unparse(set()) == "*"
assert ParseErrorCodes.unparse({0}) == "0x00"
assert ParseErrorCodes.unparse({1, 2}) in {"0x01,0x02", "0x02,0x01"}


def test_parse_error_codes_roundtrip():
test_cases = [
set(),
{0},
{1, 2},
{1, 2, 3},
]

for original in test_cases:
unparsed = ParseErrorCodes.unparse(original)
parsed = ParseErrorCodes.parse(unparsed)
assert parsed == original, f"Roundtrip failed for {original}"


def test_parse_array_lengths():
with pytest.raises(ValueError):
ParseArrayLengths.parse("x=")
ParseArrayLengths.parse("x= ")
ParseArrayLengths.parse("x=;")
assert ParseArrayLengths.parse("") == {}
assert ParseArrayLengths.parse(" ") == {}
assert ParseArrayLengths.parse(",") == {}
assert ParseArrayLengths.parse("x=1") == {"x": [1]}
assert ParseArrayLengths.parse("x=1,") == {"x": [1]}
assert ParseArrayLengths.parse("x=1;2,y=3") == {"x": [1, 2], "y": [3]}
assert ParseArrayLengths.parse("x=1;2;,y=3") == {"x": [1, 2], "y": [3]}
assert ParseArrayLengths.parse("x=1;2,y=3;") == {"x": [1, 2], "y": [3]}
assert ParseArrayLengths.parse("x=1;2,y=3,") == {"x": [1, 2], "y": [3]}
assert ParseArrayLengths.parse("x=1;2;,y=3;,") == {"x": [1, 2], "y": [3]}
assert ParseArrayLengths.parse(" x = 1 ; 2 , y = 3 ") == {"x": [1, 2], "y": [3]}
assert ParseArrayLengths.parse(" , x = 1 ; 2 , y = 3 , ") == {"x": [1, 2], "y": [3]}
assert ParseArrayLengths.parse(" , x = ; 1 ; 2 ; , y = ; 3 ; , ") == {
"x": [1, 2],
"y": [3],
}


def test_unparse_array_lengths():
assert ParseArrayLengths.unparse({}) == ""
assert ParseArrayLengths.unparse({"x": [1]}) == "x=1"
assert ParseArrayLengths.unparse({"x": [1, 2], "y": [3]}) in {
"x=1;2,y=3",
"y=3,x=1;2",
}


def test_parse_array_lengths_roundtrip():
test_cases = [
{},
{"x": [1]},
{"x": [1, 2], "y": [3]},
{"x": [1, 2, 3], "y": [4, 5], "z": [6]},
]

for original in test_cases:
unparsed = ParseArrayLengths.unparse(original)
parsed = ParseArrayLengths.parse(unparsed)
assert parsed == original, f"Roundtrip failed for {original}"

0 comments on commit 4cd642b

Please sign in to comment.