From 79a0bea6203325f10c0b3788f5f414da609d13ff Mon Sep 17 00:00:00 2001 From: Sebastian Hoffmann Date: Tue, 7 Jan 2025 17:20:04 +0100 Subject: [PATCH] feat: extended @dml.root_only to also work with Stage and Pipeline callback methods, closes #17 --- dmlcloud/core/distributed.py | 155 +++++++++++++++++++++++-- dmlcloud/core/logging.py | 10 +- test/conftest.py | 57 +++++++++ test/test_root_only.py | 218 +++++++++++++++++++++++++++++++++++ 4 files changed, 423 insertions(+), 17 deletions(-) create mode 100644 test/test_root_only.py diff --git a/dmlcloud/core/distributed.py b/dmlcloud/core/distributed.py index 681f809..1d13980 100644 --- a/dmlcloud/core/distributed.py +++ b/dmlcloud/core/distributed.py @@ -1,13 +1,22 @@ +import inspect import os +import sys from contextlib import contextmanager +from datetime import timedelta from functools import wraps +from typing import Callable, TYPE_CHECKING import torch +import torch.distributed import torch.distributed as dist from ..util.tcp import find_free_port, get_local_ips +if TYPE_CHECKING: + from dmlcloud import Pipeline, Stage # noqa: F401 + + __all__ = [ 'has_slurm', 'has_environment', @@ -29,6 +38,7 @@ DEFAULT_PORT = os.environ.get('DMLCLOUD_PORT', 41312) # dml +LONG_TIMEOUT = 24 * 60 * 60 # timeout for long running barriers, default is 24 hours class _WorkerInfo: @@ -75,40 +85,161 @@ def has_mpi(): return False -def is_root(): +def is_root(group: dist.ProcessGroup = None): """ Check if the current rank is the root rank (rank 0). + + Args: + group (ProcessGroup, optional): The process group to work on. If None (default), the default process group will be used. """ - return dist.get_rank() == 0 + return dist.get_rank(group) == 0 -def root_only(fn): +def root_only( + fn: Callable | type, + group: torch.distributed.ProcessGroup = None, + synchronize: bool = True, + timeout: int = LONG_TIMEOUT, +) -> Callable | type: """ Decorator for methods that should only be called on the root rank. - """ - @wraps(fn) - def wrapper(*args, **kwargs): - if is_root(): - return fn(*args, **kwargs) + Can also be applied to individual callback methods of :class:`Pipeline` and :class:`Stage`, or to the whole class. + In that case, :attr:`Pipeline.gloo_group` is used as process group. - return wrapper + If ``synchronize=True``, a monitored_barrier before or after the function call depending on the rank. + This can be important to prevent timeouts from future all_reduce operations if non-root ranks move on before the root rank has finished. + + Args: + fn: The function to decorate or a subclass of :class:`Pipeline` or :class:`Stage`. + group: The process group to work on. If None (default), the default process group will be used. + synchronize: If True, a barrier is inserted before or after the function call depending on the rank. Default is True. + timeout: Timeout in seconds for the monitored_barrier. Default is 24 hours. + + Returns: + The decorated function or class. + + Examples: + + Annotating an individual function: + + >>> @root_only + >>> def my_function(): + >>> print('Only the root rank prints this.') + + Annotating a whole :class:`Stage` subclass: + + >>> @root_only + >>> class MyStage(Stage): + >>> def pre_stage(self): + >>> print('Only the root rank prints this.') + >>> + >>> def run_epoch(self): + >>> print('Only the root rank prints this.') + >>> + >>> def post_stage(self): + >>> print('Only the root rank prints this.') + + Annotating individual methods of :class:`Stage`: + + >>> class MyStage(Stage): + >>> def pre_stage(self): + >>> print('All ranks print this.') + >>> + >>> @root_only + >>> def post_stage(self): + >>> print('Only the root rank prints this.') + """ + + if not inspect.isclass(fn): + + @wraps(fn) + def wrapper(*args, **kwargs): + if is_root(group): + ret = fn(*args, **kwargs) + if synchronize: + dist.monitored_barrier(group, timeout=timedelta(seconds=timeout), wait_all_ranks=True) + return ret + elif synchronize: + dist.monitored_barrier(group, timeout=timedelta(seconds=timeout), wait_all_ranks=True) + + return wrapper + + elif 'dmlcloud.core.pipeline' in sys.modules and issubclass( + fn, sys.modules['dmlcloud.core.pipeline'].Pipeline + ): # avoids circular imports + pipeline_cls = fn + + def make_wrapper(method): + @wraps(method) + def pipeline_wrapper(self, *args, **kwargs): + if is_root(group): + ret = method(self, *args, **kwargs) + if synchronize: + dist.monitored_barrier(self.gloo_group, timeout=timedelta(seconds=timeout), wait_all_ranks=True) + return ret + elif synchronize: + dist.monitored_barrier(self.gloo_group, timeout=timedelta(seconds=timeout), wait_all_ranks=True) + + return pipeline_wrapper + + pipeline_cls.pre_run = make_wrapper(pipeline_cls.pre_run) + pipeline_cls.post_run = make_wrapper(pipeline_cls.post_run) + + return pipeline_cls + + elif 'dmlcloud.core.stage' in sys.modules and issubclass( + fn, sys.modules['dmlcloud.core.stage'].Stage + ): # avoids circular imports + stage_cls = fn + + def make_wrapper(method): + @wraps(method) + def stage_wrapper(self, *args, **kwargs): + if is_root(group): + ret = method(self, *args, **kwargs) + if synchronize: + dist.monitored_barrier( + self.pipe.gloo_group, timeout=timedelta(seconds=timeout), wait_all_ranks=True + ) + return ret + elif synchronize: + dist.monitored_barrier( + self.pipe.gloo_group, timeout=timedelta(seconds=timeout), wait_all_ranks=True + ) + + return stage_wrapper + + stage_cls.pre_stage = make_wrapper(stage_cls.pre_stage) + stage_cls.post_stage = make_wrapper(stage_cls.post_stage) + stage_cls.pre_epoch = make_wrapper(stage_cls.pre_epoch) + stage_cls.post_epoch = make_wrapper(stage_cls.post_epoch) + stage_cls.run_epoch = make_wrapper(stage_cls.run_epoch) + + return stage_cls + + else: + raise ValueError('root_only can only be applied to functions, Pipeline, or Stage subclasses.') @contextmanager -def root_first(): +def root_first(group: dist.ProcessGroup = None): """ Context manager that ensures that the root rank executes the code first before all other ranks. This is realized by inserting a barrier before or after the code block depending on the rank. + Notice, that only a regular barrier is used, and, hence, the default timeout of 1800000 seconds applies for nccl. + + Args: + group (ProcessGroup, optional): The process group to work on. If None (default), the default process group will be used. """ if is_root(): try: yield finally: - dist.barrier() + dist.barrier(group) else: - dist.barrier() + dist.barrier(group) try: yield finally: diff --git a/dmlcloud/core/logging.py b/dmlcloud/core/logging.py index 1630d9d..936ba97 100644 --- a/dmlcloud/core/logging.py +++ b/dmlcloud/core/logging.py @@ -15,7 +15,7 @@ import torch import torch.distributed -from . import distributed as dmldist +from . import distributed as dml_distributed logger = logging.getLogger('dmlcloud') @@ -173,16 +173,16 @@ def print_worker(*values, sep=' ', end="\n", file=None, flush=True, barrier=Fals if barrier: torch.distributed.barrier() - modified_values = [f'Worker {dmldist.rank()}'] - if dmldist.local_node() is not None: - modified_values += [f'({dmldist.local_node()}.{dmldist.local_rank()})'] + modified_values = [f'Worker {dml_distributed.rank()}'] + if dml_distributed.local_node() is not None: + modified_values += [f'({dml_distributed.local_node()}.{dml_distributed.local_rank()})'] modified_values.extend(values) print(*modified_values, sep=sep, end=end, file=file, flush=flush) if barrier: torch.distributed.barrier() -@dmldist.root_only +@dml_distributed.root_only def print_root(*values, sep=' ', end="\n", file=None, flush=True): """ Print the values to a stream if the current rank is the root rank. diff --git a/test/conftest.py b/test/conftest.py index ad9bb12..0764c5d 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,5 +1,8 @@ +from multiprocessing import Pipe, Process + import pytest import torch +import torch.distributed from dmlcloud.core.distributed import init @@ -8,3 +11,57 @@ def torch_distributed(): init(kind='dummy') yield torch.distributed.destroy_process_group() + + +class DistributedEnvironment: + @staticmethod + def bind(tmpdir): + def init(world_size, timeout=5 * 60, daemon=True): + return DistributedEnvironment(world_size, timeout, daemon, str(tmpdir / 'filestore')) + + return init + + def __init__(self, world_size: int, timeout: int = 5 * 60, daemon: bool = True, file: str = None): + self.world_size = world_size + self.timeout = timeout + self.daemon = daemon + self.file = str(file) + + def _run(self, rank, conn, func, *args, **kwargs): + store = torch.distributed.FileStore(self.file, self.world_size) + torch.distributed.init_process_group(backend='gloo', world_size=self.world_size, rank=rank, store=store) + + torch.distributed.barrier() + ret = func(*args, **kwargs) # TODO: need to handle exceptions + torch.distributed.barrier() + + conn.send(ret) + + torch.distributed.destroy_process_group() + + def start(self, func, *args, **kwargs): + self.processes = [] + self.conns = [] + for rank in range(self.world_size): + recv_conn, send_conn = Pipe() + process_args = (rank, send_conn, func) + args + process_kwargs = dict(kwargs) + process = Process(target=self._run, args=process_args, kwargs=process_kwargs, daemon=self.daemon) + self.conns.append(recv_conn) + self.processes.append(process) + + for process in self.processes: + process.start() + + return_values = [] + for process, conn in zip(self.processes, self.conns): # TODO: should probably be a context manager + ret = conn.recv() + return_values.append(ret) + process.join(self.timeout) + + return return_values + + +@pytest.fixture +def distributed_environment(tmp_path): + return DistributedEnvironment.bind(tmp_path) diff --git a/test/test_root_only.py b/test/test_root_only.py new file mode 100644 index 0000000..40bc9fb --- /dev/null +++ b/test/test_root_only.py @@ -0,0 +1,218 @@ +import sys + +import dmlcloud as dml +import pytest + + +@dml.root_only +def return_root_rank(): + """TEST_DOC_STRING""" + return dml.rank() + + +@dml.root_only +class RootOnlyStage(dml.Stage): + """TEST_DOC_STRING""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.cb_executed = { + 'pre_stage': False, + 'post_stage': False, + 'pre_epoch': False, + 'post_epoch': False, + 'run_epoch': False, + } + + def pre_stage(self): + """TEST_DOC_STRING""" + self.cb_executed['pre_stage'] = True + + def post_stage(self): + """TEST_DOC_STRING""" + self.cb_executed['post_stage'] = True + + def pre_epoch(self): + """TEST_DOC_STRING""" + self.cb_executed['pre_epoch'] = True + + def post_epoch(self): + """TEST_DOC_STRING""" + self.cb_executed['post_epoch'] = True + + def run_epoch(self): + """TEST_DOC_STRING""" + self.cb_executed['run_epoch'] = True + + +class PartialRootOnlyStage(dml.Stage): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.cb_executed = { + 'pre_stage': False, + 'post_stage': False, + 'pre_epoch': False, + 'post_epoch': False, + 'run_epoch': False, + } + + def pre_stage(self): + self.cb_executed['pre_stage'] = True + + @dml.root_only + def post_stage(self): + """TEST_DOC_STRING""" + self.cb_executed['post_stage'] = True + + @dml.root_only + def pre_epoch(self): + """TEST_DOC_STRING""" + self.cb_executed['pre_epoch'] = True + + def post_epoch(self): + self.cb_executed['post_epoch'] = True + + @dml.root_only + def run_epoch(self): + self.cb_executed['run_epoch'] = True + + +@dml.root_only +class RootOnlyPipeline(dml.Pipeline): + """TEST_DOC_STRING""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.cb_executed = { + 'pre_run': False, + 'post_run': False, + } + + def pre_run(self): + """TEST_DOC_STRING""" + self.cb_executed['pre_run'] = True + + def post_run(self): + """TEST_DOC_STRING""" + self.cb_executed['post_run'] = True + + +class PartialRootOnlyPipeline(dml.Pipeline): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.cb_executed = { + 'pre_run': False, + 'post_run': False, + } + + @dml.root_only + def pre_run(self): + """TEST_DOC_STRING""" + self.cb_executed['pre_run'] = True + + def post_run(self): + self.cb_executed['post_run'] = True + + +class TestRootOnly: + def test_function(self, distributed_environment): + ranks = distributed_environment(4).start(return_root_rank) + assert ranks == [0, None, None, None] + assert return_root_rank.__name__ == 'return_root_rank' + assert return_root_rank.__doc__ == 'TEST_DOC_STRING' + + def test_stage(self, distributed_environment): + def run(): + stage = RootOnlyStage(epochs=1) + pipe = dml.Pipeline() + pipe.append(stage) + pipe.run() + return stage.cb_executed + + results = distributed_environment(3).start(run) + + assert [r['pre_stage'] for r in results] == [True, False, False] + assert [r['post_stage'] for r in results] == [True, False, False] + assert [r['pre_epoch'] for r in results] == [True, False, False] + assert [r['post_epoch'] for r in results] == [True, False, False] + assert [r['run_epoch'] for r in results] == [True, False, False] + + assert RootOnlyStage.__name__ == 'RootOnlyStage' + assert RootOnlyStage.__doc__ == 'TEST_DOC_STRING' + + assert RootOnlyStage.pre_stage.__name__ == 'pre_stage' + assert RootOnlyStage.pre_stage.__doc__ == 'TEST_DOC_STRING' + + assert RootOnlyStage.post_stage.__name__ == 'post_stage' + assert RootOnlyStage.post_stage.__doc__ == 'TEST_DOC_STRING' + + assert RootOnlyStage.pre_epoch.__name__ == 'pre_epoch' + assert RootOnlyStage.pre_epoch.__doc__ == 'TEST_DOC_STRING' + + assert RootOnlyStage.post_epoch.__name__ == 'post_epoch' + assert RootOnlyStage.post_epoch.__doc__ == 'TEST_DOC_STRING' + + assert RootOnlyStage.run_epoch.__name__ == 'run_epoch' + assert RootOnlyStage.run_epoch.__doc__ == 'TEST_DOC_STRING' + + def test_partial_stage(self, distributed_environment): + def run(): + stage = PartialRootOnlyStage(epochs=1) + pipe = dml.Pipeline() + pipe.append(stage) + pipe.run() + return stage.cb_executed + + results = distributed_environment(3).start(run) + + assert [r['pre_stage'] for r in results] == [True, True, True] + assert [r['post_stage'] for r in results] == [True, False, False] + assert [r['pre_epoch'] for r in results] == [True, False, False] + assert [r['post_epoch'] for r in results] == [True, True, True] + assert [r['run_epoch'] for r in results] == [True, False, False] + + assert PartialRootOnlyStage.post_stage.__name__ == 'post_stage' + assert PartialRootOnlyStage.post_stage.__doc__ == 'TEST_DOC_STRING' + + assert PartialRootOnlyStage.pre_epoch.__name__ == 'pre_epoch' + assert PartialRootOnlyStage.pre_epoch.__doc__ == 'TEST_DOC_STRING' + + def test_pipeline(self, distributed_environment): + def run(): + pipe = RootOnlyPipeline() + pipe.append(RootOnlyStage(epochs=1)) + pipe.run() + return pipe.cb_executed + + results = distributed_environment(3).start(run) + + assert [r['pre_run'] for r in results] == [True, False, False] + assert [r['post_run'] for r in results] == [True, False, False] + + assert RootOnlyPipeline.__name__ == 'RootOnlyPipeline' + assert RootOnlyPipeline.__doc__ == 'TEST_DOC_STRING' + + assert RootOnlyPipeline.pre_run.__name__ == 'pre_run' + assert RootOnlyPipeline.pre_run.__doc__ == 'TEST_DOC_STRING' + + assert RootOnlyPipeline.post_run.__name__ == 'post_run' + assert RootOnlyPipeline.post_run.__doc__ == 'TEST_DOC_STRING' + + def test_partial_pipeline(self, distributed_environment): + def run(): + pipe = PartialRootOnlyPipeline() + pipe.append(RootOnlyStage(epochs=1)) + pipe.run() + return pipe.cb_executed + + results = distributed_environment(3).start(run) + + assert [r['pre_run'] for r in results] == [True, False, False] + assert [r['post_run'] for r in results] == [True, True, True] + + assert PartialRootOnlyPipeline.pre_run.__name__ == 'pre_run' + assert PartialRootOnlyPipeline.pre_run.__doc__ == 'TEST_DOC_STRING' + + +if __name__ == '__main__': + sys.exit(pytest.main([__file__]))