Skip to content

Commit

Permalink
fix: slurm init due to exotic format, closes #39
Browse files Browse the repository at this point in the history
  • Loading branch information
sehoffmann committed Jan 6, 2025
1 parent 25d48b6 commit 39fd5e9
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion dmlcloud/core/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,20 @@ def _init_process_group_slurm(port=DEFAULT_PORT, **kwargs):
_WorkerInfo.RANK = int(os.environ['SLURM_PROCID'])
_WorkerInfo.WORLD_SIZE = int(os.environ['SLURM_NTASKS'])
_WorkerInfo.LOCAL_RANK = int(os.environ['SLURM_LOCALID'])
_WorkerInfo.LOCAL_WORLD_SIZE = int(os.environ['SLURM_STEP_TASKS_PER_NODE'])
_WorkerInfo.NODE_ID = int(os.environ['SLURM_NODEID'])

# Determine local world size via SLURM_TASKS_PER_NODE
# Format example: 2(x3),4,1
tasks_per_node_raw = os.environ['SLURM_TASKS_PER_NODE'].split(',')
tasks_per_node = []
for t in tasks_per_node_raw:
if '(x' in t:
ntasks, nnodes = t.split('(x')
tasks_per_node.extend([int(ntasks)] * int(nnodes[:-1]))
else:
tasks_per_node.append(int(t))
_WorkerInfo.LOCAL_WORLD_SIZE = tasks_per_node[_WorkerInfo.NODE_ID]

ip = os.environ['SLURM_SRUN_COMM_HOST']

dist.init_process_group(
Expand Down

0 comments on commit 39fd5e9

Please sign in to comment.