diff --git a/docs/usage.rst b/docs/usage.rst index 4e1f16b2..3a40ce58 100644 --- a/docs/usage.rst +++ b/docs/usage.rst @@ -29,9 +29,21 @@ Before running the benchmarks 2. Set the ``$MILABENCH_CONFIG`` environment variable to the configuration file that represents the benchmark suite you want to run. Normally it should be set to ``config/standard.yaml``. -3. ``milabench install``: Install the individual benchmarks in virtual environments. +3. Setup huggingface access -4. ``milabench prepare``: Download the datasets, weights, etc. + 1. Request access to gated models + + - `Llama-2-7b `_ + - `Llama-3.1-8B `_ + - `Llama-3.1-70B `_ + + 2. Create a new `read token `_ to download the models + + 3. Add the token to your environment ``export MILABENCH_HF_TOKEN={your_token}`` + +4. ``milabench install``: Install the individual benchmarks in virtual environments. + +5. ``milabench prepare``: Download the datasets, weights, etc. If the machine has both NVIDIA/CUDA and AMD/ROCm GPUs, you may have to set the ``MILABENCH_GPU_ARCH`` environment variable as well, to either ``cuda`` or ``rocm``. diff --git a/milabench/cli/gated.py b/milabench/cli/gated.py index 3bf15828..43c26889 100644 --- a/milabench/cli/gated.py +++ b/milabench/cli/gated.py @@ -1,34 +1,42 @@ - +from collections import defaultdict from milabench.common import arguments, _get_multipack -def cli_gated(): - args = arguments() +def cli_gated(args=None): + """Print instruction to get access to gated models""" + + if args is None: + args = arguments() benchmarks = _get_multipack(args, return_config=True) - gated_bench = [] + urls = defaultdict(list) for bench, config in benchmarks.items(): tags = config.get("tags", []) if "gated" in tags and 'url' in config: - gated_bench.append((bench, config)) + urls[config["url"]].append((bench, config)) - if len(gated_bench) > 0: - print("benchmark use gated models or datasets") - print("You need to request permission to huggingface") + + if len(urls) > 0: + # + # This match the documentation in milabench/docs/usage.rst + # + print("#. Setup huggingface access: benchmark use gated models or datasets") + print(" You need to request permission to huggingface") + print() + print(" 1. Request access to gated models") print() - for bench, config in gated_bench: - print(f"{bench}") - print(f" url: {config.get('url')}") + for url, benches in urls.items(): + names = ' '.join([k for k, _ in benches]) + print(f" - `{names} <{url}>`_") + + print() + print(" 2. Create a new `read token `_ to download the models") + print() + print(" 3. Add the token to your environment ``export MILABENCH_HF_TOKEN={your_token}``") print() - print("Create a new token") - print(" - https://huggingface.co/settings/tokens/new?tokenType=read") - print("") - print("Add your token to your environment") - print(" export MILABENCH_HF_TOKEN={your_token}") - print("") print("Now you are ready to execute `milabench prepare`") diff --git a/milabench/cli/install.py b/milabench/cli/install.py index 10d33a1d..7a82761c 100644 --- a/milabench/cli/install.py +++ b/milabench/cli/install.py @@ -6,7 +6,7 @@ from ..common import get_multipack, run_with_loggers from ..log import DataReporter, TerminalFormatter, TextReporter - +from .gated import cli_gated # fmt: off @dataclass @@ -55,7 +55,7 @@ def cli_install(args=None): mp = get_multipack(run_name="install.{time}", overrides=overrides) - return run_with_loggers( + rc = run_with_loggers( mp.do_install(), loggers=[ TerminalFormatter(), @@ -66,3 +66,8 @@ def cli_install(args=None): ], mp=mp, ) + + # Print info about setting up milabench for gated models + cli_gated() + + return rc