From 01eb8819dbe36d8b54987758706247c63b3f73df Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 16 Jan 2025 23:16:02 +0000 Subject: [PATCH] change jax version --- setup.cfg | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/setup.cfg b/setup.cfg index 43458cb07..040d1e26a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -128,10 +128,10 @@ jax_cpu = # JAX GPU # Note this installs both jax and jaxlib. jax_gpu = - jax==0.4.38 - jaxlib==0.4.38 - jax-cuda12-plugin[with_cuda]==0.4.38 - jax-cuda12-pjrt==0.4.38 + jax==0.4.36 + jaxlib==0.4.36 + jax-cuda12-plugin[with_cuda]==0.4.36 + jax-cuda12-pjrt==0.4.36 %(jax_core_deps)s # PyTorch CPU