Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make it possible to specify multiple config files #54

Merged
merged 1 commit into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 0 additions & 14 deletions bytelatent/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import numpy as np
import yaml
from omegaconf import OmegaConf
from pydantic import BaseModel, ConfigDict

from bytelatent.checkpoint import CheckpointArgs
Expand Down Expand Up @@ -38,19 +37,6 @@ def get_rng_state(seed: int, rank: int, world_size: int) -> dict[str, Any]:
return np.random.default_rng((seed, rank, world_size)).bit_generator.state


def parse_args(args_cls):
cli_args = OmegaConf.from_cli()
file_cfg = OmegaConf.load(cli_args.config)
# We remove 'config' attribute from config as the underlying DataClass does not have it
del cli_args.config

default_cfg = OmegaConf.create(args_cls().model_dump())
cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args)
cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
pydantic_args = args_cls.model_validate(cfg)
return pydantic_args


TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl"


Expand Down
73 changes: 73 additions & 0 deletions bytelatent/config_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import copy
from typing import Type, TypeVar

import omegaconf
from omegaconf import DictConfig, OmegaConf
from pydantic import BaseModel


def parse_file_config(path: str) -> DictConfig:
file_cfg = OmegaConf.load(path)
if not isinstance(file_cfg, DictConfig):
raise ValueError(
f"File paths must parse to DictConfig, but it was: {type(file_cfg)}"
)
return file_cfg


def recursively_parse_config(cfg: DictConfig) -> list[DictConfig]:
if "config" not in cfg:
return [cfg]

ordered_cfgs = []
cfg = copy.deepcopy(cfg)
config_arg = cfg["config"]
del cfg["config"]
ordered_cfgs.append(cfg)

if isinstance(config_arg, str):
file_cfg = parse_file_config(config_arg)
sub_configs = recursively_parse_config(file_cfg)
ordered_cfgs = sub_configs + ordered_cfgs
elif isinstance(config_arg, omegaconf.listconfig.ListConfig):
sub_configs = []
for c in config_arg:
if not isinstance(c, str):
raise ValueError(
f'If "config" is specified, it must be either a string path or a list of string paths. It was config={config_arg}'
)
config_to_parse = parse_file_config(c)
sub_configs.extend(recursively_parse_config(config_to_parse))
ordered_cfgs = sub_configs + ordered_cfgs
else:
raise ValueError(
f'If "config" is specified, it must be either a string path or a list of string paths, it was config={config_arg}'
)
return ordered_cfgs


def parse_args_with_default(
*, default_cfg: DictConfig | None = None, cli_args: DictConfig | None = None
):
if cli_args is None:
cli_args = OmegaConf.from_cli()
assert isinstance(
cli_args, DictConfig
), f"CLI Args must be a DictConfig, not {type(cli_args)}"
ordered_cfgs = recursively_parse_config(cli_args)
if default_cfg is not None:
ordered_cfgs.insert(0, default_cfg)
cfg = OmegaConf.merge(*ordered_cfgs)
return OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)


T = TypeVar("T", bound=BaseModel)


def parse_args_to_pydantic_model(
args_cls: Type[T], cli_args: DictConfig | None = None
) -> T:
default_cfg = OmegaConf.create(args_cls().model_dump())
parsed_cfg = parse_args_with_default(default_cfg=default_cfg, cli_args=cli_args)
pydantic_args = args_cls.model_validate(parsed_cfg)
return pydantic_args
4 changes: 1 addition & 3 deletions bytelatent/configs/debug.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,11 @@ model:
recompute_attn: false
custom_bwd: false
layer_ckpt: "none"
patch_only_encoder: false
patch_only_decoder: false
use_local_encoder_transformer: true
init_use_gaussian: true
init_use_depth: "current"
attn_bias_type: "block_causal"
attn_impl: "xformers"
attn_bias_type: "block_causal"
alpha_depth: "disabled"
max_length: 256
local_attention_window_len: 512
Expand Down
4 changes: 2 additions & 2 deletions bytelatent/configs/entropy_model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
# Evals can be activated by uncommenting its config
# python -m launchers.stool config=apps/main/configs/debug.yaml nodes=8 account=fair_amaia_cw_codegen qos=lowest

dump_dir: /tmp/
dump_dir: /tmp/blt-entropy
name: "debug"
steps: 100_000
max_steps: null
probe_freq: null
seed: 777
optim:
Expand Down Expand Up @@ -35,7 +36,6 @@ entropy_model:
attn_impl: "xformers"

data:
s3_profile: blt
root_dir: ???
sources:
dclm_baseline_1.0: 1.0
Expand Down
8 changes: 2 additions & 6 deletions bytelatent/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,15 @@
import os
from collections import defaultdict
from datetime import datetime
from pathlib import Path
from typing import Any

