From 6dce30e5538b7de3ed291e0741f5a63b7481bd2d Mon Sep 17 00:00:00 2001 From: Arun Jose <40291569+arunjose696@users.noreply.github.com> Date: Thu, 1 Aug 2024 18:33:11 +0200 Subject: [PATCH] FIX-#7355: Cpu count would be set incorrectly on a cluster (#7356) Signed-off-by: arunjose696 --- modin/config/envvars.py | 18 ++++++++++++++++++ modin/core/execution/dask/common/utils.py | 1 + modin/core/execution/ray/common/utils.py | 1 + modin/core/execution/unidist/common/utils.py | 1 + 4 files changed, 21 insertions(+) diff --git a/modin/config/envvars.py b/modin/config/envvars.py index 8654ebe30c1..3635c63d026 100644 --- a/modin/config/envvars.py +++ b/modin/config/envvars.py @@ -332,6 +332,24 @@ class CpuCount(EnvironmentVariable, type=int): varname = "MODIN_CPUS" + @classmethod + def _put(cls, value: int) -> None: + """ + Put specific value if CpuCount wasn't set by a user yet. + + Parameters + ---------- + value : int + Config value to set. + + Notes + ----- + This method is used to set CpuCount from cluster resources internally + and should not be called by a user. + """ + if cls.get_value_source() == ValueSource.DEFAULT: + cls.put(value) + @classmethod def _get_default(cls) -> int: """ diff --git a/modin/core/execution/dask/common/utils.py b/modin/core/execution/dask/common/utils.py index 067a94fcdf0..52b4e38f53d 100644 --- a/modin/core/execution/dask/common/utils.py +++ b/modin/core/execution/dask/common/utils.py @@ -74,3 +74,4 @@ def _disable_warnings(): num_cpus = len(client.ncores()) NPartitions._put(num_cpus) + CpuCount._put(num_cpus) diff --git a/modin/core/execution/ray/common/utils.py b/modin/core/execution/ray/common/utils.py index cc2010fc7fb..d419a61a0d2 100644 --- a/modin/core/execution/ray/common/utils.py +++ b/modin/core/execution/ray/common/utils.py @@ -151,6 +151,7 @@ def initialize_ray( num_cpus = int(ray.cluster_resources()["CPU"]) NPartitions._put(num_cpus) + CpuCount._put(num_cpus) # TODO(https://github.com/ray-project/ray/issues/28216): remove this # workaround once Ray gives a better way to suppress task errors. diff --git a/modin/core/execution/unidist/common/utils.py b/modin/core/execution/unidist/common/utils.py index 5aa31698b6a..6455d194b25 100644 --- a/modin/core/execution/unidist/common/utils.py +++ b/modin/core/execution/unidist/common/utils.py @@ -42,6 +42,7 @@ def initialize_unidist(): num_cpus = sum(v["CPU"] for v in unidist.cluster_resources().values()) modin_cfg.NPartitions._put(num_cpus) + modin_cfg.CpuCount._put(num_cpus) def deserialize(obj): # pragma: no cover