-
Notifications
You must be signed in to change notification settings - Fork 106
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Make it possible to specify multiple config files
Summary: Make it possible to specify multiple config files. Parsing CLI is not a special case anymore, just uses the same config inheritance method. Test Plan: Test that this iterpolates in the right order via unit tests Sample usage, loads the internal config, which references bytelatent/configs/entropy_model.yaml. The precendence order is: - Default pydantic args - Included configs, eg `config` - CLI args ``` python -m bytelatent.print_config config=internal/configs/entropy_model.yaml eval=null ``` Summary: Test Plan:
- Loading branch information
Showing
13 changed files
with
286 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import copy | ||
from typing import Type, TypeVar | ||
|
||
import omegaconf | ||
from omegaconf import DictConfig, OmegaConf | ||
from pydantic import BaseModel | ||
|
||
|
||
def parse_file_config(path: str) -> DictConfig: | ||
file_cfg = OmegaConf.load(path) | ||
if not isinstance(file_cfg, DictConfig): | ||
raise ValueError( | ||
f"File paths must parse to DictConfig, but it was: {type(file_cfg)}" | ||
) | ||
return file_cfg | ||
|
||
|
||
def recursively_parse_config(cfg: DictConfig) -> list[DictConfig]: | ||
if "config" not in cfg: | ||
return [cfg] | ||
|
||
ordered_cfgs = [] | ||
cfg = copy.deepcopy(cfg) | ||
config_arg = cfg["config"] | ||
del cfg["config"] | ||
ordered_cfgs.append(cfg) | ||
|
||
if isinstance(config_arg, str): | ||
file_cfg = parse_file_config(config_arg) | ||
sub_configs = recursively_parse_config(file_cfg) | ||
ordered_cfgs = sub_configs + ordered_cfgs | ||
elif isinstance(config_arg, omegaconf.listconfig.ListConfig): | ||
sub_configs = [] | ||
for c in config_arg: | ||
if not isinstance(c, str): | ||
raise ValueError( | ||
f'If "config" is specified, it must be either a string path or a list of string paths. It was config={config_arg}' | ||
) | ||
config_to_parse = parse_file_config(c) | ||
sub_configs.extend(recursively_parse_config(config_to_parse)) | ||
ordered_cfgs = sub_configs + ordered_cfgs | ||
else: | ||
raise ValueError( | ||
f'If "config" is specified, it must be either a string path or a list of string paths, it was config={config_arg}' | ||
) | ||
return ordered_cfgs | ||
|
||
|
||
def parse_args_with_default( | ||
*, default_cfg: DictConfig | None = None, cli_args: DictConfig | None = None | ||
): | ||
if cli_args is None: | ||
cli_args = OmegaConf.from_cli() | ||
assert isinstance( | ||
cli_args, DictConfig | ||
), f"CLI Args must be a DictConfig, not {type(cli_args)}" | ||
ordered_cfgs = recursively_parse_config(cli_args) | ||
if default_cfg is not None: | ||
ordered_cfgs.insert(0, default_cfg) | ||
cfg = OmegaConf.merge(*ordered_cfgs) | ||
return OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) | ||
|
||
|
||
T = TypeVar("T", bound=BaseModel) | ||
|
||
|
||
def parse_args_to_pydantic_model( | ||
args_cls: Type[T], cli_args: DictConfig | None = None | ||
) -> T: | ||
default_cfg = OmegaConf.create(args_cls().model_dump()) | ||
parsed_cfg = parse_args_with_default(default_cfg=default_cfg, cli_args=cli_args) | ||
pydantic_args = args_cls.model_validate(parsed_cfg) | ||
return pydantic_args |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from bytelatent.args import TrainArgs | ||
from bytelatent.config_parser import parse_args_to_pydantic_model | ||
|
||
|
||
def main(): | ||
train_args = parse_args_to_pydantic_model(TrainArgs) | ||
print(train_args.model_dump_json(indent=4)) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
import os | ||
|
||
import pytest | ||
from omegaconf import DictConfig, MissingMandatoryValue, OmegaConf | ||
from pydantic import BaseModel, ConfigDict | ||
|
||
from bytelatent.config_parser import ( | ||
parse_args_to_pydantic_model, | ||
parse_file_config, | ||
recursively_parse_config, | ||
) | ||
|
||
FIXTURE_DIR = "fixtures/test-cfgs" | ||
|
||
|
||
def test_parse_file_config(): | ||
with pytest.raises(ValueError): | ||
cfg = parse_file_config(os.path.join(FIXTURE_DIR, "list.yaml")) | ||
assert isinstance(cfg, DictConfig) | ||
|
||
|
||
def test_nop(): | ||
cfg = OmegaConf.create({"a": 1}) | ||
parsed_cfgs = recursively_parse_config(cfg) | ||
assert len(parsed_cfgs) == 1 | ||
assert parsed_cfgs[0] == cfg | ||
|
||
|
||
def test_root(): | ||
cli_cfg = OmegaConf.create({"config": os.path.join(FIXTURE_DIR, "root.yaml")}) | ||
parsed_cfgs = recursively_parse_config(cli_cfg) | ||
assert len(parsed_cfgs) == 2 | ||
assert len(parsed_cfgs[1]) == 0 | ||
assert parsed_cfgs[0]["seed"] == -1 | ||
with pytest.raises(MissingMandatoryValue): | ||
assert parsed_cfgs[0]["b"]["y"] is not None | ||
|
||
# Test basic cli override | ||
cli_cfg = OmegaConf.create( | ||
{"config": os.path.join(FIXTURE_DIR, "root.yaml"), "seed": 42} | ||
) | ||
parsed_cfgs = recursively_parse_config(cli_cfg) | ||
assert parsed_cfgs[1]["seed"] == 42 | ||
cfg = OmegaConf.merge(*parsed_cfgs) | ||
assert cfg["seed"] == 42 | ||
|
||
|
||
def test_one_level_include(): | ||
cli_cfg = OmegaConf.create({"config": os.path.join(FIXTURE_DIR, "middle.yaml")}) | ||
parsed_cfgs = recursively_parse_config(cli_cfg) | ||
assert len(parsed_cfgs) == 3 | ||
assert parsed_cfgs[0]["seed"] == -1 | ||
assert parsed_cfgs[1]["b"]["y"] == 10 | ||
assert len(parsed_cfgs[2]) == 0 | ||
cfg = OmegaConf.merge(*parsed_cfgs) | ||
assert cfg["b"]["y"] == 10 | ||
|
||
cli_cfg = OmegaConf.create( | ||
{"config": os.path.join(FIXTURE_DIR, "middle.yaml"), "b": {"y": 100}} | ||
) | ||
parsed_cfgs = recursively_parse_config(cli_cfg) | ||
assert len(parsed_cfgs) == 3 | ||
assert parsed_cfgs[0]["seed"] == -1 | ||
assert parsed_cfgs[1]["b"]["y"] == 10 | ||
assert parsed_cfgs[2]["b"]["y"] == 100 | ||
cfg = OmegaConf.merge(*parsed_cfgs) | ||
assert cfg["b"]["y"] == 100 | ||
|
||
|
||
def test_two_level_include(): | ||
cli_cfg = OmegaConf.create( | ||
{"config": os.path.join(FIXTURE_DIR, "top.yaml"), "p": 500, "b": {"z": -2}} | ||
) | ||
parsed_cfgs = recursively_parse_config(cli_cfg) | ||
assert len(parsed_cfgs) == 4 | ||
assert parsed_cfgs[0]["seed"] == -1 | ||
assert parsed_cfgs[1]["b"]["y"] == 10 | ||
assert parsed_cfgs[2]["hello"] == "world" | ||
assert parsed_cfgs[3]["p"] == 500 | ||
assert parsed_cfgs[3]["b"]["z"] == -2 | ||
cfg = OmegaConf.merge(*parsed_cfgs) | ||
assert cfg["a"] == 1 | ||
assert cfg["seed"] == -1 | ||
assert cfg["b"]["x"] == 0 | ||
assert cfg["b"]["y"] == 10 | ||
assert cfg["b"]["z"] == -2 | ||
assert cfg["hello"] == "world" | ||
|
||
|
||
def test_multiple_includes(): | ||
cli_cfg = OmegaConf.create( | ||
{ | ||
"config": [ | ||
os.path.join(FIXTURE_DIR, "top.yaml"), | ||
os.path.join(FIXTURE_DIR, "override.yaml"), | ||
], | ||
"p": 500, | ||
"b": {"z": -2}, | ||
} | ||
) | ||
parsed_cfgs = recursively_parse_config(cli_cfg) | ||
assert len(parsed_cfgs) == 5 | ||
assert parsed_cfgs[0]["seed"] == -1 | ||
assert parsed_cfgs[1]["b"]["y"] == 10 | ||
assert parsed_cfgs[2]["hello"] == "world" | ||
assert parsed_cfgs[3]["a"] == 100 | ||
assert parsed_cfgs[4]["p"] == 500 | ||
assert parsed_cfgs[4]["b"]["z"] == -2 | ||
cfg = OmegaConf.merge(*parsed_cfgs) | ||
assert cfg["a"] == 100 | ||
assert cfg["seed"] == -1 | ||
assert cfg["b"]["x"] == 0 | ||
assert cfg["b"]["y"] == 10 | ||
assert cfg["b"]["z"] == -2 | ||
assert cfg["hello"] == "world" | ||
|
||
cli_cfg = OmegaConf.create( | ||
{ | ||
"config": [ | ||
os.path.join(FIXTURE_DIR, "top.yaml"), | ||
os.path.join(FIXTURE_DIR, "override.yaml"), | ||
], | ||
"p": 500, | ||
"b": {"z": -2}, | ||
"a": 1000, | ||
} | ||
) | ||
parsed_cfgs = recursively_parse_config(cli_cfg) | ||
assert len(parsed_cfgs) == 5 | ||
assert parsed_cfgs[0]["seed"] == -1 | ||
assert parsed_cfgs[1]["b"]["y"] == 10 | ||
assert parsed_cfgs[2]["hello"] == "world" | ||
assert parsed_cfgs[3]["a"] == 100 | ||
assert parsed_cfgs[4]["p"] == 500 | ||
assert parsed_cfgs[4]["b"]["z"] == -2 | ||
cfg = OmegaConf.merge(*parsed_cfgs) | ||
assert cfg["a"] == 1000 | ||
assert cfg["seed"] == -1 | ||
assert cfg["b"]["x"] == 0 | ||
assert cfg["b"]["y"] == 10 | ||
assert cfg["b"]["z"] == -2 | ||
assert cfg["hello"] == "world" | ||
|
||
|
||
class SubConfig(BaseModel): | ||
model_config = ConfigDict(extra="forbid") | ||
x: int = -100 | ||
y: int = -100 | ||
z: int = -5 | ||
|
||
|
||
class SampleConfig(BaseModel): | ||
model_config = ConfigDict(extra="forbid") | ||
a: int = -100 | ||
seed: int = -100 | ||
b: SubConfig = SubConfig() | ||
hello: str = "" | ||
p: int = -100 | ||
|
||
|
||
def test_pydantic_parse(): | ||
cli_cfg = OmegaConf.create( | ||
{ | ||
"config": [ | ||
os.path.join(FIXTURE_DIR, "top.yaml"), | ||
os.path.join(FIXTURE_DIR, "override.yaml"), | ||
], | ||
"p": 500, | ||
"a": 1000, | ||
} | ||
) | ||
cfg = parse_args_to_pydantic_model(SampleConfig, cli_args=cli_cfg) | ||
assert isinstance(cfg, SampleConfig) | ||
assert cfg.a == 1000 | ||
assert cfg.p == 500 | ||
assert cfg.seed == -1 | ||
assert cfg.b.x == 0 | ||
assert cfg.b.y == 10 | ||
assert cfg.b.z == -5 | ||
assert cfg.hello == "world" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
[1, 2, 3] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
config: fixtures/test-cfgs/root.yaml | ||
b: | ||
y: 10 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
a: 100 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
seed: -1 | ||
a: 1 | ||
b: | ||
x: 0 | ||
y: ??? | ||
z: ??? |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
config: fixtures/test-cfgs/middle.yaml | ||
|
||
hello: world |