Skip to content

Commit

Permalink
Cleaning up --help: Artifact Management Subcommands (pytorch#889)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jack-Khuu authored and malfet committed Jul 17, 2024
1 parent 6e067f3 commit d1ae6b6
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 22 deletions.
62 changes: 53 additions & 9 deletions cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,35 @@
).expanduser()


KNOWN_VERBS = ["chat", "browser", "download", "generate", "eval", "export", "list", "remove", "where"]
# Subcommands related to downloading and managing model artifacts
INVENTORY_VERBS = ["download", "list", "remove", "where"]

# List of all supported subcommands in torchchat
KNOWN_VERBS = ["chat", "browser", "generate", "eval", "export"] + INVENTORY_VERBS

# Handle CLI arguments that are common to a majority of subcommands.
def check_args(args, verb: str) -> None:
# Handle model download. Skip this for download, since it has slightly
# different semantics.
if (
verb not in ["download", "list", "remove"]
verb not in INVENTORY_VERBS
and args.model
and not is_model_downloaded(args.model, args.model_directory)
):
download_and_convert(args.model, args.model_directory, args.hf_token)


def add_arguments_for_verb(parser, verb: str):
def add_arguments_for_verb(parser, verb: str) -> None:
# Model specification. TODO Simplify this.
# A model can be specified using a positional model name or HuggingFace
# path. Alternatively, the model can be specified via --gguf-path or via
# an explicit --checkpoint-dir, --checkpoint-path, or --tokenizer-path.

if verb in INVENTORY_VERBS:
_configure_artifact_inventory_args(parser, verb)
_add_cli_metadata_args(parser)
return

parser.add_argument(
"model",
type=str,
Expand Down Expand Up @@ -191,12 +201,6 @@ def add_arguments_for_verb(parser, verb: str):
choices=allowable_dtype_names(),
help="Override the dtype of the model (default is the checkpoint dtype). Options: bf16, fp16, fp32, fast16, fast",
)
parser.add_argument(
"-v",
"--verbose",
action="store_true",
help="Verbose output",
)
parser.add_argument(
"--quantize",
type=str,
Expand Down Expand Up @@ -252,6 +256,46 @@ def add_arguments_for_verb(parser, verb: str):
default=5000,
help="Port for the web server in browser mode",
)
_add_cli_metadata_args(parser)


# Add CLI Args that are relevant to any subcommand execution
def _add_cli_metadata_args(parser) -> None:
parser.add_argument(
"-v",
"--verbose",
action="store_true",
help="Verbose output",
)


# Configure CLI Args specific to Model Artifact Management
def _configure_artifact_inventory_args(parser, verb: str) -> None:
if verb in ["download", "remove", "where"]:
parser.add_argument(
"model",
type=str,
nargs="?",
default=None,
help="Model name for well-known models",
)

if verb in INVENTORY_VERBS:
parser.add_argument(
"--model-directory",
type=Path,
default=default_model_dir,
help=f"The directory to store downloaded model artifacts. Default: {default_model_dir}",
)

if verb == "download":
parser.add_argument(
"--hf-token",
type=str,
default=None,
help="A HuggingFace API token to use when downloading model artifacts",
)


# Add CLI Args specific to Model Evaluation
def _add_evaluation_args(parser) -> None:
Expand Down
7 changes: 0 additions & 7 deletions download.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,6 @@ def is_model_downloaded(model: str, models_dir: Path) -> bool:

# Subcommand to list available models.
def list_main(args) -> None:
# TODO It would be nice to have argparse validate this. However, we have
# model as an optional named parameter for all subcommands, so we'd
# probably need to move it to be registered per-command.
if args.model:
print("Usage: torchchat.py list")
return

model_configs = load_model_configs()

# Build the table in-memory so that we can align the text nicely.
Expand Down
17 changes: 11 additions & 6 deletions torchchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from cli import (
add_arguments_for_verb,
KNOWN_VERBS,
INVENTORY_VERBS,
arg_init,
check_args,
)
Expand Down Expand Up @@ -49,7 +50,11 @@

# Now parse the arguments
args = parser.parse_args()
args = arg_init(args)

# Don't initialize for Inventory management subcommands
# TODO: Remove when arg_init is refactored
if args.command not in INVENTORY_VERBS:
args = arg_init(args)
logging.basicConfig(
format="%(message)s", level=logging.DEBUG if args.verbose else logging.INFO
)
Expand All @@ -70,11 +75,6 @@
from browser.browser import main as browser_main

browser_main(args)
elif args.command == "download":
check_args(args, "download")
from download import download_main

download_main(args)
elif args.command == "generate":
check_args(args, "generate")
from generate import main as generate_main
Expand All @@ -89,6 +89,11 @@
from export import main as export_main

export_main(args)
elif args.command == "download":
check_args(args, "download")
from download import download_main

download_main(args)
elif args.command == "list":
check_args(args, "list")
from download import list_main
Expand Down

0 comments on commit d1ae6b6

Please sign in to comment.