Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IndexError: Too many indices: 0-dimensional array indexed with 1 regular index, while migrating from jax.random.PRNGKey to jax.random.key #815

Open
init-22 opened this issue Nov 21, 2024 · 3 comments

Comments

@init-22
Copy link

init-22 commented Nov 21, 2024

System Info:
Ubuntu 20.04,
Python 3.11,
Nvidia3080ti

Jax Versions:
jax==0.4.35
jax-cuda12-pjrt==0.4.35
jax-cuda12-plugin==0.4.35
jaxlib==0.4.35

Getting the following error while migrating from PRNGKey to key.

Here is the full traceback:

Traceback (most recent call last):
  File "/home/isaac/Desktop/opensource/isaac_algorithmic_efficiency/algorithmic-efficiency/submission_runner.py", line 714, in <module>
    app.run(main)
  File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
             ^^^^^^^^^^
  File "/home/isaac/Desktop/opensource/isaac_algorithmic_efficiency/algorithmic-efficiency/submission_runner.py", line 682, in main
    score = score_submission_on_workload(
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/isaac/Desktop/opensource/isaac_algorithmic_efficiency/algorithmic-efficiency/submission_runner.py", line 587, in score_submission_on_workload
    timing, metrics = train_once(workload, workload_name,
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/isaac/Desktop/opensource/isaac_algorithmic_efficiency/algorithmic-efficiency/submission_runner.py", line 221, in train_once
    input_queue = workload._build_input_queue(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/isaac/Desktop/opensource/isaac_algorithmic_efficiency/algorithmic-efficiency/algorithmic_efficiency/workloads/mnist/workload.py", line 155, in _build_input_queue
    ds = _build_mnist_dataset(
         ^^^^^^^^^^^^^^^^^^^^^
  File "/home/isaac/Desktop/opensource/isaac_algorithmic_efficiency/algorithmic-efficiency/algorithmic_efficiency/workloads/mnist/workload.py", line 58, in _build_mnist_dataset
    ds = ds.shuffle(16 * global_batch_size, seed=data_rng[0])
                                                 ~~~~~~~~^^^
  File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py", line 646, in _getitem
    return lax_numpy._rewriting_take(self, item)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 11411, in _rewriting_take
    return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 11420, in _gather
    indexer = _index_to_gather(shape(arr), idx)  # shared with _scatter_update
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 11528, in _index_to_gather
    idx = _canonicalize_tuple_index(len(x_shape), idx)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 11852, in _canonicalize_tuple_index
    raise IndexError(
IndexError: Too many indices: 0-dimensional array indexed with 1 regular index.

@init-22
Copy link
Author

init-22 commented Nov 23, 2024

@priyakasimbeg

I can resolve this by using jax.random.key_data(key),
I've added the following lines in submission_runner.py file:
data_rng = jax.random.key_data(data_rng) at line 213
and
eval_rng = jax.random.key_data(eval_rng) at line 339

I did not encounter the error in other keys so maybe they dont need to changed.
can you please tell me if this seems okay?

@priyakasimbeg
Copy link
Contributor

Hi Isaac, I think what we want to do instead is change all the places we index into the the rng to just use the array.

E.g. from the above traceback you'll notice in our code:

 File "/home/isaac/Desktop/opensource/isaac_algorithmic_efficiency/algorithmic-efficiency/algorithmic_efficiency/workloads/mnist/workload.py", line 58, in _build_mnist_dataset
    ds = ds.shuffle(16 * global_batch_size, seed=data_rng[0])

We want to change data_rng[0] to just data_rng, since migrating to random.key will return 0 dimensional Arrays.
I would do a search for rng[0] to find the pieces of code to correct.

@init-22
Copy link
Author

init-22 commented Nov 30, 2024

I tried doing this but it seems like the dtype isn't directly accessible via tfds,
the type of data_rng is <class 'jax._src.prng.PRNGKeyArray'> and
the type of jax.random.key_data(data_rng)[0] is <class 'jaxlib.xla_extension.ArrayImpl'>
and tfds is accepting the later,
by just changing the data_rng[0] to data_rng I am getting the following error:
TypeError: remainder does not accept dtypes key, int32

this is mentioned in the documentation:
"""
To convert between the two, use jax.random.key_data() and jax.random.wrap_key_data(). The legacy key format may be needed when interfacing with systems outside of JAX (e.g. exporting arrays to a serializable format), or when passing keys to JAX-based libraries that assume the legacy format.
"""
Let me know your thoughts on ths

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants