Skip to content

Commit

Permalink
fix(test): DistributedEnvironment on MacOS due to pickle errors
Browse files Browse the repository at this point in the history
  • Loading branch information
sehoffmann committed Jan 7, 2025
1 parent 120376c commit 4faffce
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 43 deletions.
31 changes: 18 additions & 13 deletions test/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from multiprocessing import Pipe, Process
import multiprocessing as mp

import pytest
import torch
Expand All @@ -21,15 +21,10 @@ def init(world_size, timeout=5 * 60, daemon=True):

return init

def __init__(self, world_size: int, timeout: int = 5 * 60, daemon: bool = True, file: str = None):
self.world_size = world_size
self.timeout = timeout
self.daemon = daemon
self.file = str(file)

def _run(self, rank, conn, func, *args, **kwargs):
store = torch.distributed.FileStore(self.file, self.world_size)
torch.distributed.init_process_group(backend='gloo', world_size=self.world_size, rank=rank, store=store)
@staticmethod # important to be staticmethod, otherwise pickle will fail
def _run(rank, world_size, file, conn, func, *args, **kwargs):
store = torch.distributed.FileStore(file, world_size)
torch.distributed.init_process_group(backend='gloo', world_size=world_size, rank=rank, store=store)

torch.distributed.barrier()
ret = func(*args, **kwargs) # TODO: need to handle exceptions
Expand All @@ -39,14 +34,24 @@ def _run(self, rank, conn, func, *args, **kwargs):

torch.distributed.destroy_process_group()

def __init__(self, world_size: int, timeout: int = 5 * 60, daemon: bool = True, file: str = None):
self.world_size = world_size
self.timeout = timeout
self.daemon = daemon
self.file = str(file)

def start(self, func, *args, **kwargs):
ctx = mp.get_context('spawn')

self.processes = []
self.conns = []
for rank in range(self.world_size):
recv_conn, send_conn = Pipe()
process_args = (rank, send_conn, func) + args
recv_conn, send_conn = ctx.Pipe()
process_args = (rank, self.world_size, self.file, send_conn, func) + args
process_kwargs = dict(kwargs)
process = Process(target=self._run, args=process_args, kwargs=process_kwargs, daemon=self.daemon)
process = ctx.Process(
target=DistributedEnvironment._run, args=process_args, kwargs=process_kwargs, daemon=self.daemon
)
self.conns.append(recv_conn)
self.processes.append(process)

Expand Down
64 changes: 34 additions & 30 deletions test/test_root_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,15 +121,16 @@ def test_function(self, distributed_environment):
assert return_root_rank.__name__ == 'return_root_rank'
assert return_root_rank.__doc__ == 'TEST_DOC_STRING'

def test_stage(self, distributed_environment):
def run():
stage = RootOnlyStage(epochs=1)
pipe = dml.Pipeline()
pipe.append(stage)
pipe.run()
return stage.cb_executed
@staticmethod
def _test_stage_run():
stage = RootOnlyStage(epochs=1)
pipe = dml.Pipeline()
pipe.append(stage)
pipe.run()
return stage.cb_executed

results = distributed_environment(3).start(run)
def test_stage(self, distributed_environment):
results = distributed_environment(3).start(TestRootOnly._test_stage_run)

assert [r['pre_stage'] for r in results] == [True, False, False]
assert [r['post_stage'] for r in results] == [True, False, False]
Expand All @@ -155,15 +156,16 @@ def run():
assert RootOnlyStage.run_epoch.__name__ == 'run_epoch'
assert RootOnlyStage.run_epoch.__doc__ == 'TEST_DOC_STRING'

def test_partial_stage(self, distributed_environment):
def run():
stage = PartialRootOnlyStage(epochs=1)
pipe = dml.Pipeline()
pipe.append(stage)
pipe.run()
return stage.cb_executed
@staticmethod
def _test_partial_stage_run():
stage = PartialRootOnlyStage(epochs=1)
pipe = dml.Pipeline()
pipe.append(stage)
pipe.run()
return stage.cb_executed

results = distributed_environment(3).start(run)
def test_partial_stage(self, distributed_environment):
results = distributed_environment(3).start(TestRootOnly._test_partial_stage_run)

assert [r['pre_stage'] for r in results] == [True, True, True]
assert [r['post_stage'] for r in results] == [True, False, False]
Expand All @@ -177,14 +179,15 @@ def run():
assert PartialRootOnlyStage.pre_epoch.__name__ == 'pre_epoch'
assert PartialRootOnlyStage.pre_epoch.__doc__ == 'TEST_DOC_STRING'

def test_pipeline(self, distributed_environment):
def run():
pipe = RootOnlyPipeline()
pipe.append(RootOnlyStage(epochs=1))
pipe.run()
return pipe.cb_executed
@staticmethod
def _test_pipeline_run():
pipe = RootOnlyPipeline()
pipe.append(RootOnlyStage(epochs=1))
pipe.run()
return pipe.cb_executed

results = distributed_environment(3).start(run)
def test_pipeline(self, distributed_environment):
results = distributed_environment(3).start(TestRootOnly._test_pipeline_run)

assert [r['pre_run'] for r in results] == [True, False, False]
assert [r['post_run'] for r in results] == [True, False, False]
Expand All @@ -198,14 +201,15 @@ def run():
assert RootOnlyPipeline.post_run.__name__ == 'post_run'
assert RootOnlyPipeline.post_run.__doc__ == 'TEST_DOC_STRING'

def test_partial_pipeline(self, distributed_environment):
def run():
pipe = PartialRootOnlyPipeline()
pipe.append(RootOnlyStage(epochs=1))
pipe.run()
return pipe.cb_executed
@staticmethod
def _test_partial_pipeline_run():
pipe = PartialRootOnlyPipeline()
pipe.append(RootOnlyStage(epochs=1))
pipe.run()
return pipe.cb_executed

results = distributed_environment(3).start(run)
def test_partial_pipeline(self, distributed_environment):
results = distributed_environment(3).start(TestRootOnly._test_partial_pipeline_run)

assert [r['pre_run'] for r in results] == [True, False, False]
assert [r['post_run'] for r in results] == [True, True, True]
Expand Down

0 comments on commit 4faffce

Please sign in to comment.