Skip to content

Commit

Permalink
Make it possible to specify multiple config files
Browse files Browse the repository at this point in the history
Summary:

Make it possible to specify multiple config files.
Parsing CLI is not a special case anymore, just uses the same config inheritance method.

Test Plan:

Test that this iterpolates in the right order via unit tests

Sample usage, loads the internal config, which references bytelatent/configs/entropy_model.yaml. The precendence order is:

- Default pydantic args
- Included configs, eg `config`
- CLI args

```
python -m bytelatent.print_config config=internal/configs/entropy_model.yaml eval=null

```


Summary:

Test Plan:
  • Loading branch information
EntilZha committed Feb 18, 2025
1 parent 9f29e0d commit 3117ac1
Show file tree
Hide file tree
Showing 13 changed files with 286 additions and 27 deletions.
14 changes: 0 additions & 14 deletions bytelatent/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import numpy as np
import yaml
from omegaconf import OmegaConf
from pydantic import BaseModel, ConfigDict

from bytelatent.checkpoint import CheckpointArgs
Expand Down Expand Up @@ -38,19 +37,6 @@ def get_rng_state(seed: int, rank: int, world_size: int) -> dict[str, Any]:
return np.random.default_rng((seed, rank, world_size)).bit_generator.state


def parse_args(args_cls):
cli_args = OmegaConf.from_cli()
file_cfg = OmegaConf.load(cli_args.config)
# We remove 'config' attribute from config as the underlying DataClass does not have it
del cli_args.config

default_cfg = OmegaConf.create(args_cls().model_dump())
cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args)
cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
pydantic_args = args_cls.model_validate(cfg)
return pydantic_args


TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl"


Expand Down
73 changes: 73 additions & 0 deletions bytelatent/config_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import copy
from typing import Type, TypeVar

import omegaconf
from omegaconf import DictConfig, OmegaConf
from pydantic import BaseModel


def parse_file_config(path: str) -> DictConfig:
file_cfg = OmegaConf.load(path)
if not isinstance(file_cfg, DictConfig):
raise ValueError(
f"File paths must parse to DictConfig, but it was: {type(file_cfg)}"
)
return file_cfg


def recursively_parse_config(cfg: DictConfig) -> list[DictConfig]:
if "config" not in cfg:
return [cfg]

ordered_cfgs = []
cfg = copy.deepcopy(cfg)
config_arg = cfg["config"]
del cfg["config"]
ordered_cfgs.append(cfg)

if isinstance(config_arg, str):
file_cfg = parse_file_config(config_arg)
sub_configs = recursively_parse_config(file_cfg)
ordered_cfgs = sub_configs + ordered_cfgs
elif isinstance(config_arg, omegaconf.listconfig.ListConfig):
sub_configs = []
for c in config_arg:
if not isinstance(c, str):
raise ValueError(
f'If "config" is specified, it must be either a string path or a list of string paths. It was config={config_arg}'
)
config_to_parse = parse_file_config(c)
sub_configs.extend(recursively_parse_config(config_to_parse))
ordered_cfgs = sub_configs + ordered_cfgs
else:
raise ValueError(
f'If "config" is specified, it must be either a string path or a list of string paths, it was config={config_arg}'
)
return ordered_cfgs


def parse_args_with_default(
*, default_cfg: DictConfig | None = None, cli_args: DictConfig | None = None
):
if cli_args is None:
cli_args = OmegaConf.from_cli()
assert isinstance(
cli_args, DictConfig
), f"CLI Args must be a DictConfig, not {type(cli_args)}"
ordered_cfgs = recursively_parse_config(cli_args)
if default_cfg is not None:
ordered_cfgs.insert(0, default_cfg)
cfg = OmegaConf.merge(*ordered_cfgs)
return OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)


T = TypeVar("T", bound=BaseModel)


def parse_args_to_pydantic_model(
args_cls: Type[T], cli_args: DictConfig | None = None
) -> T:
default_cfg = OmegaConf.create(args_cls().model_dump())
parsed_cfg = parse_args_with_default(default_cfg=default_cfg, cli_args=cli_args)
pydantic_args = args_cls.model_validate(parsed_cfg)
return pydantic_args
4 changes: 1 addition & 3 deletions bytelatent/configs/debug.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,11 @@ model:
recompute_attn: false
custom_bwd: false
layer_ckpt: "none"
patch_only_encoder: false
patch_only_decoder: false
use_local_encoder_transformer: true
init_use_gaussian: true
init_use_depth: "current"
attn_bias_type: "block_causal"
attn_impl: "xformers"
attn_bias_type: "block_causal"
alpha_depth: "disabled"
max_length: 256
local_attention_window_len: 512
Expand Down
4 changes: 2 additions & 2 deletions bytelatent/configs/entropy_model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
# Evals can be activated by uncommenting its config
# python -m launchers.stool config=apps/main/configs/debug.yaml nodes=8 account=fair_amaia_cw_codegen qos=lowest

