From e801e0fd0982f701a5f07c3df014a7c25482cc90 Mon Sep 17 00:00:00 2001 From: Sebastian Hoffmann Date: Wed, 8 Jan 2025 15:42:17 +0100 Subject: [PATCH] feat: dml.seed(), closes #23 --- dmlcloud/__init__.py | 68 +++++++++++----------- dmlcloud/core/distributed.py | 47 +++++++++++++++ doc/dmlcloud.rst | 2 + examples/mnist.py | 19 +++++- test/test_seed.py | 109 +++++++++++++++++++++++++++++++++++ 5 files changed, 210 insertions(+), 35 deletions(-) create mode 100644 test/test_seed.py diff --git a/dmlcloud/__init__.py b/dmlcloud/__init__.py index e70e6df..fe27baf 100644 --- a/dmlcloud/__init__.py +++ b/dmlcloud/__init__.py @@ -29,7 +29,7 @@ from .core.pipeline import Pipeline __all__ += [ - Pipeline, + 'Pipeline', ] # Stage @@ -37,7 +37,7 @@ from .core.stage import Stage __all__ += [ - Stage, + 'Stage', ] # Callbacks @@ -66,26 +66,28 @@ rank, root_first, root_only, + seed, world_size, ) __all__ += [ - has_slurm, - has_environment, - has_mpi, - is_root, - root_only, - root_first, - rank, - world_size, - local_rank, - local_world_size, - local_node, - all_gather_object, - gather_object, - broadcast_object, - init, - deinitialize_torch_distributed, + 'has_slurm', + 'has_environment', + 'has_mpi', + 'is_root', + 'root_only', + 'root_first', + 'rank', + 'world_size', + 'local_rank', + 'local_world_size', + 'local_node', + 'all_gather_object', + 'gather_object', + 'broadcast_object', + 'init', + 'deinitialize_torch_distributed', + 'seed', ] # Metrics @@ -114,24 +116,24 @@ ) __all__ += [ - logger, - setup_logger, - reset_logger, - flush_logger, - print_root, - print_worker, - log, - debug, - info, - warning, - error, - critical, + 'logger', + 'setup_logger', + 'reset_logger', + 'flush_logger', + 'print_root', + 'print_worker', + 'log', + 'debug', + 'info', + 'warning', + 'error', + 'critical', ] from .core.model import count_parameters, scale_lr, wrap_ddp __all__ += [ - wrap_ddp, - scale_lr, - count_parameters, + 'wrap_ddp', + 'scale_lr', + 'count_parameters', ] diff --git a/dmlcloud/core/distributed.py b/dmlcloud/core/distributed.py index 1d13980..a91059f 100644 --- a/dmlcloud/core/distributed.py +++ b/dmlcloud/core/distributed.py @@ -1,11 +1,13 @@ import inspect import os +import random import sys from contextlib import contextmanager from datetime import timedelta from functools import wraps from typing import Callable, TYPE_CHECKING +import numpy as np import torch import torch.distributed import torch.distributed as dist @@ -34,6 +36,7 @@ 'broadcast_object', 'init', 'deinitialize_torch_distributed', + 'seed', ] @@ -557,3 +560,47 @@ def deinitialize_torch_distributed(): _WorkerInfo.LOCAL_WORLD_SIZE = None _WorkerInfo.NODE_ID = None dist.destroy_process_group() + + +def seed(seed: int | None = None, group: dist.ProcessGroup = None) -> int: + """ + Share's the seed from the root rank to all ranks in the group and seeds the random number generators. + + The following libraries are seeded: + - random + - numpy + - torch + - tensorflow (if installed and imported) + + Different ranks will be seeded differently, so that they do not generate the same random numbers. + + Args: + seed: The seed to share. If None, a random seed is generated. Default is None. + Value must be within the inclusive range `[-0x8000_0000_0000_0000, 0xffff_ffff_ffff_ffff]`. + Negative inputs are remapped to positive values with the formula `0xffff_ffff_ffff_ffff + seed`. + group: The process group to work on. If None, the default process group will be used. Default is None. + + Returns: + The original seed that was provided the root rank. 64-bit integer. + + Raises: + RuntimeError: If ``seed`` is not within the inclusive range `[-0x8000_0000_0000_0000, 0xffff_ffff_ffff_ffff]`. + """ + if seed is None: + seed = torch.seed() + seed = broadcast_object(seed, group=group) + + worker_seed = seed + rank() + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + # numpy only supports 32bits, so we use torch to generate a 32bit seed + np_seed = torch.randint(0, 2**31, (1,)).item() + np.random.seed(np_seed) + + if 'tensorflow' in sys.modules: + import tensorflow as tf + + tf.random.set_seed(worker_seed) + + return seed diff --git a/doc/dmlcloud.rst b/doc/dmlcloud.rst index ea12535..2971166 100644 --- a/doc/dmlcloud.rst +++ b/doc/dmlcloud.rst @@ -23,6 +23,8 @@ dmlcloud provides a set of helper functions to simplify the use of torch.distrib :toctree: generated init + seed + deinitialize_torch_distributed is_root root_only diff --git a/examples/mnist.py b/examples/mnist.py index bc13083..d024883 100644 --- a/examples/mnist.py +++ b/examples/mnist.py @@ -1,3 +1,5 @@ +import argparse + import dmlcloud as dml import torch import torchmetrics @@ -96,8 +98,21 @@ def _val_epoch(self): def main(): - pipe = dml.Pipeline(name='MNIST') - pipe.append(MNISTStage(epochs=3)) + dml.init() + + parser = argparse.ArgumentParser() + parser.add_argument('--epochs', type=int, default=4) + parser.add_argument('--seed', type=int) + args = parser.parse_args() + + seed = dml.seed(args.seed) # This is a helper function to set the seed for all devices + config = { + 'seed': seed, + 'epochs': args.epochs, + } + + pipe = dml.Pipeline(config, name='MNIST') + pipe.append(MNISTStage(epochs=args.epochs)) pipe.enable_checkpointing('checkpoints') pipe.enable_wandb() pipe.run() diff --git a/test/test_seed.py b/test/test_seed.py new file mode 100644 index 0000000..0cc52b9 --- /dev/null +++ b/test/test_seed.py @@ -0,0 +1,109 @@ +import random +import sys + +import dmlcloud as dml +import numpy as np +import pytest +import torch + + +def seed(seed=None): + seed = dml.seed(seed) + state = dict( + seed=seed, + torch_state=np.array(torch.get_rng_state()), + numpy_state=np.random.get_state()[1], + random_state=np.array(random.getstate()[1]), + ) + return state + + +class TestSeed: + def test_single_worker_deterministic(self, torch_distributed): + prev_torch_state = np.array(torch.get_rng_state()) + prev_numpy_state = np.random.get_state()[1] + prev_random_state = np.array(random.getstate()[1]) + + states = seed(42) + assert states['seed'] == 42 + assert (states['torch_state'] != prev_torch_state).any() + assert (states['numpy_state'] != prev_numpy_state).any() + assert (states['random_state'] != prev_random_state).any() + + # advance the RNG + torch.randint(0, 10, (1,)) + np.random.randint(0, 10) + + # reseeding should reset the RNG + new_states = seed(42) + assert new_states['seed'] == 42 + assert (new_states['torch_state'] == states['torch_state']).all() + assert (new_states['numpy_state'] == states['numpy_state']).all() + assert (new_states['random_state'] == states['random_state']).all() + + def test_input_validation(self, torch_distributed): + with pytest.raises(RuntimeError): + dml.seed(2**80) + assert dml.seed(2**64 - 1) == 2**64 - 1 + + def test_single_worker_random(self, torch_distributed): + prev_torch_state = np.array(torch.get_rng_state()) + prev_numpy_state = np.random.get_state()[1] + prev_random_state = np.array(random.getstate()[1]) + + states = seed() + assert type(states['seed']) is int + assert (states['torch_state'] != prev_torch_state).any() + assert (states['numpy_state'] != prev_numpy_state).any() + assert (states['random_state'] != prev_random_state).any() + + # reseeding should yield different states + new_states = seed() + assert new_states['seed'] != states['seed'] + assert (new_states['torch_state'] != states['torch_state']).any() + assert (new_states['numpy_state'] != states['numpy_state']).any() + assert (new_states['random_state'] != states['random_state']).any() + + def test_multi_worker_deterministic(self, distributed_environment): + states = distributed_environment(4).start(seed, 42) + assert [s['seed'] for s in states] == [42, 42, 42, 42] + + # workers should have different states + assert all((s['torch_state'] != states[0]['torch_state']).any() for s in states[1:]) + assert all((s['numpy_state'] != states[0]['numpy_state']).any() for s in states[1:]) + assert all((s['random_state'] != states[0]['random_state']).any() for s in states[1:]) + + # same seed should yield same states + new_states = distributed_environment(4).start(seed, 42) + assert [s['seed'] for s in new_states] == [42, 42, 42, 42] + assert all((s1['torch_state'] == s2['torch_state']).all() for s1, s2 in zip(states, new_states)) + assert all((s1['numpy_state'] == s2['numpy_state']).all() for s1, s2 in zip(states, new_states)) + assert all((s1['random_state'] == s2['random_state']).all() for s1, s2 in zip(states, new_states)) + + # different seed should yield different states + new_states = distributed_environment(4).start(seed, 11) + assert [s['seed'] for s in new_states] == [11, 11, 11, 11] + assert all((s1['torch_state'] != s2['torch_state']).any() for s1, s2 in zip(states, new_states)) + assert all((s1['numpy_state'] != s2['numpy_state']).any() for s1, s2 in zip(states, new_states)) + assert all((s1['random_state'] != s2['random_state']).any() for s1, s2 in zip(states, new_states)) + + def test_multi_worker_random(self, distributed_environment): + # all workers should have same seeds + states = distributed_environment(4).start(seed) + assert [s['seed'] for s in states] == [states[0]['seed']] * 4 + + # workers should have different states + assert all((s['torch_state'] != states[0]['torch_state']).any() for s in states[1:]) + assert all((s['numpy_state'] != states[0]['numpy_state']).any() for s in states[1:]) + assert all((s['random_state'] != states[0]['random_state']).any() for s in states[1:]) + + # reseeding should yield different states and seeds + new_states = distributed_environment(4).start(seed) + assert [s['seed'] for s in new_states] != [s['seed'] for s in states] + assert all((s1['torch_state'] != s2['torch_state']).any() for s1, s2 in zip(states, new_states)) + assert all((s1['numpy_state'] != s2['numpy_state']).any() for s1, s2 in zip(states, new_states)) + assert all((s1['random_state'] != s2['random_state']).any() for s1, s2 in zip(states, new_states)) + + +if __name__ == '__main__': + sys.exit(pytest.main([__file__]))