From 865176fff7f391c07515594412c2e23873237e2a Mon Sep 17 00:00:00 2001 From: GeminiLn Date: Fri, 6 Dec 2024 23:09:53 -0700 Subject: [PATCH 1/2] Fabric: add support for 'auto' accelerator --- src/lightning/fabric/cli.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/lightning/fabric/cli.py b/src/lightning/fabric/cli.py index 5f18884e83d79..f4523bdb3f3a9 100644 --- a/src/lightning/fabric/cli.py +++ b/src/lightning/fabric/cli.py @@ -36,7 +36,7 @@ _CLICK_AVAILABLE = RequirementCache("click") _LIGHTNING_SDK_AVAILABLE = RequirementCache("lightning_sdk") -_SUPPORTED_ACCELERATORS = ("cpu", "gpu", "cuda", "mps", "tpu") +_SUPPORTED_ACCELERATORS = ("cpu", "gpu", "cuda", "mps", "tpu", "auto") def _get_supported_strategies() -> list[str]: @@ -208,6 +208,14 @@ def _set_env_variables(args: Namespace) -> None: def _get_num_processes(accelerator: str, devices: str) -> int: """Parse the `devices` argument to determine how many processes need to be launched on the current machine.""" + if accelerator == "auto": + if torch.cuda.is_available(): + accelerator = "cuda" + elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): + accelerator = "mps" + else: + accelerator = "cpu" + if accelerator == "gpu": parsed_devices = _parse_gpu_ids(devices, include_cuda=True, include_mps=True) elif accelerator == "cuda": From 3949b29e5b9c689c077a0194ad1ad9d45e4f385b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 7 Dec 2024 06:18:48 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/fabric/cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/fabric/cli.py b/src/lightning/fabric/cli.py index f4523bdb3f3a9..3f93472b044f6 100644 --- a/src/lightning/fabric/cli.py +++ b/src/lightning/fabric/cli.py @@ -215,7 +215,7 @@ def _get_num_processes(accelerator: str, devices: str) -> int: accelerator = "mps" else: accelerator = "cpu" - + if accelerator == "gpu": parsed_devices = _parse_gpu_ids(devices, include_cuda=True, include_mps=True) elif accelerator == "cuda":