dump_dir: /tmp/
dump_dir: /tmp/blt-entropy
name: "debug"
steps: 100_000
max_steps: null
probe_freq: null
seed: 777
optim:
Expand Down Expand Up @@ -35,7 +36,6 @@ entropy_model:
attn_impl: "xformers"

data:
s3_profile: blt
root_dir: ???
sources:
dclm_baseline_1.0: 1.0
Expand Down
8 changes: 2 additions & 6 deletions bytelatent/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,15 @@
import os
from collections import defaultdict
from datetime import datetime
from pathlib import Path
from typing import Any

import torch
from lm_eval import simple_evaluate
from lm_eval.api.instance import Instance
from lm_eval.api.model import LM
from omegaconf import OmegaConf
from pydantic import BaseModel, ConfigDict

from bytelatent.args import EvalArgs, ValidationArgs, parse_args
from bytelatent.args import EvalArgs, ValidationArgs
from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints
from bytelatent.config_parser import parse_args_to_pydantic_model
from bytelatent.data.file_util import get_fs
from bytelatent.distributed import (
DistributedArgs,
Expand All @@ -29,7 +26,6 @@
PackedCausalTransformerGenerator,
load_consolidated_model_and_tokenizer,
)
from bytelatent.transformer import LMTransformer, LMTransformerArgs

EVAL_FOLDER_NAME = "{:010d}"

Expand Down
11 changes: 11 additions & 0 deletions bytelatent/print_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from bytelatent.args import TrainArgs
from bytelatent.config_parser import parse_args_to_pydantic_model


def main():
train_args = parse_args_to_pydantic_model(TrainArgs)
print(train_args.model_dump_json(indent=4))


if __name__ == "__main__":
main()
180 changes: 180 additions & 0 deletions bytelatent/test_config_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
import os

import pytest
from omegaconf import DictConfig, MissingMandatoryValue, OmegaConf
from pydantic import BaseModel, ConfigDict

from bytelatent.config_parser import (
parse_args_to_pydantic_model,
parse_file_config,
recursively_parse_config,
)

FIXTURE_DIR = "fixtures/test-cfgs"


def test_parse_file_config():
with pytest.raises(ValueError):
cfg = parse_file_config(os.path.join(FIXTURE_DIR, "list.yaml"))
assert isinstance(cfg, DictConfig)


def test_nop():
cfg = OmegaConf.create({"a": 1})
parsed_cfgs = recursively_parse_config(cfg)
assert len(parsed_cfgs) == 1
assert parsed_cfgs[0] == cfg


def test_root():
cli_cfg = OmegaConf.create({"config": os.path.join(FIXTURE_DIR, "root.yaml")})
parsed_cfgs = recursively_parse_config(cli_cfg)
assert len(parsed_cfgs) == 2
assert len(parsed_cfgs[1]) == 0
assert parsed_cfgs[0]["seed"] == -1
with pytest.raises(MissingMandatoryValue):
assert parsed_cfgs[0]["b"]["y"] is not None

# Test basic cli override
cli_cfg = OmegaConf.create(
{"config": os.path.join(FIXTURE_DIR, "root.yaml"), "seed": 42}
)
parsed_cfgs = recursively_parse_config(cli_cfg)
assert parsed_cfgs[1]["seed"] == 42
cfg = OmegaConf.merge(*parsed_cfgs)
assert cfg["seed"] == 42


def test_one_level_include():
cli_cfg = OmegaConf.create({"config": os.path.join(FIXTURE_DIR, "middle.yaml")})
parsed_cfgs = recursively_parse_config(cli_cfg)
assert len(parsed_cfgs) == 3
assert parsed_cfgs[0]["seed"] == -1
assert parsed_cfgs[1]["b"]["y"] == 10
assert len(parsed_cfgs[2]) == 0
cfg = OmegaConf.merge(*parsed_cfgs)
assert cfg["b"]["y"] == 10

cli_cfg = OmegaConf.create(
{"config": os.path.join(FIXTURE_DIR, "middle.yaml"), "b": {"y": 100}}
)
parsed_cfgs = recursively_parse_config(cli_cfg)
assert len(parsed_cfgs) == 3
assert parsed_cfgs[0]["seed"] == -1
assert parsed_cfgs[1]["b"]["y"] == 10
assert parsed_cfgs[2]["b"]["y"] == 100
cfg = OmegaConf.merge(*parsed_cfgs)
assert cfg["b"]["y"] == 100


