diff --git a/dmlcloud/core/distributed.py b/dmlcloud/core/distributed.py index 1a62ab9..681f809 100644 --- a/dmlcloud/core/distributed.py +++ b/dmlcloud/core/distributed.py @@ -285,9 +285,20 @@ def _init_process_group_slurm(port=DEFAULT_PORT, **kwargs): _WorkerInfo.RANK = int(os.environ['SLURM_PROCID']) _WorkerInfo.WORLD_SIZE = int(os.environ['SLURM_NTASKS']) _WorkerInfo.LOCAL_RANK = int(os.environ['SLURM_LOCALID']) - _WorkerInfo.LOCAL_WORLD_SIZE = int(os.environ['SLURM_STEP_TASKS_PER_NODE']) _WorkerInfo.NODE_ID = int(os.environ['SLURM_NODEID']) + # Determine local world size via SLURM_TASKS_PER_NODE + # Format example: 2(x3),4,1 + tasks_per_node_raw = os.environ['SLURM_TASKS_PER_NODE'].split(',') + tasks_per_node = [] + for t in tasks_per_node_raw: + if '(x' in t: + ntasks, nnodes = t.split('(x') + tasks_per_node.extend([int(ntasks)] * int(nnodes[:-1])) + else: + tasks_per_node.append(int(t)) + _WorkerInfo.LOCAL_WORLD_SIZE = tasks_per_node[_WorkerInfo.NODE_ID] + ip = os.environ['SLURM_SRUN_COMM_HOST'] dist.init_process_group(