Skip to content

Commit

Permalink
remove extracting key_data
Browse files Browse the repository at this point in the history
  • Loading branch information
priyakasimbeg committed Jan 22, 2025
1 parent 3fb722d commit 9b7cee4
Showing 1 changed file with 0 additions and 3 deletions.
3 changes: 0 additions & 3 deletions submission_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,6 @@ def train_once(
_reset_cuda_mem()
data_rng, opt_init_rng, model_init_rng, rng = prng.split(rng, 4)

data_rng = jax.random.key_data(data_rng)
# Workload setup.
logging.info('Initializing dataset.')
if hasattr(workload, '_eval_num_workers'):
Expand Down Expand Up @@ -347,8 +346,6 @@ def train_once(
data_select_rng, update_rng, prep_eval_rng, eval_rng = \
prng.split(step_rng, 4)

eval_rng = jax.random.key_data(eval_rng)

with profiler.profile('Data selection'):
batch = data_selection(workload,
input_queue,
Expand Down

0 comments on commit 9b7cee4

Please sign in to comment.