Skip to content

Commit

Permalink
Add startup_random_jitter_max_secs flag to pax main.py.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 647018797
  • Loading branch information
zhangqiaorjc authored and pax authors committed Jun 26, 2024
1 parent 765e2fe commit 8944ff9
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions paxml/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,11 @@
'Timeout in seconds for asynchronous save operations. `None` indicates that'
' no timeout is set.',
)
flags.DEFINE_integer(
'startup_random_jitter_max_secs',
60,
'Random jitter in seconds to avoid thundering herd RPC calls at startup.',
)
flags.DEFINE_string(
'jax_traceback_filtering_option', 'auto',
'Controls how JAX filters internal frames out of tracebacks: '
Expand Down Expand Up @@ -425,6 +430,7 @@ def _setup_xm_work_unit():
def run(
experiment_config: base_experiment.BaseExperiment,
enable_checkpoint_saving: bool = True,
startup_random_jitter_max_secs: int = 60,
):
"""Run an experiment.
Expand All @@ -436,12 +442,14 @@ def run(
Args:
experiment_config: The experiment to run.
enable_checkpoint_saving: Whether to perform checkpoint saving or not.
startup_random_jitter_max_secs: The max seconds to wait for random jitter at
startup.
"""

# Add a note so that we can tell which Borg task is which JAX host.
# (Borg task 0 is not guaranteed to be host 0)
if jax.process_count() > 128:
wait_with_random_jitter(min_secs=0, max_secs=60)
wait_with_random_jitter(min_secs=0, max_secs=startup_random_jitter_max_secs)
work_unit = _setup_xm_work_unit()

# Start jax.profiler for TensorBoard and profiling in open source.
Expand Down Expand Up @@ -543,8 +551,11 @@ def _main(argv: Sequence[str]) -> None:
)

experiment_config.validate()
run(experiment_config=experiment_config,
enable_checkpoint_saving=FLAGS.enable_checkpoint_saving)
run(
experiment_config=experiment_config,
enable_checkpoint_saving=FLAGS.enable_checkpoint_saving,
startup_random_jitter_max_secs=FLAGS.startup_random_jitter_max_secs,
)


_TASK_HANDLE_RE = re.compile(r'(?:logs\.)?(\d+)\.(.*)\.([^.]+)\.\d+')
Expand Down

0 comments on commit 8944ff9

Please sign in to comment.