From aa550f2564ec5b8a8e21b9c54fbad0950ba7fbc3 Mon Sep 17 00:00:00 2001 From: Sebastian Hoffmann Date: Sun, 5 Jan 2025 16:15:39 +0100 Subject: [PATCH] fix: test race condition --- dmlcloud/core/callbacks.py | 2 +- test/test_callback.py | 46 +++++++++++++++++++------------------- 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/dmlcloud/core/callbacks.py b/dmlcloud/core/callbacks.py index 1623f06..d347033 100644 --- a/dmlcloud/core/callbacks.py +++ b/dmlcloud/core/callbacks.py @@ -91,7 +91,7 @@ class CbPriority(IntEnum): STAGE_TIMER = -180 DIAGNOSTICS = -170 GIT = -160 - METRIC_REDUCTION = -160 + METRIC_REDUCTION = -150 OBJECT_METHODS = 0 diff --git a/test/test_callback.py b/test/test_callback.py index fcc3103..94a27e2 100644 --- a/test/test_callback.py +++ b/test/test_callback.py @@ -20,25 +20,25 @@ def __init__(self, idx): self.t_post_epoch = [] def pre_run(self, pipe): - self.t_pre_run.append(time.time()) + self.t_pre_run.append(time.monotonic_ns()) def post_run(self, pipe): - self.t_post_run.append(time.time()) + self.t_post_run.append(time.monotonic_ns()) def pre_stage(self, stage): - self.t_pre_stage.append(time.time()) + self.t_pre_stage.append(time.monotonic_ns()) def post_stage(self, stage): - self.t_post_stage.append(time.time()) + self.t_post_stage.append(time.monotonic_ns()) def cleanup(self, pipe, exc_type, exc_value, traceback): - self.t_cleanup.append(time.time()) + self.t_cleanup.append(time.monotonic_ns()) def pre_epoch(self, stage): - self.t_pre_epoch.append(time.time()) + self.t_pre_epoch.append(time.monotonic_ns()) def post_epoch(self, stage): - self.t_post_epoch.append(time.time()) + self.t_post_epoch.append(time.monotonic_ns()) class DummyStage(dml.Stage): @@ -50,16 +50,16 @@ def __init__(self, name, epochs): self.t_post_epoch = [] def pre_stage(self): - self.t_pre_stage.append(time.time()) + self.t_pre_stage.append(time.monotonic_ns()) def post_stage(self): - self.t_post_stage.append(time.time()) + self.t_post_stage.append(time.monotonic_ns()) def pre_epoch(self): - self.t_pre_epoch.append(time.time()) + self.t_pre_epoch.append(time.monotonic_ns()) def post_epoch(self): - self.t_post_epoch.append(time.time()) + self.t_post_epoch.append(time.monotonic_ns()) def run_epoch(self): pass @@ -122,11 +122,11 @@ def test_stage_methods(self, torch_distributed): assert len(stage1.t_pre_epoch) == 2 assert len(stage1.t_post_epoch) == 2 - assert stage1.t_pre_stage[0] < stage1.t_pre_epoch[0] - assert stage1.t_pre_epoch[0] < stage1.t_post_epoch[0] - assert stage1.t_post_epoch[0] < stage1.t_pre_epoch[1] - assert stage1.t_pre_epoch[1] < stage1.t_post_epoch[1] - assert stage1.t_post_epoch[1] < stage1.t_post_stage[0] + assert stage1.t_pre_stage[0] <= stage1.t_pre_epoch[0] + assert stage1.t_pre_epoch[0] <= stage1.t_post_epoch[0] + assert stage1.t_post_epoch[0] <= stage1.t_pre_epoch[1] + assert stage1.t_pre_epoch[1] <= stage1.t_post_epoch[1] + assert stage1.t_post_epoch[1] <= stage1.t_post_stage[0] def test_stage_callback(self, torch_distributed): pipe = dml.Pipeline() @@ -148,8 +148,8 @@ def test_stage_callback(self, torch_distributed): assert len(cb.t_pre_run) == 0 assert len(cb.t_post_run) == 0 - assert stage1.t_pre_stage[0] < cb.t_pre_stage[0] - assert stage1.t_post_stage[0] < cb.t_post_stage[0] + assert stage1.t_pre_stage[0] <= cb.t_pre_stage[0] + assert stage1.t_post_stage[0] <= cb.t_post_stage[0] def test_stage_callback_priority(self, torch_distributed): pipe = dml.Pipeline() @@ -171,8 +171,8 @@ def test_stage_callback_priority(self, torch_distributed): assert len(cb.t_pre_run) == 0 assert len(cb.t_post_run) == 0 - assert cb.t_pre_stage[0] < stage1.t_pre_stage[0] - assert cb.t_post_stage[0] < stage1.t_post_stage[0] + assert cb.t_pre_stage[0] <= stage1.t_pre_stage[0] + assert cb.t_post_stage[0] <= stage1.t_post_stage[0] def test_pipeline_callback(self, torch_distributed): pipe = dml.Pipeline() @@ -194,9 +194,9 @@ def test_pipeline_callback(self, torch_distributed): assert len(cb.t_pre_epoch) == 2 assert len(cb.t_post_epoch) == 2 - assert cb.t_pre_run[0] < cb.t_pre_stage[0] - assert cb.t_post_stage[0] < cb.t_post_run[0] - assert cb.t_post_run[0] < cb.t_cleanup[0] + assert cb.t_pre_run[0] <= cb.t_pre_stage[0] + assert cb.t_post_stage[0] <= cb.t_post_run[0] + assert cb.t_post_run[0] <= cb.t_cleanup[0] if __name__ == '__main__':