Skip to content

Commit

Permalink
fix: test race condition
Browse files Browse the repository at this point in the history
  • Loading branch information
sehoffmann committed Jan 5, 2025
1 parent 5fa449b commit aa550f2
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 24 deletions.
2 changes: 1 addition & 1 deletion dmlcloud/core/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class CbPriority(IntEnum):
STAGE_TIMER = -180
DIAGNOSTICS = -170
GIT = -160
METRIC_REDUCTION = -160
METRIC_REDUCTION = -150

OBJECT_METHODS = 0

Expand Down
46 changes: 23 additions & 23 deletions test/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,25 @@ def __init__(self, idx):
self.t_post_epoch = []

def pre_run(self, pipe):
self.t_pre_run.append(time.time())
self.t_pre_run.append(time.monotonic_ns())

def post_run(self, pipe):
self.t_post_run.append(time.time())
self.t_post_run.append(time.monotonic_ns())

def pre_stage(self, stage):
self.t_pre_stage.append(time.time())
self.t_pre_stage.append(time.monotonic_ns())

def post_stage(self, stage):
self.t_post_stage.append(time.time())
self.t_post_stage.append(time.monotonic_ns())

def cleanup(self, pipe, exc_type, exc_value, traceback):
self.t_cleanup.append(time.time())
self.t_cleanup.append(time.monotonic_ns())

def pre_epoch(self, stage):
self.t_pre_epoch.append(time.time())
self.t_pre_epoch.append(time.monotonic_ns())

def post_epoch(self, stage):
self.t_post_epoch.append(time.time())
self.t_post_epoch.append(time.monotonic_ns())


class DummyStage(dml.Stage):
Expand All @@ -50,16 +50,16 @@ def __init__(self, name, epochs):
self.t_post_epoch = []

def pre_stage(self):
self.t_pre_stage.append(time.time())
self.t_pre_stage.append(time.monotonic_ns())

def post_stage(self):
self.t_post_stage.append(time.time())
self.t_post_stage.append(time.monotonic_ns())

def pre_epoch(self):
self.t_pre_epoch.append(time.time())
self.t_pre_epoch.append(time.monotonic_ns())

def post_epoch(self):
self.t_post_epoch.append(time.time())
self.t_post_epoch.append(time.monotonic_ns())

def run_epoch(self):
pass
Expand Down Expand Up @@ -122,11 +122,11 @@ def test_stage_methods(self, torch_distributed):
assert len(stage1.t_pre_epoch) == 2
assert len(stage1.t_post_epoch) == 2

assert stage1.t_pre_stage[0] < stage1.t_pre_epoch[0]
assert stage1.t_pre_epoch[0] < stage1.t_post_epoch[0]
assert stage1.t_post_epoch[0] < stage1.t_pre_epoch[1]
assert stage1.t_pre_epoch[1] < stage1.t_post_epoch[1]
assert stage1.t_post_epoch[1] < stage1.t_post_stage[0]
assert stage1.t_pre_stage[0] <= stage1.t_pre_epoch[0]
assert stage1.t_pre_epoch[0] <= stage1.t_post_epoch[0]
assert stage1.t_post_epoch[0] <= stage1.t_pre_epoch[1]
assert stage1.t_pre_epoch[1] <= stage1.t_post_epoch[1]
assert stage1.t_post_epoch[1] <= stage1.t_post_stage[0]

def test_stage_callback(self, torch_distributed):
pipe = dml.Pipeline()
Expand All @@ -148,8 +148,8 @@ def test_stage_callback(self, torch_distributed):
assert len(cb.t_pre_run) == 0
assert len(cb.t_post_run) == 0

assert stage1.t_pre_stage[0] < cb.t_pre_stage[0]
assert stage1.t_post_stage[0] < cb.t_post_stage[0]
assert stage1.t_pre_stage[0] <= cb.t_pre_stage[0]
assert stage1.t_post_stage[0] <= cb.t_post_stage[0]

def test_stage_callback_priority(self, torch_distributed):
pipe = dml.Pipeline()
Expand All @@ -171,8 +171,8 @@ def test_stage_callback_priority(self, torch_distributed):
assert len(cb.t_pre_run) == 0
assert len(cb.t_post_run) == 0

assert cb.t_pre_stage[0] < stage1.t_pre_stage[0]
assert cb.t_post_stage[0] < stage1.t_post_stage[0]
assert cb.t_pre_stage[0] <= stage1.t_pre_stage[0]
assert cb.t_post_stage[0] <= stage1.t_post_stage[0]

def test_pipeline_callback(self, torch_distributed):
pipe = dml.Pipeline()
Expand All @@ -194,9 +194,9 @@ def test_pipeline_callback(self, torch_distributed):
assert len(cb.t_pre_epoch) == 2
assert len(cb.t_post_epoch) == 2

assert cb.t_pre_run[0] < cb.t_pre_stage[0]
assert cb.t_post_stage[0] < cb.t_post_run[0]
assert cb.t_post_run[0] < cb.t_cleanup[0]
assert cb.t_pre_run[0] <= cb.t_pre_stage[0]
assert cb.t_post_stage[0] <= cb.t_post_run[0]
assert cb.t_post_run[0] <= cb.t_cleanup[0]


if __name__ == '__main__':
Expand Down

0 comments on commit aa550f2

Please sign in to comment.