import torch
from lm_eval import simple_evaluate
from lm_eval.api.instance import Instance
from lm_eval.api.model import LM
from omegaconf import OmegaConf
from pydantic import BaseModel, ConfigDict

from bytelatent.args import EvalArgs, ValidationArgs, parse_args
from bytelatent.args import EvalArgs, ValidationArgs
from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints
from bytelatent.config_parser import parse_args_to_pydantic_model
from bytelatent.data.file_util import get_fs
from bytelatent.distributed import (
DistributedArgs,
Expand All @@ -29,7 +26,6 @@
PackedCausalTransformerGenerator,
load_consolidated_model_and_tokenizer,
)
from bytelatent.transformer import LMTransformer, LMTransformerArgs

EVAL_FOLDER_NAME = "{:010d}"

Expand Down
11 changes: 11 additions & 0 deletions bytelatent/print_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from bytelatent.args import TrainArgs
from bytelatent.config_parser import parse_args_to_pydantic_model


def main():
train_args = parse_args_to_pydantic_model(TrainArgs)
print(train_args.model_dump_json(indent=4))


if __name__ == "__main__":
main()
180 changes: 180 additions & 0 deletions bytelatent/test_config_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
import os

import pytest
from omegaconf import DictConfig, MissingMandatoryValue, OmegaConf
from pydantic import BaseModel, ConfigDict

from bytelatent.config_parser import (
parse_args_to_pydantic_model,
parse_file_config,
recursively_parse_config,
)

FIXTURE_DIR = "fixtures/test-cfgs"


def test_parse_file_config():
with pytest.raises(ValueError):
cfg = parse_file_config(os.path.join(FIXTURE_DIR, "list.yaml"))
assert isinstance(cfg, DictConfig)


def test_nop():
cfg = OmegaConf.create({"a": 1})
parsed_cfgs = recursively_parse_config(cfg)
assert len(parsed_cfgs) == 1
assert parsed_cfgs[0] == cfg


def test_root():
cli_cfg = OmegaConf.create({"config": os.path.join(FIXTURE_DIR, "root.yaml")})
parsed_cfgs = recursively_parse_config(cli_cfg)
assert len(parsed_cfgs) == 2
assert len(parsed_cfgs[1]) == 0
assert parsed_cfgs[0]["seed"] == -1
with pytest.raises(MissingMandatoryValue):
assert parsed_cfgs[0]["b"]["y"] is not None

# Test basic cli override
cli_cfg = OmegaConf.create(
{"config": os.path.join(FIXTURE_DIR, "root.yaml"), "seed": 42}
)
parsed_cfgs = recursively_parse_config(cli_cfg)
assert parsed_cfgs[1]["seed"] == 42
cfg = OmegaConf.merge(*parsed_cfgs)
assert cfg["seed"] == 42


def test_one_level_include():
cli_cfg = OmegaConf.create({"config": os.path.join(FIXTURE_DIR, "middle.yaml")})
parsed_cfgs = recursively_parse_config(cli_cfg)
assert len(parsed_cfgs) == 3
assert parsed_cfgs[0]["seed"] == -1
assert parsed_cfgs[1]["b"]["y"] == 10
assert len(parsed_cfgs[2]) == 0
cfg = OmegaConf.merge(*parsed_cfgs)
assert cfg["b"]["y"] == 10

cli_cfg = OmegaConf.create(
{"config": os.path.join(FIXTURE_DIR, "middle.yaml"), "b": {"y": 100}}
)
parsed_cfgs = recursively_parse_config(cli_cfg)
assert len(parsed_cfgs) == 3
assert parsed_cfgs[0]["seed"] == -1
assert parsed_cfgs[1]["b"]["y"] == 10
assert parsed_cfgs[2]["b"]["y"] == 100
cfg = OmegaConf.merge(*parsed_cfgs)
assert cfg["b"]["y"] == 100


def test_two_level_include():
cli_cfg = OmegaConf.create(
{"config": os.path.join(FIXTURE_DIR, "top.yaml"), "p": 500, "b": {"z": -2}}
)
parsed_cfgs = recursively_parse_config(cli_cfg)
assert len(parsed_cfgs) == 4
assert parsed_cfgs[0]["seed"] == -1
assert parsed_cfgs[1]["b"]["y"] == 10
assert parsed_cfgs[2]["hello"] == "world"
assert parsed_cfgs[3]["p"] == 500
assert parsed_cfgs[3]["b"]["z"] == -2
cfg = OmegaConf.merge(*parsed_cfgs)
assert cfg["a"] == 1
assert cfg["seed"] == -1
assert cfg["b"]["x"] == 0
assert cfg["b"]["y"] == 10
assert cfg["b"]["z"] == -2
assert cfg["hello"] == "world"


