diff --git a/test/conftest.py b/test/conftest.py index 0764c5d..c5c9aa0 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,4 +1,4 @@ -from multiprocessing import Pipe, Process +import multiprocessing as mp import pytest import torch @@ -21,15 +21,10 @@ def init(world_size, timeout=5 * 60, daemon=True): return init - def __init__(self, world_size: int, timeout: int = 5 * 60, daemon: bool = True, file: str = None): - self.world_size = world_size - self.timeout = timeout - self.daemon = daemon - self.file = str(file) - - def _run(self, rank, conn, func, *args, **kwargs): - store = torch.distributed.FileStore(self.file, self.world_size) - torch.distributed.init_process_group(backend='gloo', world_size=self.world_size, rank=rank, store=store) + @staticmethod # important to be staticmethod, otherwise pickle will fail + def _run(rank, world_size, file, conn, func, *args, **kwargs): + store = torch.distributed.FileStore(file, world_size) + torch.distributed.init_process_group(backend='gloo', world_size=world_size, rank=rank, store=store) torch.distributed.barrier() ret = func(*args, **kwargs) # TODO: need to handle exceptions @@ -39,14 +34,24 @@ def _run(self, rank, conn, func, *args, **kwargs): torch.distributed.destroy_process_group() + def __init__(self, world_size: int, timeout: int = 5 * 60, daemon: bool = True, file: str = None): + self.world_size = world_size + self.timeout = timeout + self.daemon = daemon + self.file = str(file) + def start(self, func, *args, **kwargs): + ctx = mp.get_context('spawn') + self.processes = [] self.conns = [] for rank in range(self.world_size): - recv_conn, send_conn = Pipe() - process_args = (rank, send_conn, func) + args + recv_conn, send_conn = ctx.Pipe() + process_args = (rank, self.world_size, self.file, send_conn, func) + args process_kwargs = dict(kwargs) - process = Process(target=self._run, args=process_args, kwargs=process_kwargs, daemon=self.daemon) + process = ctx.Process( + target=DistributedEnvironment._run, args=process_args, kwargs=process_kwargs, daemon=self.daemon + ) self.conns.append(recv_conn) self.processes.append(process) diff --git a/test/test_root_only.py b/test/test_root_only.py index 40bc9fb..7bed4a0 100644 --- a/test/test_root_only.py +++ b/test/test_root_only.py @@ -121,15 +121,16 @@ def test_function(self, distributed_environment): assert return_root_rank.__name__ == 'return_root_rank' assert return_root_rank.__doc__ == 'TEST_DOC_STRING' - def test_stage(self, distributed_environment): - def run(): - stage = RootOnlyStage(epochs=1) - pipe = dml.Pipeline() - pipe.append(stage) - pipe.run() - return stage.cb_executed + @staticmethod + def _test_stage_run(): + stage = RootOnlyStage(epochs=1) + pipe = dml.Pipeline() + pipe.append(stage) + pipe.run() + return stage.cb_executed - results = distributed_environment(3).start(run) + def test_stage(self, distributed_environment): + results = distributed_environment(3).start(TestRootOnly._test_stage_run) assert [r['pre_stage'] for r in results] == [True, False, False] assert [r['post_stage'] for r in results] == [True, False, False] @@ -155,15 +156,16 @@ def run(): assert RootOnlyStage.run_epoch.__name__ == 'run_epoch' assert RootOnlyStage.run_epoch.__doc__ == 'TEST_DOC_STRING' - def test_partial_stage(self, distributed_environment): - def run(): - stage = PartialRootOnlyStage(epochs=1) - pipe = dml.Pipeline() - pipe.append(stage) - pipe.run() - return stage.cb_executed + @staticmethod + def _test_partial_stage_run(): + stage = PartialRootOnlyStage(epochs=1) + pipe = dml.Pipeline() + pipe.append(stage) + pipe.run() + return stage.cb_executed - results = distributed_environment(3).start(run) + def test_partial_stage(self, distributed_environment): + results = distributed_environment(3).start(TestRootOnly._test_partial_stage_run) assert [r['pre_stage'] for r in results] == [True, True, True] assert [r['post_stage'] for r in results] == [True, False, False] @@ -177,14 +179,15 @@ def run(): assert PartialRootOnlyStage.pre_epoch.__name__ == 'pre_epoch' assert PartialRootOnlyStage.pre_epoch.__doc__ == 'TEST_DOC_STRING' - def test_pipeline(self, distributed_environment): - def run(): - pipe = RootOnlyPipeline() - pipe.append(RootOnlyStage(epochs=1)) - pipe.run() - return pipe.cb_executed + @staticmethod + def _test_pipeline_run(): + pipe = RootOnlyPipeline() + pipe.append(RootOnlyStage(epochs=1)) + pipe.run() + return pipe.cb_executed - results = distributed_environment(3).start(run) + def test_pipeline(self, distributed_environment): + results = distributed_environment(3).start(TestRootOnly._test_pipeline_run) assert [r['pre_run'] for r in results] == [True, False, False] assert [r['post_run'] for r in results] == [True, False, False] @@ -198,14 +201,15 @@ def run(): assert RootOnlyPipeline.post_run.__name__ == 'post_run' assert RootOnlyPipeline.post_run.__doc__ == 'TEST_DOC_STRING' - def test_partial_pipeline(self, distributed_environment): - def run(): - pipe = PartialRootOnlyPipeline() - pipe.append(RootOnlyStage(epochs=1)) - pipe.run() - return pipe.cb_executed + @staticmethod + def _test_partial_pipeline_run(): + pipe = PartialRootOnlyPipeline() + pipe.append(RootOnlyStage(epochs=1)) + pipe.run() + return pipe.cb_executed - results = distributed_environment(3).start(run) + def test_partial_pipeline(self, distributed_environment): + results = distributed_environment(3).start(TestRootOnly._test_partial_pipeline_run) assert [r['pre_run'] for r in results] == [True, False, False] assert [r['post_run'] for r in results] == [True, True, True]