Skip to content

Commit

Permalink
'tune cat' command for pretty printing configuration files (#2298)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ankur-singh authored Jan 30, 2025
1 parent 6764618 commit e6b9064
Show file tree
Hide file tree
Showing 4 changed files with 283 additions and 2 deletions.
73 changes: 71 additions & 2 deletions docs/source/tune_cli.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ with a short description of each.
.. code-block:: bash
$ tune --help
usage: tune [-h] {download,ls,cp,run,validate} ...
usage: tune [-h] {download,ls,cp,run,validate,cat} ...
Welcome to the torchtune CLI!
options:
-h, --help show this help message and exit
subcommands:
{download,ls,cp,run,validate}
{download,ls,cp,run,validate,cat}
download Download a model from the Hugging Face Hub.
ls List all built-in recipes and configs
...
Expand Down Expand Up @@ -233,3 +233,72 @@ The ``tune validate <config>`` command will validate that your config is formatt
# If you've copied over a built-in config and want to validate custom changes
$ tune validate my_configs/llama3/8B_full.yaml
Config is well-formed!
.. _tune_cat_cli_label:

Inspect a config
---------------------

The ``tune cat <config>`` command pretty prints a configuration file, making it easy to use ``tune run`` with confidence. This command is useful for inspecting the structure and contents of a config file before running a recipe, ensuring that all parameters are correctly set.

You can also use the ``--sort`` option to print the config in sorted order, which can help in quickly locating specific keys.

.. list-table::
:widths: 30 60

* - \--sort
- Print the config in sorted order.

**Workflow Example**

1. **List all available configs:**

Use the ``tune ls`` command to list all the built-in recipes and configs within torchtune.

.. code-block:: bash
$ tune ls
RECIPE CONFIG
full_finetune_single_device llama2/7B_full_low_memory
code_llama2/7B_full_low_memory
llama3/8B_full_single_device
mistral/7B_full_low_memory
phi3/mini_full_low_memory
full_finetune_distributed llama2/7B_full
llama2/13B_full
llama3/8B_full
llama3/70B_full
...
2. **Inspect the contents of a config:**

Use the ``tune cat`` command to pretty print the contents of a specific config. This helps you understand the structure and parameters of the config.

.. code-block:: bash
$ tune cat llama2/7B_full
output_dir: /tmp/torchtune/llama2_7B/full
tokenizer:
_component_: torchtune.models.llama2.llama2_tokenizer
path: /tmp/Llama-2-7b-hf/tokenizer.model
max_seq_len: null
...
You can also print the config in sorted order:

.. code-block:: bash
$ tune cat llama2/7B_full --sort
3. **Run a recipe with parameter override:**

After inspecting the config, you can use the ``tune run`` command to run a recipe with the config. You can also override specific parameters directly from the command line. For example, to override the `output_dir` parameter:

.. code-block:: bash
$ tune run full_finetune_distributed --config llama2/7B_full output_dir=./
Learn more about config overrides :ref:`here <cli_override>`.

.. note::
You can find all the cat-able configs via the ``tune ls`` command.
81 changes: 81 additions & 0 deletions tests/torchtune/_cli/test_cat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import runpy
import sys

import pytest
from tests.common import TUNE_PATH


class TestTuneCatCommand:
"""This class tests the `tune cat` command."""

def test_cat_valid_config(self, capsys, monkeypatch):
testargs = "tune cat llama2/7B_full".split()
monkeypatch.setattr(sys, "argv", testargs)
runpy.run_path(TUNE_PATH, run_name="__main__")

captured = capsys.readouterr()
output = captured.out.rstrip("\n")

# Check for key sections that should be in the YAML output
assert "output_dir:" in output
assert "tokenizer:" in output
assert "model:" in output

def test_cat_recipe_name_shows_error(self, capsys, monkeypatch):
testargs = "tune cat full_finetune_single_device".split()
monkeypatch.setattr(sys, "argv", testargs)
runpy.run_path(TUNE_PATH, run_name="__main__")

captured = capsys.readouterr()
output = captured.out.rstrip("\n")

assert "is a recipe, not a config" in output

def test_cat_non_existent_config(self, capsys, monkeypatch):
testargs = "tune cat non_existent_config".split()
monkeypatch.setattr(sys, "argv", testargs)

with pytest.raises(SystemExit):
runpy.run_path(TUNE_PATH, run_name="__main__")

captured = capsys.readouterr()
err = captured.err.rstrip("\n")

assert (
"Invalid config format: 'non_existent_config'. Must be YAML (.yaml/.yml)"
in err
)

def test_cat_invalid_yaml_file(self, capsys, monkeypatch, tmpdir):
invalid_yaml = tmpdir / "invalid.yaml"
invalid_yaml.write_text("invalid: yaml: file", encoding="utf-8")

testargs = f"tune cat {invalid_yaml}".split()
monkeypatch.setattr(sys, "argv", testargs)

with pytest.raises(SystemExit):
runpy.run_path(TUNE_PATH, run_name="__main__")

captured = capsys.readouterr()
err = captured.err.rstrip("\n")

assert "Error parsing YAML file" in err

def test_cat_external_yaml_file(self, capsys, monkeypatch, tmpdir):
valid_yaml = tmpdir / "external.yaml"
valid_yaml.write_text("key: value", encoding="utf-8")

testargs = f"tune cat {valid_yaml}".split()
monkeypatch.setattr(sys, "argv", testargs)
runpy.run_path(TUNE_PATH, run_name="__main__")

captured = capsys.readouterr()
output = captured.out.rstrip("\n")

assert "key: value" in output
128 changes: 128 additions & 0 deletions torchtune/_cli/cat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import textwrap

from pathlib import Path
from typing import List, Optional

import yaml
from torchtune._cli.subcommand import Subcommand
from torchtune._recipe_registry import Config, get_all_recipes

ROOT = Path(__file__).parent.parent.parent


class Cat(Subcommand):
"""Holds all the logic for the `tune cat` subcommand."""

def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self._parser = subparsers.add_parser(
"cat",
prog="tune cat",
help="Pretty print a config, making it easy to know which parameters you can override with `tune run`.",
description="Pretty print a config, making it easy to know which parameters you can override with `tune run`.",
epilog=textwrap.dedent(
"""\
examples:
$ tune cat llama2/7B_full
output_dir: /tmp/torchtune/llama2_7B/full
tokenizer:
_component_: torchtune.models.llama2.llama2_tokenizer
path: /tmp/Llama-2-7b-hf/tokenizer.model
max_seq_len: null
...
# Pretty print the config in sorted order
$ tune cat llama2/7B_full --sort
# Pretty print the contents of LOCALFILE.yaml
$ tune cat LOCALFILE.yaml
You can now easily override a key based on your findings from `tune cat`:
$ tune run full_finetune_distributed --config llama2/7B_full output_dir=./
Need to find all the "cat"-able configs? Try `tune ls`!
"""
),
formatter_class=argparse.RawTextHelpFormatter,
)
self._parser.add_argument(
"config_name", type=str, help="Name of the config to print"
)
self._parser.set_defaults(func=self._cat_cmd)
self._parser.add_argument(
"--sort", action="store_true", help="Print the config in sorted order"
)

def _get_all_recipes(self) -> List[str]:
return [recipe.name for recipe in get_all_recipes()]

def _get_config(self, config_str: str) -> Optional[Config]:
# Search through all recipes
for recipe in get_all_recipes():
for config in recipe.configs:
if config.name == config_str:
return config

def _print_yaml_file(self, file: str, sort_keys: bool) -> None:
try:
with open(file, "r") as f:
data = yaml.safe_load(f)
if data:
print(
yaml.dump(
data,
default_flow_style=False,
sort_keys=sort_keys,
indent=4,
width=80,
allow_unicode=True,
),
end="",
)
except yaml.YAMLError as e:
self._parser.error(f"Error parsing YAML file: {e}")

def _cat_cmd(self, args: argparse.Namespace) -> None:
"""Display the contents of a configuration file.
Handles both predefined configurations and direct file paths, ensuring:
- Input is not a recipe name
- File exists
- File is YAML format
Args:
args (argparse.Namespace): Command-line arguments containing 'config_name' attribute
"""
config_str = args.config_name

# Immediately handle recipe name case
if config_str in self._get_all_recipes():
print(
f"'{config_str}' is a recipe, not a config. Please use a config name."
)
return

# Resolve config path
config = self._get_config(config_str)
if config:
config_path = ROOT / "recipes" / "configs" / config.file_path
else:
config_path = Path(config_str)
if config_path.suffix.lower() not in {".yaml", ".yml"}:
self._parser.error(
f"Invalid config format: '{config_path}'. Must be YAML (.yaml/.yml)"
)
return

if not config_path.exists():
self._parser.error(f"Config '{config_str}' not found.")
return

self._print_yaml_file(str(config_path), args.sort)
3 changes: 3 additions & 0 deletions torchtune/_cli/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import argparse

from torchtune._cli.cat import Cat

from torchtune._cli.cp import Copy
from torchtune._cli.download import Download
from torchtune._cli.ls import List
Expand Down Expand Up @@ -33,6 +35,7 @@ def __init__(self):
Copy.create(subparsers)
Run.create(subparsers)
Validate.create(subparsers)
Cat.create(subparsers)

def parse_args(self) -> argparse.Namespace:
"""Parse CLI arguments"""
Expand Down

0 comments on commit e6b9064

Please sign in to comment.