You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Getting the following error while migrating from PRNGKey to key.
Here is the full traceback:
Traceback (most recent call last):
File "/home/isaac/Desktop/opensource/isaac_algorithmic_efficiency/algorithmic-efficiency/submission_runner.py", line 714, in <module>
app.run(main)
File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
^^^^^^^^^^
File "/home/isaac/Desktop/opensource/isaac_algorithmic_efficiency/algorithmic-efficiency/submission_runner.py", line 682, in main
score = score_submission_on_workload(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/isaac/Desktop/opensource/isaac_algorithmic_efficiency/algorithmic-efficiency/submission_runner.py", line 587, in score_submission_on_workload
timing, metrics = train_once(workload, workload_name,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/isaac/Desktop/opensource/isaac_algorithmic_efficiency/algorithmic-efficiency/submission_runner.py", line 221, in train_once
input_queue = workload._build_input_queue(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/isaac/Desktop/opensource/isaac_algorithmic_efficiency/algorithmic-efficiency/algorithmic_efficiency/workloads/mnist/workload.py", line 155, in _build_input_queue
ds = _build_mnist_dataset(
^^^^^^^^^^^^^^^^^^^^^
File "/home/isaac/Desktop/opensource/isaac_algorithmic_efficiency/algorithmic-efficiency/algorithmic_efficiency/workloads/mnist/workload.py", line 58, in _build_mnist_dataset
ds = ds.shuffle(16 * global_batch_size, seed=data_rng[0])
~~~~~~~~^^^
File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py", line 646, in _getitem
return lax_numpy._rewriting_take(self, item)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 11411, in _rewriting_take
return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 11420, in _gather
indexer = _index_to_gather(shape(arr), idx) # shared with _scatter_update
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 11528, in _index_to_gather
idx = _canonicalize_tuple_index(len(x_shape), idx)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 11852, in _canonicalize_tuple_index
raise IndexError(
IndexError: Too many indices: 0-dimensional array indexed with 1 regular index.
The text was updated successfully, but these errors were encountered:
I can resolve this by using jax.random.key_data(key),
I've added the following lines in submission_runner.py file:
data_rng = jax.random.key_data(data_rng) at line 213
and
eval_rng = jax.random.key_data(eval_rng) at line 339
I did not encounter the error in other keys so maybe they dont need to changed.
can you please tell me if this seems okay?
Hi Isaac, I think what we want to do instead is change all the places we index into the the rng to just use the array.
E.g. from the above traceback you'll notice in our code:
File "/home/isaac/Desktop/opensource/isaac_algorithmic_efficiency/algorithmic-efficiency/algorithmic_efficiency/workloads/mnist/workload.py", line 58, in _build_mnist_dataset
ds = ds.shuffle(16 * global_batch_size, seed=data_rng[0])
We want to change data_rng[0] to just data_rng, since migrating to random.key will return 0 dimensional Arrays.
I would do a search for rng[0] to find the pieces of code to correct.
I tried doing this but it seems like the dtype isn't directly accessible via tfds,
the type of data_rng is <class 'jax._src.prng.PRNGKeyArray'> and
the type of jax.random.key_data(data_rng)[0] is <class 'jaxlib.xla_extension.ArrayImpl'>
and tfds is accepting the later,
by just changing the data_rng[0] to data_rng I am getting the following error:
TypeError: remainder does not accept dtypes key, int32
this is mentioned in the documentation:
"""
To convert between the two, use jax.random.key_data() and jax.random.wrap_key_data(). The legacy key format may be needed when interfacing with systems outside of JAX (e.g. exporting arrays to a serializable format), or when passing keys to JAX-based libraries that assume the legacy format.
"""
Let me know your thoughts on ths
System Info:
Ubuntu 20.04,
Python 3.11,
Nvidia3080ti
Jax Versions:
jax==0.4.35
jax-cuda12-pjrt==0.4.35
jax-cuda12-plugin==0.4.35
jaxlib==0.4.35
Getting the following error while migrating from PRNGKey to key.
Here is the full traceback:
The text was updated successfully, but these errors were encountered: