Skip to content

Commit

Permalink
feat: dml.seed(), closes #23
Browse files Browse the repository at this point in the history
  • Loading branch information
sehoffmann committed Jan 8, 2025
1 parent 399cc4b commit e801e0f
Show file tree
Hide file tree
Showing 5 changed files with 210 additions and 35 deletions.
68 changes: 35 additions & 33 deletions dmlcloud/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@
from .core.pipeline import Pipeline

__all__ += [
Pipeline,
'Pipeline',
]

# Stage

from .core.stage import Stage

__all__ += [
Stage,
'Stage',
]

# Callbacks
Expand Down Expand Up @@ -66,26 +66,28 @@
rank,
root_first,
root_only,
seed,
world_size,
)

__all__ += [
has_slurm,
has_environment,
has_mpi,
is_root,
root_only,
root_first,
rank,
world_size,
local_rank,
local_world_size,
local_node,
all_gather_object,
gather_object,
broadcast_object,
init,
deinitialize_torch_distributed,
'has_slurm',
'has_environment',
'has_mpi',
'is_root',
'root_only',
'root_first',
'rank',
'world_size',
'local_rank',
'local_world_size',
'local_node',
'all_gather_object',
'gather_object',
'broadcast_object',
'init',
'deinitialize_torch_distributed',
'seed',
]

# Metrics
Expand Down Expand Up @@ -114,24 +116,24 @@
)

__all__ += [
logger,
setup_logger,
reset_logger,
flush_logger,
print_root,
print_worker,
log,
debug,
info,
warning,
error,
critical,
'logger',
'setup_logger',
'reset_logger',
'flush_logger',
'print_root',
'print_worker',
'log',
'debug',
'info',
'warning',
'error',
'critical',
]

from .core.model import count_parameters, scale_lr, wrap_ddp

__all__ += [
wrap_ddp,
scale_lr,
count_parameters,
'wrap_ddp',
'scale_lr',
'count_parameters',
]
47 changes: 47 additions & 0 deletions dmlcloud/core/distributed.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import inspect
import os
import random
import sys
from contextlib import contextmanager
from datetime import timedelta
from functools import wraps
from typing import Callable, TYPE_CHECKING

import numpy as np
import torch
import torch.distributed
import torch.distributed as dist
Expand Down Expand Up @@ -34,6 +36,7 @@
'broadcast_object',
'init',
'deinitialize_torch_distributed',
'seed',
]


Expand Down Expand Up @@ -557,3 +560,47 @@ def deinitialize_torch_distributed():
_WorkerInfo.LOCAL_WORLD_SIZE = None
_WorkerInfo.NODE_ID = None
dist.destroy_process_group()


def seed(seed: int | None = None, group: dist.ProcessGroup = None) -> int:
"""
Share's the seed from the root rank to all ranks in the group and seeds the random number generators.
The following libraries are seeded:
- random
- numpy
- torch
- tensorflow (if installed and imported)
Different ranks will be seeded differently, so that they do not generate the same random numbers.
Args:
seed: The seed to share. If None, a random seed is generated. Default is None.
Value must be within the inclusive range `[-0x8000_0000_0000_0000, 0xffff_ffff_ffff_ffff]`.
Negative inputs are remapped to positive values with the formula `0xffff_ffff_ffff_ffff + seed`.
group: The process group to work on. If None, the default process group will be used. Default is None.
Returns:
The original seed that was provided the root rank. 64-bit integer.
Raises:
RuntimeError: If ``seed`` is not within the inclusive range `[-0x8000_0000_0000_0000, 0xffff_ffff_ffff_ffff]`.
"""
if seed is None:
seed = torch.seed()
seed = broadcast_object(seed, group=group)

worker_seed = seed + rank()
torch.manual_seed(worker_seed)
random.seed(worker_seed)

# numpy only supports 32bits, so we use torch to generate a 32bit seed
np_seed = torch.randint(0, 2**31, (1,)).item()
np.random.seed(np_seed)

if 'tensorflow' in sys.modules:
import tensorflow as tf

tf.random.set_seed(worker_seed)

return seed
2 changes: 2 additions & 0 deletions doc/dmlcloud.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ dmlcloud provides a set of helper functions to simplify the use of torch.distrib
:toctree: generated

init
seed
deinitialize_torch_distributed

is_root
root_only
Expand Down
19 changes: 17 additions & 2 deletions examples/mnist.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import argparse

import dmlcloud as dml
import torch
import torchmetrics
Expand Down Expand Up @@ -96,8 +98,21 @@ def _val_epoch(self):


def main():
pipe = dml.Pipeline(name='MNIST')
pipe.append(MNISTStage(epochs=3))
dml.init()

parser = argparse.ArgumentParser()
parser.add_argument('--epochs', type=int, default=4)
parser.add_argument('--seed', type=int)
args = parser.parse_args()

seed = dml.seed(args.seed) # This is a helper function to set the seed for all devices
config = {
'seed': seed,
'epochs': args.epochs,
}

pipe = dml.Pipeline(config, name='MNIST')
pipe.append(MNISTStage(epochs=args.epochs))
pipe.enable_checkpointing('checkpoints')
pipe.enable_wandb()
pipe.run()
Expand Down
109 changes: 109 additions & 0 deletions test/test_seed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import random
import sys

import dmlcloud as dml
import numpy as np
import pytest
import torch


def seed(seed=None):
seed = dml.seed(seed)
state = dict(
seed=seed,
torch_state=np.array(torch.get_rng_state()),
numpy_state=np.random.get_state()[1],
random_state=np.array(random.getstate()[1]),
)
return state


class TestSeed:
def test_single_worker_deterministic(self, torch_distributed):
prev_torch_state = np.array(torch.get_rng_state())
prev_numpy_state = np.random.get_state()[1]
prev_random_state = np.array(random.getstate()[1])

states = seed(42)
assert states['seed'] == 42
assert (states['torch_state'] != prev_torch_state).any()
assert (states['numpy_state'] != prev_numpy_state).any()
assert (states['random_state'] != prev_random_state).any()

# advance the RNG
torch.randint(0, 10, (1,))
np.random.randint(0, 10)

# reseeding should reset the RNG
new_states = seed(42)
assert new_states['seed'] == 42
assert (new_states['torch_state'] == states['torch_state']).all()
assert (new_states['numpy_state'] == states['numpy_state']).all()
assert (new_states['random_state'] == states['random_state']).all()

def test_input_validation(self, torch_distributed):
with pytest.raises(RuntimeError):
dml.seed(2**80)
assert dml.seed(2**64 - 1) == 2**64 - 1

def test_single_worker_random(self, torch_distributed):
prev_torch_state = np.array(torch.get_rng_state())
prev_numpy_state = np.random.get_state()[1]
prev_random_state = np.array(random.getstate()[1])

states = seed()
assert type(states['seed']) is int
assert (states['torch_state'] != prev_torch_state).any()
assert (states['numpy_state'] != prev_numpy_state).any()
assert (states['random_state'] != prev_random_state).any()

# reseeding should yield different states
new_states = seed()
assert new_states['seed'] != states['seed']
assert (new_states['torch_state'] != states['torch_state']).any()
assert (new_states['numpy_state'] != states['numpy_state']).any()
assert (new_states['random_state'] != states['random_state']).any()

def test_multi_worker_deterministic(self, distributed_environment):
states = distributed_environment(4).start(seed, 42)
assert [s['seed'] for s in states] == [42, 42, 42, 42]

# workers should have different states
assert all((s['torch_state'] != states[0]['torch_state']).any() for s in states[1:])
assert all((s['numpy_state'] != states[0]['numpy_state']).any() for s in states[1:])
assert all((s['random_state'] != states[0]['random_state']).any() for s in states[1:])

# same seed should yield same states
new_states = distributed_environment(4).start(seed, 42)
assert [s['seed'] for s in new_states] == [42, 42, 42, 42]
assert all((s1['torch_state'] == s2['torch_state']).all() for s1, s2 in zip(states, new_states))
assert all((s1['numpy_state'] == s2['numpy_state']).all() for s1, s2 in zip(states, new_states))
assert all((s1['random_state'] == s2['random_state']).all() for s1, s2 in zip(states, new_states))

# different seed should yield different states
new_states = distributed_environment(4).start(seed, 11)
assert [s['seed'] for s in new_states] == [11, 11, 11, 11]
assert all((s1['torch_state'] != s2['torch_state']).any() for s1, s2 in zip(states, new_states))
assert all((s1['numpy_state'] != s2['numpy_state']).any() for s1, s2 in zip(states, new_states))
assert all((s1['random_state'] != s2['random_state']).any() for s1, s2 in zip(states, new_states))

def test_multi_worker_random(self, distributed_environment):
# all workers should have same seeds
states = distributed_environment(4).start(seed)
assert [s['seed'] for s in states] == [states[0]['seed']] * 4

# workers should have different states
assert all((s['torch_state'] != states[0]['torch_state']).any() for s in states[1:])
assert all((s['numpy_state'] != states[0]['numpy_state']).any() for s in states[1:])
assert all((s['random_state'] != states[0]['random_state']).any() for s in states[1:])

# reseeding should yield different states and seeds
new_states = distributed_environment(4).start(seed)
assert [s['seed'] for s in new_states] != [s['seed'] for s in states]
assert all((s1['torch_state'] != s2['torch_state']).any() for s1, s2 in zip(states, new_states))
assert all((s1['numpy_state'] != s2['numpy_state']).any() for s1, s2 in zip(states, new_states))
assert all((s1['random_state'] != s2['random_state']).any() for s1, s2 in zip(states, new_states))


if __name__ == '__main__':
sys.exit(pytest.main([__file__]))

0 comments on commit e801e0f

Please sign in to comment.