Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TypeError: JAX encountered invalid PRNG key data: expected key_data.dtype = uint32; got dtype=int32 #801

Closed
init-22 opened this issue Oct 26, 2024 · 6 comments

Comments

@init-22
Copy link

init-22 commented Oct 26, 2024

I was trying to run the submission_runner.py file inside the docker and got a TypeError,
Use these commands to reproduce the error:

sudo docker run -it   -v <PATH>/algorithmic-efficiency:/algorithmic-efficiency --runtime=nvidia algoperf_pytorch /bin/bash

cd algorithmic-efficiency

python3 submission_runner.py \
    --framework=pytorch \
    --workload=mnist \
    --experiment_dir=$HOME/experiments \
    --experiment_name=my_first_experiment \
    --submission_path=reference_algorithms/paper_baselines/adamw/jax/submission.py \
    --tuning_search_space=reference_algorithms/paper_baselines/adamw/tuning_search_space.json

Here is the traceback:

Traceback (most recent call last):
  File "submission_runner.py", line 714, in <module>
    app.run(main)
  File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "submission_runner.py", line 682, in main
    score = score_submission_on_workload(
  File "submission_runner.py", line 587, in score_submission_on_workload
    timing, metrics = train_once(workload, workload_name,
  File "submission_runner.py", line 351, in train_once
    optimizer_state, model_params, model_state = update_params(
  File "/algorithmic-efficiency/reference_algorithms/paper_baselines/adamw/jax/submission.py", line 130, in update_params
    per_device_rngs = jax.random.split(rng, jax.local_device_count())
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/random.py", line 217, in split
    key, wrapped = _check_prng_key(key)
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/random.py", line 79, in _check_prng_key
    return prng.random_wrap(key, impl=default_prng_impl()), True
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/prng.py", line 907, in random_wrap
    _check_prng_key_data(impl, base_arr)
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/prng.py", line 119, in _check_prng_key_data
    raise TypeError("JAX encountered invalid PRNG key data: expected key_data.dtype = uint32; "
TypeError: JAX encountered invalid PRNG key data: expected key_data.dtype = uint32; got dtype=int32

am I missing something?

@priyakasimbeg
Copy link
Contributor

I noticed you ran with a docker container called algoperf_pytorch. Does this have the correct jax dependencies installed?
Could you attach an output of pip freeze in the container environment?

@init-22
Copy link
Author

init-22 commented Oct 29, 2024

Oh yeah I created one for pytorch and just rebuilt it with both but got the same error
Please cheeck this:
pip_deps.txt

@priyakasimbeg
Copy link
Contributor

Correct command:

python3 submission_runner.py \
    --framework=pytorch \
    --workload=mnist \
    --experiment_dir=$HOME/experiments \
    --experiment_name=my_first_experiment \
    --submission_path=reference_algorithms/paper_baselines/adamw/pytorch/submission.py \
    --tuning_search_space=reference_algorithms/paper_baselines/adamw/tuning_search_space.json

@priyakasimbeg
Copy link
Contributor

priyakasimbeg commented Jan 18, 2025

After some debugging:

  1. random.Random(arg) with Python 3.11 throws an error if type(arg) is np.uint32, but works with Python 3.8.
    I think we can undo the changes in PR 810 that convert the types from int32 to uint32, so revert back to using np.int32.

  2. x % 2**32 where type(x) is np.int32 returns an overflow error with Python 3.11, but works with Python 3.8.
    We should modify the max and min constants to 2**31-1 and 0.

@priyakasimbeg
Copy link
Contributor

In the above PR changed the max and min to min and max uint32 (2**31-1).
Also reverted back the types from np.uint32 to np.int32.

@init-22
Copy link
Author

init-22 commented Jan 18, 2025

Sure, Thanks Priya!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants