diff --git a/docs/source/tune_cli.rst b/docs/source/tune_cli.rst index 5981fd8b80..0bb1b0c424 100644 --- a/docs/source/tune_cli.rst +++ b/docs/source/tune_cli.rst @@ -17,7 +17,7 @@ with a short description of each. .. code-block:: bash $ tune --help - usage: tune [-h] {download,ls,cp,run,validate} ... + usage: tune [-h] {download,ls,cp,run,validate,cat} ... Welcome to the torchtune CLI! @@ -25,7 +25,7 @@ with a short description of each. -h, --help show this help message and exit subcommands: - {download,ls,cp,run,validate} + {download,ls,cp,run,validate,cat} download Download a model from the Hugging Face Hub. ls List all built-in recipes and configs ... @@ -233,3 +233,72 @@ The ``tune validate `` command will validate that your config is formatt # If you've copied over a built-in config and want to validate custom changes $ tune validate my_configs/llama3/8B_full.yaml Config is well-formed! + +.. _tune_cat_cli_label: + +Inspect a config +--------------------- + +The ``tune cat `` command pretty prints a configuration file, making it easy to use ``tune run`` with confidence. This command is useful for inspecting the structure and contents of a config file before running a recipe, ensuring that all parameters are correctly set. + +You can also use the ``--sort`` option to print the config in sorted order, which can help in quickly locating specific keys. + +.. list-table:: + :widths: 30 60 + + * - \--sort + - Print the config in sorted order. + +**Workflow Example** + +1. **List all available configs:** + + Use the ``tune ls`` command to list all the built-in recipes and configs within torchtune. + + .. code-block:: bash + + $ tune ls + RECIPE CONFIG + full_finetune_single_device llama2/7B_full_low_memory + code_llama2/7B_full_low_memory + llama3/8B_full_single_device + mistral/7B_full_low_memory + phi3/mini_full_low_memory + full_finetune_distributed llama2/7B_full + llama2/13B_full + llama3/8B_full + llama3/70B_full + ... + +2. **Inspect the contents of a config:** + + Use the ``tune cat`` command to pretty print the contents of a specific config. This helps you understand the structure and parameters of the config. + + .. code-block:: bash + + $ tune cat llama2/7B_full + output_dir: /tmp/torchtune/llama2_7B/full + tokenizer: + _component_: torchtune.models.llama2.llama2_tokenizer + path: /tmp/Llama-2-7b-hf/tokenizer.model + max_seq_len: null + ... + + You can also print the config in sorted order: + + .. code-block:: bash + + $ tune cat llama2/7B_full --sort + +3. **Run a recipe with parameter override:** + + After inspecting the config, you can use the ``tune run`` command to run a recipe with the config. You can also override specific parameters directly from the command line. For example, to override the `output_dir` parameter: + + .. code-block:: bash + + $ tune run full_finetune_distributed --config llama2/7B_full output_dir=./ + + Learn more about config overrides :ref:`here `. + +.. note:: + You can find all the cat-able configs via the ``tune ls`` command. diff --git a/tests/torchtune/_cli/test_cat.py b/tests/torchtune/_cli/test_cat.py new file mode 100644 index 0000000000..4ab310afa2 --- /dev/null +++ b/tests/torchtune/_cli/test_cat.py @@ -0,0 +1,81 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import runpy +import sys + +import pytest +from tests.common import TUNE_PATH + + +class TestTuneCatCommand: + """This class tests the `tune cat` command.""" + + def test_cat_valid_config(self, capsys, monkeypatch): + testargs = "tune cat llama2/7B_full".split() + monkeypatch.setattr(sys, "argv", testargs) + runpy.run_path(TUNE_PATH, run_name="__main__") + + captured = capsys.readouterr() + output = captured.out.rstrip("\n") + + # Check for key sections that should be in the YAML output + assert "output_dir:" in output + assert "tokenizer:" in output + assert "model:" in output + + def test_cat_recipe_name_shows_error(self, capsys, monkeypatch): + testargs = "tune cat full_finetune_single_device".split() + monkeypatch.setattr(sys, "argv", testargs) + runpy.run_path(TUNE_PATH, run_name="__main__") + + captured = capsys.readouterr() + output = captured.out.rstrip("\n") + + assert "is a recipe, not a config" in output + + def test_cat_non_existent_config(self, capsys, monkeypatch): + testargs = "tune cat non_existent_config".split() + monkeypatch.setattr(sys, "argv", testargs) + + with pytest.raises(SystemExit): + runpy.run_path(TUNE_PATH, run_name="__main__") + + captured = capsys.readouterr() + err = captured.err.rstrip("\n") + + assert ( + "Invalid config format: 'non_existent_config'. Must be YAML (.yaml/.yml)" + in err + ) + + def test_cat_invalid_yaml_file(self, capsys, monkeypatch, tmpdir): + invalid_yaml = tmpdir / "invalid.yaml" + invalid_yaml.write_text("invalid: yaml: file", encoding="utf-8") + + testargs = f"tune cat {invalid_yaml}".split() + monkeypatch.setattr(sys, "argv", testargs) + + with pytest.raises(SystemExit): + runpy.run_path(TUNE_PATH, run_name="__main__") + + captured = capsys.readouterr() + err = captured.err.rstrip("\n") + + assert "Error parsing YAML file" in err + + def test_cat_external_yaml_file(self, capsys, monkeypatch, tmpdir): + valid_yaml = tmpdir / "external.yaml" + valid_yaml.write_text("key: value", encoding="utf-8") + + testargs = f"tune cat {valid_yaml}".split() + monkeypatch.setattr(sys, "argv", testargs) + runpy.run_path(TUNE_PATH, run_name="__main__") + + captured = capsys.readouterr() + output = captured.out.rstrip("\n") + + assert "key: value" in output diff --git a/torchtune/_cli/cat.py b/torchtune/_cli/cat.py new file mode 100644 index 0000000000..fb575a3aee --- /dev/null +++ b/torchtune/_cli/cat.py @@ -0,0 +1,128 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import textwrap + +from pathlib import Path +from typing import List, Optional + +import yaml +from torchtune._cli.subcommand import Subcommand +from torchtune._recipe_registry import Config, get_all_recipes + +ROOT = Path(__file__).parent.parent.parent + + +class Cat(Subcommand): + """Holds all the logic for the `tune cat` subcommand.""" + + def __init__(self, subparsers: argparse._SubParsersAction): + super().__init__() + self._parser = subparsers.add_parser( + "cat", + prog="tune cat", + help="Pretty print a config, making it easy to know which parameters you can override with `tune run`.", + description="Pretty print a config, making it easy to know which parameters you can override with `tune run`.", + epilog=textwrap.dedent( + """\ + examples: + $ tune cat llama2/7B_full + output_dir: /tmp/torchtune/llama2_7B/full + tokenizer: + _component_: torchtune.models.llama2.llama2_tokenizer + path: /tmp/Llama-2-7b-hf/tokenizer.model + max_seq_len: null + ... + + # Pretty print the config in sorted order + $ tune cat llama2/7B_full --sort + + # Pretty print the contents of LOCALFILE.yaml + $ tune cat LOCALFILE.yaml + + You can now easily override a key based on your findings from `tune cat`: + $ tune run full_finetune_distributed --config llama2/7B_full output_dir=./ + + Need to find all the "cat"-able configs? Try `tune ls`! + """ + ), + formatter_class=argparse.RawTextHelpFormatter, + ) + self._parser.add_argument( + "config_name", type=str, help="Name of the config to print" + ) + self._parser.set_defaults(func=self._cat_cmd) + self._parser.add_argument( + "--sort", action="store_true", help="Print the config in sorted order" + ) + + def _get_all_recipes(self) -> List[str]: + return [recipe.name for recipe in get_all_recipes()] + + def _get_config(self, config_str: str) -> Optional[Config]: + # Search through all recipes + for recipe in get_all_recipes(): + for config in recipe.configs: + if config.name == config_str: + return config + + def _print_yaml_file(self, file: str, sort_keys: bool) -> None: + try: + with open(file, "r") as f: + data = yaml.safe_load(f) + if data: + print( + yaml.dump( + data, + default_flow_style=False, + sort_keys=sort_keys, + indent=4, + width=80, + allow_unicode=True, + ), + end="", + ) + except yaml.YAMLError as e: + self._parser.error(f"Error parsing YAML file: {e}") + + def _cat_cmd(self, args: argparse.Namespace) -> None: + """Display the contents of a configuration file. + + Handles both predefined configurations and direct file paths, ensuring: + - Input is not a recipe name + - File exists + - File is YAML format + + Args: + args (argparse.Namespace): Command-line arguments containing 'config_name' attribute + """ + config_str = args.config_name + + # Immediately handle recipe name case + if config_str in self._get_all_recipes(): + print( + f"'{config_str}' is a recipe, not a config. Please use a config name." + ) + return + + # Resolve config path + config = self._get_config(config_str) + if config: + config_path = ROOT / "recipes" / "configs" / config.file_path + else: + config_path = Path(config_str) + if config_path.suffix.lower() not in {".yaml", ".yml"}: + self._parser.error( + f"Invalid config format: '{config_path}'. Must be YAML (.yaml/.yml)" + ) + return + + if not config_path.exists(): + self._parser.error(f"Config '{config_str}' not found.") + return + + self._print_yaml_file(str(config_path), args.sort) diff --git a/torchtune/_cli/tune.py b/torchtune/_cli/tune.py index ea9c58bbe1..13d3bbae0b 100644 --- a/torchtune/_cli/tune.py +++ b/torchtune/_cli/tune.py @@ -6,6 +6,8 @@ import argparse +from torchtune._cli.cat import Cat + from torchtune._cli.cp import Copy from torchtune._cli.download import Download from torchtune._cli.ls import List @@ -33,6 +35,7 @@ def __init__(self): Copy.create(subparsers) Run.create(subparsers) Validate.create(subparsers) + Cat.create(subparsers) def parse_args(self) -> argparse.Namespace: """Parse CLI arguments"""