Skip to content

Commit

Permalink
Merge pull request #26325 from hawkinsp:tpu2
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 723298028
  • Loading branch information
Google-ML-Automation committed Feb 5, 2025
2 parents 307006e + b1a2c27 commit 781172c
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* Changes
* `JAX_CPU_COLLECTIVES_IMPLEMENTATION` and `JAX_NUM_CPU_DEVICES` now work as
env vars. Before they could only be specified via jax.config or flags.
* The `jax[tpu]` TPU extra no longer depends on the `libtpu-nightly` package.
This package may safely be removed if it is present on your machine; JAX now
uses `libtpu` instead.

## jax 0.5.0 (Jan 17, 2025)

Expand Down
3 changes: 0 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
_latest_jaxlib_version_on_pypi = '0.5.0'

_libtpu_version = '0.0.8'
_libtpu_nightly_terminal_version = '0.1.dev20241010+nightly.cleanup'

def load_version_module(pkg_path):
spec = importlib.util.spec_from_file_location(
Expand Down Expand Up @@ -77,8 +76,6 @@ def load_version_module(pkg_path):
# $ pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
'tpu': [
f'jaxlib>={_current_jaxlib_version},<={_jax_version}',
# TODO(phawkins): remove the libtpu-nightly dependency in Q1 2025.
f'libtpu-nightly=={_libtpu_nightly_terminal_version}',
f'libtpu=={_libtpu_version}',
'requests', # necessary for jax.distributed.initialize
],
Expand Down

0 comments on commit 781172c

Please sign in to comment.