Skip to content

Commit

Permalink
fix jax versions
Browse files Browse the repository at this point in the history
  • Loading branch information
priyakasimbeg committed Feb 8, 2025
1 parent 082be03 commit 39bb876
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,13 @@ jax_core_deps = [
"protobuf==4.25.5",
]
jax_cpu = [
"jax==0.4.26",
"jaxlib==0.4.26",
"jax==0.4.28",
"jaxlib==0.4.28",
"algorithmic_efficiency[jax_core_deps]",
]
jax_gpu = [
"jax==0.4.26",
"jaxlib==0.4.26",
"jax==0.4.28",
"jaxlib==0.4.28",
"jax-cuda12-plugin[with_cuda]==0.4.28",
"jax-cuda12-pjrt==0.4.28",
"algorithmic_efficiency[jax_core_deps]",
Expand Down

0 comments on commit 39bb876

Please sign in to comment.