def test_two_level_include():
cli_cfg = OmegaConf.create(
{"config": os.path.join(FIXTURE_DIR, "top.yaml"), "p": 500, "b": {"z": -2}}
)
parsed_cfgs = recursively_parse_config(cli_cfg)
assert len(parsed_cfgs) == 4
assert parsed_cfgs[0]["seed"] == -1
assert parsed_cfgs[1]["b"]["y"] == 10
assert parsed_cfgs[2]["hello"] == "world"
assert parsed_cfgs[3]["p"] == 500
assert parsed_cfgs[3]["b"]["z"] == -2
cfg = OmegaConf.merge(*parsed_cfgs)
assert cfg["a"] == 1
assert cfg["seed"] == -1
assert cfg["b"]["x"] == 0
assert cfg["b"]["y"] == 10
assert cfg["b"]["z"] == -2
assert cfg["hello"] == "world"


def test_multiple_includes():
cli_cfg = OmegaConf.create(
{
"config": [
os.path.join(FIXTURE_DIR, "top.yaml"),
os.path.join(FIXTURE_DIR, "override.yaml"),
],
"p": 500,
"b": {"z": -2},
}
)
parsed_cfgs = recursively_parse_config(cli_cfg)
assert len(parsed_cfgs) == 5
assert parsed_cfgs[0]["seed"] == -1
assert parsed_cfgs[1]["b"]["y"] == 10
assert parsed_cfgs[2]["hello"] == "world"
assert parsed_cfgs[3]["a"] == 100
assert parsed_cfgs[4]["p"] == 500
assert parsed_cfgs[4]["b"]["z"] == -2
cfg = OmegaConf.merge(*parsed_cfgs)
assert cfg["a"] == 100
assert cfg["seed"] == -1
assert cfg["b"]["x"] == 0
assert cfg["b"]["y"] == 10
assert cfg["b"]["z"] == -2
assert cfg["hello"] == "world"

cli_cfg = OmegaConf.create(
{
"config": [
os.path.join(FIXTURE_DIR, "top.yaml"),
os.path.join(FIXTURE_DIR, "override.yaml"),
],
"p": 500,
"b": {"z": -2},
"a": 1000,
}
)
parsed_cfgs = recursively_parse_config(cli_cfg)
assert len(parsed_cfgs) == 5
assert parsed_cfgs[0]["seed"] == -1
assert parsed_cfgs[1]["b"]["y"] == 10
assert parsed_cfgs[2]["hello"] == "world"
assert parsed_cfgs[3]["a"] == 100
assert parsed_cfgs[4]["p"] == 500
assert parsed_cfgs[4]["b"]["z"] == -2
cfg = OmegaConf.merge(*parsed_cfgs)
assert cfg["a"] == 1000
assert cfg["seed"] == -1
assert cfg["b"]["x"] == 0
assert cfg["b"]["y"] == 10
assert cfg["b"]["z"] == -2
assert cfg["hello"] == "world"


class SubConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
x: int = -100
y: int = -100
z: int = -5


class SampleConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
a: int = -100
seed: int = -100
b: SubConfig = SubConfig()
hello: str = ""
p: int = -100


def test_pydantic_parse():
cli_cfg = OmegaConf.create(
{
"config": [
os.path.join(FIXTURE_DIR, "top.yaml"),
os.path.join(FIXTURE_DIR, "override.yaml"),
],
"p": 500,
"a": 1000,
}
)
cfg = parse_args_to_pydantic_model(SampleConfig, cli_args=cli_cfg)
assert isinstance(cfg, SampleConfig)
assert cfg.a == 1000
assert cfg.p == 500
assert cfg.seed == -1
assert cfg.b.x == 0
assert cfg.b.y == 10
assert cfg.b.z == -5
assert cfg.hello == "world"
5 changes: 3 additions & 2 deletions bytelatent/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@
from torch.distributed.checkpoint.stateful import Stateful
from torch.optim import lr_scheduler

from bytelatent.args import TrainArgs, parse_args
from bytelatent.args import TrainArgs
from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint
from bytelatent.config_parser import parse_args_to_pydantic_model
from bytelatent.data.file_util import get_fs
from bytelatent.data.iterators.abstract_iterator import get_state_and_refresh
from bytelatent.data.iterators.multiprocess_iterator import (
Expand Down Expand Up @@ -824,7 +825,7 @@ class LMTransformerArgsgs:
Plus all the default values in TrainArgs dataclass.
"""
train_args = parse_args(TrainArgs)
train_args = parse_args_to_pydantic_model(TrainArgs)
if train_args.debug_dynamo:
import torch._dynamo

Expand Down
1 change: 1 addition & 0 deletions fixtures/test-cfgs/list.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[1, 2, 3]
3 changes: 3 additions & 0 deletions fixtures/test-cfgs/middle.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
config: fixtures/test-cfgs/root.yaml
b:
y: 10
1 change: 1 addition & 0 deletions fixtures/test-cfgs/override.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
a: 100
6 changes: 6 additions & 0 deletions fixtures/test-cfgs/root.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
seed: -1
a: 1
b:
x: 0
y: ???
z: ???
3 changes: 3 additions & 0 deletions fixtures/test-cfgs/top.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
config: fixtures/test-cfgs/middle.yaml

hello: world

0 comments on commit 3117ac1

Please sign in to comment.