diff --git a/bytelatent/args.py b/bytelatent/args.py index 47bd0f9..dd1fef5 100644 --- a/bytelatent/args.py +++ b/bytelatent/args.py @@ -5,7 +5,6 @@ import numpy as np import yaml -from omegaconf import OmegaConf from pydantic import BaseModel, ConfigDict from bytelatent.checkpoint import CheckpointArgs @@ -38,19 +37,6 @@ def get_rng_state(seed: int, rank: int, world_size: int) -> dict[str, Any]: return np.random.default_rng((seed, rank, world_size)).bit_generator.state -def parse_args(args_cls): - cli_args = OmegaConf.from_cli() - file_cfg = OmegaConf.load(cli_args.config) - # We remove 'config' attribute from config as the underlying DataClass does not have it - del cli_args.config - - default_cfg = OmegaConf.create(args_cls().model_dump()) - cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args) - cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) - pydantic_args = args_cls.model_validate(cfg) - return pydantic_args - - TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl" diff --git a/bytelatent/config_parser.py b/bytelatent/config_parser.py new file mode 100644 index 0000000..2e310a2 --- /dev/null +++ b/bytelatent/config_parser.py @@ -0,0 +1,73 @@ +import copy +from typing import Type, TypeVar + +import omegaconf +from omegaconf import DictConfig, OmegaConf +from pydantic import BaseModel + + +def parse_file_config(path: str) -> DictConfig: + file_cfg = OmegaConf.load(path) + if not isinstance(file_cfg, DictConfig): + raise ValueError( + f"File paths must parse to DictConfig, but it was: {type(file_cfg)}" + ) + return file_cfg + + +def recursively_parse_config(cfg: DictConfig) -> list[DictConfig]: + if "config" not in cfg: + return [cfg] + + ordered_cfgs = [] + cfg = copy.deepcopy(cfg) + config_arg = cfg["config"] + del cfg["config"] + ordered_cfgs.append(cfg) + + if isinstance(config_arg, str): + file_cfg = parse_file_config(config_arg) + sub_configs = recursively_parse_config(file_cfg) + ordered_cfgs = sub_configs + ordered_cfgs + elif isinstance(config_arg, omegaconf.listconfig.ListConfig): + sub_configs = [] + for c in config_arg: + if not isinstance(c, str): + raise ValueError( + f'If "config" is specified, it must be either a string path or a list of string paths. It was config={config_arg}' + ) + config_to_parse = parse_file_config(c) + sub_configs.extend(recursively_parse_config(config_to_parse)) + ordered_cfgs = sub_configs + ordered_cfgs + else: + raise ValueError( + f'If "config" is specified, it must be either a string path or a list of string paths, it was config={config_arg}' + ) + return ordered_cfgs + + +def parse_args_with_default( + *, default_cfg: DictConfig | None = None, cli_args: DictConfig | None = None +): + if cli_args is None: + cli_args = OmegaConf.from_cli() + assert isinstance( + cli_args, DictConfig + ), f"CLI Args must be a DictConfig, not {type(cli_args)}" + ordered_cfgs = recursively_parse_config(cli_args) + if default_cfg is not None: + ordered_cfgs.insert(0, default_cfg) + cfg = OmegaConf.merge(*ordered_cfgs) + return OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True) + + +T = TypeVar("T", bound=BaseModel) + + +def parse_args_to_pydantic_model( + args_cls: Type[T], cli_args: DictConfig | None = None +) -> T: + default_cfg = OmegaConf.create(args_cls().model_dump()) + parsed_cfg = parse_args_with_default(default_cfg=default_cfg, cli_args=cli_args) + pydantic_args = args_cls.model_validate(parsed_cfg) + return pydantic_args diff --git a/bytelatent/configs/debug.yaml b/bytelatent/configs/debug.yaml index 07d489f..1369364 100644 --- a/bytelatent/configs/debug.yaml +++ b/bytelatent/configs/debug.yaml @@ -56,13 +56,11 @@ model: recompute_attn: false custom_bwd: false layer_ckpt: "none" - patch_only_encoder: false - patch_only_decoder: false use_local_encoder_transformer: true init_use_gaussian: true init_use_depth: "current" - attn_bias_type: "block_causal" attn_impl: "xformers" + attn_bias_type: "block_causal" alpha_depth: "disabled" max_length: 256 local_attention_window_len: 512 diff --git a/bytelatent/configs/entropy_model.yaml b/bytelatent/configs/entropy_model.yaml index d7c27b7..79cc85b 100644 --- a/bytelatent/configs/entropy_model.yaml +++ b/bytelatent/configs/entropy_model.yaml @@ -2,9 +2,10 @@ # Evals can be activated by uncommenting its config # python -m launchers.stool config=apps/main/configs/debug.yaml nodes=8 account=fair_amaia_cw_codegen qos=lowest -dump_dir: /tmp/ +dump_dir: /tmp/blt-entropy name: "debug" steps: 100_000 +max_steps: null probe_freq: null seed: 777 optim: @@ -35,7 +36,6 @@ entropy_model: attn_impl: "xformers" data: - s3_profile: blt root_dir: ??? sources: dclm_baseline_1.0: 1.0 diff --git a/bytelatent/eval.py b/bytelatent/eval.py index ae73066..50e17cd 100644 --- a/bytelatent/eval.py +++ b/bytelatent/eval.py @@ -5,18 +5,15 @@ import os from collections import defaultdict from datetime import datetime -from pathlib import Path -from typing import Any import torch from lm_eval import simple_evaluate from lm_eval.api.instance import Instance from lm_eval.api.model import LM -from omegaconf import OmegaConf -from pydantic import BaseModel, ConfigDict -from bytelatent.args import EvalArgs, ValidationArgs, parse_args +from bytelatent.args import EvalArgs, ValidationArgs from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints +from bytelatent.config_parser import parse_args_to_pydantic_model from bytelatent.data.file_util import get_fs from bytelatent.distributed import ( DistributedArgs, @@ -29,7 +26,6 @@ PackedCausalTransformerGenerator, load_consolidated_model_and_tokenizer, ) -from bytelatent.transformer import LMTransformer, LMTransformerArgs EVAL_FOLDER_NAME = "{:010d}" diff --git a/bytelatent/print_config.py b/bytelatent/print_config.py new file mode 100644 index 0000000..0bc99e7 --- /dev/null +++ b/bytelatent/print_config.py @@ -0,0 +1,11 @@ +from bytelatent.args import TrainArgs +from bytelatent.config_parser import parse_args_to_pydantic_model + + +def main(): + train_args = parse_args_to_pydantic_model(TrainArgs) + print(train_args.model_dump_json(indent=4)) + + +if __name__ == "__main__": + main() diff --git a/bytelatent/test_config_parser.py b/bytelatent/test_config_parser.py new file mode 100644 index 0000000..c1ec99b --- /dev/null +++ b/bytelatent/test_config_parser.py @@ -0,0 +1,180 @@ +import os + +import pytest +from omegaconf import DictConfig, MissingMandatoryValue, OmegaConf +from pydantic import BaseModel, ConfigDict + +from bytelatent.config_parser import ( + parse_args_to_pydantic_model, + parse_file_config, + recursively_parse_config, +) + +FIXTURE_DIR = "fixtures/test-cfgs" + + +def test_parse_file_config(): + with pytest.raises(ValueError): + cfg = parse_file_config(os.path.join(FIXTURE_DIR, "list.yaml")) + assert isinstance(cfg, DictConfig) + + +def test_nop(): + cfg = OmegaConf.create({"a": 1}) + parsed_cfgs = recursively_parse_config(cfg) + assert len(parsed_cfgs) == 1 + assert parsed_cfgs[0] == cfg + + +def test_root(): + cli_cfg = OmegaConf.create({"config": os.path.join(FIXTURE_DIR, "root.yaml")}) + parsed_cfgs = recursively_parse_config(cli_cfg) + assert len(parsed_cfgs) == 2 + assert len(parsed_cfgs[1]) == 0 + assert parsed_cfgs[0]["seed"] == -1 + with pytest.raises(MissingMandatoryValue): + assert parsed_cfgs[0]["b"]["y"] is not None + + # Test basic cli override + cli_cfg = OmegaConf.create( + {"config": os.path.join(FIXTURE_DIR, "root.yaml"), "seed": 42} + ) + parsed_cfgs = recursively_parse_config(cli_cfg) + assert parsed_cfgs[1]["seed"] == 42 + cfg = OmegaConf.merge(*parsed_cfgs) + assert cfg["seed"] == 42 + + +def test_one_level_include(): + cli_cfg = OmegaConf.create({"config": os.path.join(FIXTURE_DIR, "middle.yaml")}) + parsed_cfgs = recursively_parse_config(cli_cfg) + assert len(parsed_cfgs) == 3 + assert parsed_cfgs[0]["seed"] == -1 + assert parsed_cfgs[1]["b"]["y"] == 10 + assert len(parsed_cfgs[2]) == 0 + cfg = OmegaConf.merge(*parsed_cfgs) + assert cfg["b"]["y"] == 10 + + cli_cfg = OmegaConf.create( + {"config": os.path.join(FIXTURE_DIR, "middle.yaml"), "b": {"y": 100}} + ) + parsed_cfgs = recursively_parse_config(cli_cfg) + assert len(parsed_cfgs) == 3 + assert parsed_cfgs[0]["seed"] == -1 + assert parsed_cfgs[1]["b"]["y"] == 10 + assert parsed_cfgs[2]["b"]["y"] == 100 + cfg = OmegaConf.merge(*parsed_cfgs) + assert cfg["b"]["y"] == 100 + + +def test_two_level_include(): + cli_cfg = OmegaConf.create( + {"config": os.path.join(FIXTURE_DIR, "top.yaml"), "p": 500, "b": {"z": -2}} + ) + parsed_cfgs = recursively_parse_config(cli_cfg) + assert len(parsed_cfgs) == 4 + assert parsed_cfgs[0]["seed"] == -1 + assert parsed_cfgs[1]["b"]["y"] == 10 + assert parsed_cfgs[2]["hello"] == "world" + assert parsed_cfgs[3]["p"] == 500 + assert parsed_cfgs[3]["b"]["z"] == -2 + cfg = OmegaConf.merge(*parsed_cfgs) + assert cfg["a"] == 1 + assert cfg["seed"] == -1 + assert cfg["b"]["x"] == 0 + assert cfg["b"]["y"] == 10 + assert cfg["b"]["z"] == -2 + assert cfg["hello"] == "world" + + +def test_multiple_includes(): + cli_cfg = OmegaConf.create( + { + "config": [ + os.path.join(FIXTURE_DIR, "top.yaml"), + os.path.join(FIXTURE_DIR, "override.yaml"), + ], + "p": 500, + "b": {"z": -2}, + } + ) + parsed_cfgs = recursively_parse_config(cli_cfg) + assert len(parsed_cfgs) == 5 + assert parsed_cfgs[0]["seed"] == -1 + assert parsed_cfgs[1]["b"]["y"] == 10 + assert parsed_cfgs[2]["hello"] == "world" + assert parsed_cfgs[3]["a"] == 100 + assert parsed_cfgs[4]["p"] == 500 + assert parsed_cfgs[4]["b"]["z"] == -2 + cfg = OmegaConf.merge(*parsed_cfgs) + assert cfg["a"] == 100 + assert cfg["seed"] == -1 + assert cfg["b"]["x"] == 0 + assert cfg["b"]["y"] == 10 + assert cfg["b"]["z"] == -2 + assert cfg["hello"] == "world" + + cli_cfg = OmegaConf.create( + { + "config": [ + os.path.join(FIXTURE_DIR, "top.yaml"), + os.path.join(FIXTURE_DIR, "override.yaml"), + ], + "p": 500, + "b": {"z": -2}, + "a": 1000, + } + ) + parsed_cfgs = recursively_parse_config(cli_cfg) + assert len(parsed_cfgs) == 5 + assert parsed_cfgs[0]["seed"] == -1 + assert parsed_cfgs[1]["b"]["y"] == 10 + assert parsed_cfgs[2]["hello"] == "world" + assert parsed_cfgs[3]["a"] == 100 + assert parsed_cfgs[4]["p"] == 500 + assert parsed_cfgs[4]["b"]["z"] == -2 + cfg = OmegaConf.merge(*parsed_cfgs) + assert cfg["a"] == 1000 + assert cfg["seed"] == -1 + assert cfg["b"]["x"] == 0 + assert cfg["b"]["y"] == 10 + assert cfg["b"]["z"] == -2 + assert cfg["hello"] == "world" + + +class SubConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + x: int = -100 + y: int = -100 + z: int = -5 + + +class SampleConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + a: int = -100 + seed: int = -100 + b: SubConfig = SubConfig() + hello: str = "" + p: int = -100 + + +def test_pydantic_parse(): + cli_cfg = OmegaConf.create( + { + "config": [ + os.path.join(FIXTURE_DIR, "top.yaml"), + os.path.join(FIXTURE_DIR, "override.yaml"), + ], + "p": 500, + "a": 1000, + } + ) + cfg = parse_args_to_pydantic_model(SampleConfig, cli_args=cli_cfg) + assert isinstance(cfg, SampleConfig) + assert cfg.a == 1000 + assert cfg.p == 500 + assert cfg.seed == -1 + assert cfg.b.x == 0 + assert cfg.b.y == 10 + assert cfg.b.z == -5 + assert cfg.hello == "world" diff --git a/bytelatent/train.py b/bytelatent/train.py index 3669167..ad74b44 100644 --- a/bytelatent/train.py +++ b/bytelatent/train.py @@ -23,8 +23,9 @@ from torch.distributed.checkpoint.stateful import Stateful from torch.optim import lr_scheduler -from bytelatent.args import TrainArgs, parse_args +from bytelatent.args import TrainArgs from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint +from bytelatent.config_parser import parse_args_to_pydantic_model from bytelatent.data.file_util import get_fs from bytelatent.data.iterators.abstract_iterator import get_state_and_refresh from bytelatent.data.iterators.multiprocess_iterator import ( @@ -824,7 +825,7 @@ class LMTransformerArgsgs: Plus all the default values in TrainArgs dataclass. """ - train_args = parse_args(TrainArgs) + train_args = parse_args_to_pydantic_model(TrainArgs) if train_args.debug_dynamo: import torch._dynamo diff --git a/fixtures/test-cfgs/list.yaml b/fixtures/test-cfgs/list.yaml new file mode 100644 index 0000000..b5d8bb5 --- /dev/null +++ b/fixtures/test-cfgs/list.yaml @@ -0,0 +1 @@ +[1, 2, 3] diff --git a/fixtures/test-cfgs/middle.yaml b/fixtures/test-cfgs/middle.yaml new file mode 100644 index 0000000..a476d8d --- /dev/null +++ b/fixtures/test-cfgs/middle.yaml @@ -0,0 +1,3 @@ +config: fixtures/test-cfgs/root.yaml +b: + y: 10 diff --git a/fixtures/test-cfgs/override.yaml b/fixtures/test-cfgs/override.yaml new file mode 100644 index 0000000..456df7b --- /dev/null +++ b/fixtures/test-cfgs/override.yaml @@ -0,0 +1 @@ +a: 100 diff --git a/fixtures/test-cfgs/root.yaml b/fixtures/test-cfgs/root.yaml new file mode 100644 index 0000000..dc4d285 --- /dev/null +++ b/fixtures/test-cfgs/root.yaml @@ -0,0 +1,6 @@ +seed: -1 +a: 1 +b: + x: 0 + y: ??? + z: ??? diff --git a/fixtures/test-cfgs/top.yaml b/fixtures/test-cfgs/top.yaml new file mode 100644 index 0000000..632866c --- /dev/null +++ b/fixtures/test-cfgs/top.yaml @@ -0,0 +1,3 @@ +config: fixtures/test-cfgs/middle.yaml + +hello: world