def test_multiple_includes():
cli_cfg = OmegaConf.create(
{
"config": [
os.path.join(FIXTURE_DIR, "top.yaml"),
os.path.join(FIXTURE_DIR, "override.yaml"),
],
"p": 500,
"b": {"z": -2},
}
)
parsed_cfgs = recursively_parse_config(cli_cfg)
assert len(parsed_cfgs) == 5
assert parsed_cfgs[0]["seed"] == -1
assert parsed_cfgs[1]["b"]["y"] == 10
assert parsed_cfgs[2]["hello"] == "world"
assert parsed_cfgs[3]["a"] == 100
assert parsed_cfgs[4]["p"] == 500
assert parsed_cfgs[4]["b"]["z"] == -2
cfg = OmegaConf.merge(*parsed_cfgs)
assert cfg["a"] == 100
assert cfg["seed"] == -1
assert cfg["b"]["x"] == 0
assert cfg["b"]["y"] == 10
assert cfg["b"]["z"] == -2
assert cfg["hello"] == "world"

cli_cfg = OmegaConf.create(
{
"config": [
os.path.join(FIXTURE_DIR, "top.yaml"),
os.path.join(FIXTURE_DIR, "override.yaml"),
],
"p": 500,
"b": {"z": -2},
"a": 1000,
}
)
parsed_cfgs = recursively_parse_config(cli_cfg)
assert len(parsed_cfgs) == 5
assert parsed_cfgs[0]["seed"] == -1
assert parsed_cfgs[1]["b"]["y"] == 10
assert parsed_cfgs[2]["hello"] == "world"
assert parsed_cfgs[3]["a"] == 100
assert parsed_cfgs[4]["p"] == 500
assert parsed_cfgs[4]["b"]["z"] == -2
cfg = OmegaConf.merge(*parsed_cfgs)
assert cfg["a"] == 1000
assert cfg["seed"] == -1
assert cfg["b"]["x"] == 0
assert cfg["b"]["y"] == 10
assert cfg["b"]["z"] == -2
assert cfg["hello"] == "world"


class SubConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
x: int = -100
y: int = -100
z: int = -5


class SampleConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
a: int = -100
seed: int = -100
b: SubConfig = SubConfig()
hello: str = ""
p: int = -100


def test_pydantic_parse():
cli_cfg = OmegaConf.create(
{
"config": [
os.path.join(FIXTURE_DIR, "top.yaml"),
os.path.join(FIXTURE_DIR, "override.yaml"),
],
"p": 500,
"a": 1000,
}
)
cfg = parse_args_to_pydantic_model(SampleConfig, cli_args=cli_cfg)
assert isinstance(cfg, SampleConfig)
assert cfg.a == 1000
assert cfg.p == 500
assert cfg.seed == -1
assert cfg.b.x == 0
assert cfg.b.y == 10
assert cfg.b.z == -5
assert cfg.hello == "world"
5 changes: 3 additions & 2 deletions bytelatent/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@
from torch.distributed.checkpoint.stateful import Stateful
from torch.optim import lr_scheduler

from bytelatent.args import TrainArgs, parse_args
from bytelatent.args import TrainArgs
from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint
from bytelatent.config_parser import parse_args_to_pydantic_model
from bytelatent.data.file_util import get_fs
from bytelatent.data.iterators.abstract_iterator import get_state_and_refresh
from bytelatent.data.iterators.multiprocess_iterator import (
Expand Down Expand Up @@ -824,7 +825,7 @@ class LMTransformerArgsgs:

Plus all the default values in TrainArgs dataclass.
"""
train_args = parse_args(TrainArgs)
train_args = parse_args_to_pydantic_model(TrainArgs)
if train_args.debug_dynamo:
import torch._dynamo

Expand Down
1 change: 1 addition & 0 deletions fixtures/test-cfgs/list.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[1, 2, 3]
3 changes: 3 additions & 0 deletions fixtures/test-cfgs/middle.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
config: fixtures/test-cfgs/root.yaml
b:
y: 10
1 change: 1 addition & 0 deletions fixtures/test-cfgs/override.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
a: 100
6 changes: 6 additions & 0 deletions fixtures/test-cfgs/root.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
seed: -1
a: 1
b:
x: 0
y: ???
z: ???
3 changes: 3 additions & 0 deletions fixtures/test-cfgs/top.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
config: fixtures/test-cfgs/middle.yaml

hello: world