Skip to content

Commit

Permalink
feat: extended @dml.root_only to also work with Stage and Pipeline ca…
Browse files Browse the repository at this point in the history
…llback methods, closes #17
  • Loading branch information
sehoffmann committed Jan 7, 2025
1 parent 1d2e783 commit 79a0bea
Show file tree
Hide file tree
Showing 4 changed files with 423 additions and 17 deletions.
155 changes: 143 additions & 12 deletions dmlcloud/core/distributed.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
import inspect
import os
import sys
from contextlib import contextmanager
from datetime import timedelta
from functools import wraps
from typing import Callable, TYPE_CHECKING

import torch
import torch.distributed
import torch.distributed as dist

from ..util.tcp import find_free_port, get_local_ips


if TYPE_CHECKING:
from dmlcloud import Pipeline, Stage # noqa: F401


__all__ = [
'has_slurm',
'has_environment',
Expand All @@ -29,6 +38,7 @@


DEFAULT_PORT = os.environ.get('DMLCLOUD_PORT', 41312) # dml
LONG_TIMEOUT = 24 * 60 * 60 # timeout for long running barriers, default is 24 hours


class _WorkerInfo:
Expand Down Expand Up @@ -75,40 +85,161 @@ def has_mpi():
return False


def is_root():
def is_root(group: dist.ProcessGroup = None):
"""
Check if the current rank is the root rank (rank 0).
Args:
group (ProcessGroup, optional): The process group to work on. If None (default), the default process group will be used.
"""
return dist.get_rank() == 0
return dist.get_rank(group) == 0


def root_only(fn):
def root_only(
fn: Callable | type,
group: torch.distributed.ProcessGroup = None,
synchronize: bool = True,
timeout: int = LONG_TIMEOUT,
) -> Callable | type:
"""
Decorator for methods that should only be called on the root rank.
"""
@wraps(fn)
def wrapper(*args, **kwargs):
if is_root():
return fn(*args, **kwargs)
Can also be applied to individual callback methods of :class:`Pipeline` and :class:`Stage`, or to the whole class.
In that case, :attr:`Pipeline.gloo_group` is used as process group.
return wrapper
If ``synchronize=True``, a monitored_barrier before or after the function call depending on the rank.
This can be important to prevent timeouts from future all_reduce operations if non-root ranks move on before the root rank has finished.
Args:
fn: The function to decorate or a subclass of :class:`Pipeline` or :class:`Stage`.
group: The process group to work on. If None (default), the default process group will be used.
synchronize: If True, a barrier is inserted before or after the function call depending on the rank. Default is True.
timeout: Timeout in seconds for the monitored_barrier. Default is 24 hours.
Returns:
The decorated function or class.
Examples:
Annotating an individual function:
>>> @root_only
>>> def my_function():
>>> print('Only the root rank prints this.')
Annotating a whole :class:`Stage` subclass:
>>> @root_only
>>> class MyStage(Stage):
>>> def pre_stage(self):
>>> print('Only the root rank prints this.')
>>>
>>> def run_epoch(self):
>>> print('Only the root rank prints this.')
>>>
>>> def post_stage(self):
>>> print('Only the root rank prints this.')
Annotating individual methods of :class:`Stage`:
>>> class MyStage(Stage):
>>> def pre_stage(self):
>>> print('All ranks print this.')
>>>
>>> @root_only
>>> def post_stage(self):
>>> print('Only the root rank prints this.')
"""

if not inspect.isclass(fn):

@wraps(fn)
def wrapper(*args, **kwargs):
if is_root(group):
ret = fn(*args, **kwargs)
if synchronize:
dist.monitored_barrier(group, timeout=timedelta(seconds=timeout), wait_all_ranks=True)
return ret
elif synchronize:
dist.monitored_barrier(group, timeout=timedelta(seconds=timeout), wait_all_ranks=True)

return wrapper

elif 'dmlcloud.core.pipeline' in sys.modules and issubclass(
fn, sys.modules['dmlcloud.core.pipeline'].Pipeline
): # avoids circular imports
pipeline_cls = fn

def make_wrapper(method):
@wraps(method)
def pipeline_wrapper(self, *args, **kwargs):
if is_root(group):
ret = method(self, *args, **kwargs)
if synchronize:
dist.monitored_barrier(self.gloo_group, timeout=timedelta(seconds=timeout), wait_all_ranks=True)
return ret
elif synchronize:
dist.monitored_barrier(self.gloo_group, timeout=timedelta(seconds=timeout), wait_all_ranks=True)

return pipeline_wrapper

pipeline_cls.pre_run = make_wrapper(pipeline_cls.pre_run)
pipeline_cls.post_run = make_wrapper(pipeline_cls.post_run)

return pipeline_cls

elif 'dmlcloud.core.stage' in sys.modules and issubclass(
fn, sys.modules['dmlcloud.core.stage'].Stage
): # avoids circular imports
stage_cls = fn

def make_wrapper(method):
@wraps(method)
def stage_wrapper(self, *args, **kwargs):
if is_root(group):
ret = method(self, *args, **kwargs)
if synchronize:
dist.monitored_barrier(
self.pipe.gloo_group, timeout=timedelta(seconds=timeout), wait_all_ranks=True
)
return ret
elif synchronize:
dist.monitored_barrier(
self.pipe.gloo_group, timeout=timedelta(seconds=timeout), wait_all_ranks=True
)

return stage_wrapper

stage_cls.pre_stage = make_wrapper(stage_cls.pre_stage)
stage_cls.post_stage = make_wrapper(stage_cls.post_stage)
stage_cls.pre_epoch = make_wrapper(stage_cls.pre_epoch)
stage_cls.post_epoch = make_wrapper(stage_cls.post_epoch)
stage_cls.run_epoch = make_wrapper(stage_cls.run_epoch)

return stage_cls

else:
raise ValueError('root_only can only be applied to functions, Pipeline, or Stage subclasses.')


@contextmanager
def root_first():
def root_first(group: dist.ProcessGroup = None):
"""
Context manager that ensures that the root rank executes the code first before all other ranks.
This is realized by inserting a barrier before or after the code block depending on the rank.
Notice, that only a regular barrier is used, and, hence, the default timeout of 1800000 seconds applies for nccl.
Args:
group (ProcessGroup, optional): The process group to work on. If None (default), the default process group will be used.
"""
if is_root():
try:
yield
finally:
dist.barrier()
dist.barrier(group)
else:
dist.barrier()
dist.barrier(group)
try:
yield
finally:
Expand Down
10 changes: 5 additions & 5 deletions dmlcloud/core/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import torch
import torch.distributed

from . import distributed as dmldist
from . import distributed as dml_distributed


logger = logging.getLogger('dmlcloud')
Expand Down Expand Up @@ -173,16 +173,16 @@ def print_worker(*values, sep=' ', end="\n", file=None, flush=True, barrier=Fals

if barrier:
torch.distributed.barrier()
modified_values = [f'Worker {dmldist.rank()}']
if dmldist.local_node() is not None:
modified_values += [f'({dmldist.local_node()}.{dmldist.local_rank()})']
modified_values = [f'Worker {dml_distributed.rank()}']
if dml_distributed.local_node() is not None:
modified_values += [f'({dml_distributed.local_node()}.{dml_distributed.local_rank()})']
modified_values.extend(values)
print(*modified_values, sep=sep, end=end, file=file, flush=flush)
if barrier:
torch.distributed.barrier()


@dmldist.root_only
@dml_distributed.root_only
def print_root(*values, sep=' ', end="\n", file=None, flush=True):
"""
Print the values to a stream if the current rank is the root rank.
Expand Down
57 changes: 57 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from multiprocessing import Pipe, Process

import pytest
import torch
import torch.distributed
from dmlcloud.core.distributed import init


Expand All @@ -8,3 +11,57 @@ def torch_distributed():
init(kind='dummy')
yield
torch.distributed.destroy_process_group()


class DistributedEnvironment:
@staticmethod
def bind(tmpdir):
def init(world_size, timeout=5 * 60, daemon=True):
return DistributedEnvironment(world_size, timeout, daemon, str(tmpdir / 'filestore'))

return init

def __init__(self, world_size: int, timeout: int = 5 * 60, daemon: bool = True, file: str = None):
self.world_size = world_size
self.timeout = timeout
self.daemon = daemon
self.file = str(file)

def _run(self, rank, conn, func, *args, **kwargs):
store = torch.distributed.FileStore(self.file, self.world_size)
torch.distributed.init_process_group(backend='gloo', world_size=self.world_size, rank=rank, store=store)

torch.distributed.barrier()
ret = func(*args, **kwargs) # TODO: need to handle exceptions
torch.distributed.barrier()

conn.send(ret)

torch.distributed.destroy_process_group()

def start(self, func, *args, **kwargs):
self.processes = []
self.conns = []
for rank in range(self.world_size):
recv_conn, send_conn = Pipe()
process_args = (rank, send_conn, func) + args
process_kwargs = dict(kwargs)
process = Process(target=self._run, args=process_args, kwargs=process_kwargs, daemon=self.daemon)
self.conns.append(recv_conn)
self.processes.append(process)

for process in self.processes:
process.start()

return_values = []
for process, conn in zip(self.processes, self.conns): # TODO: should probably be a context manager
ret = conn.recv()
return_values.append(ret)
process.join(self.timeout)

return return_values


@pytest.fixture
def distributed_environment(tmp_path):
return DistributedEnvironment.bind(tmp_path)
Loading

0 comments on commit 79a0bea

Please sign in to comment.