Skip to content

Commit

Permalink
Merge pull request #1249 from AI-Hypercomputer:parambole/jsts_gpu_pp
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 725387669
  • Loading branch information
maxtext authors committed Feb 11, 2025
2 parents e9b0b71 + 3ea3888 commit aaf467e
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 6 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/UploadDockerImages.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ jobs:
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxtext_gpu_jax_pinned MODE=pinned DEVICE=gpu PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxtext_gpu_local_jax_pinned
- name: build jax stable stack image
run : |
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxtext_gpu_jax_stable_stack MODE=stable_stack DEVICE=gpu PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxtext_gpu_jax_stable_stack BASEIMAGE=us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/gpu:latest MAXTEXT_REQUIREMENTS_FILE=requirements_with_jax_stable_stack.txt
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxtext_gpu_jax_stable_stack MODE=stable_stack DEVICE=gpu PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxtext_gpu_jax_stable_stack BASEIMAGE=us-central1-docker.pkg.dev/deeplearning-images/jax-stable-stack/gpu:latest MAXTEXT_REQUIREMENTS_FILE=requirements_with_jax_stable_stack.txt
- name: build image with stable stack nightly jax
run: |
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxtext_gpu_stable_stack_nightly_jax MODE=stable_stack DEVICE=gpu PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxtext_gpu_jax_stable_stack_nightly BASEIMAGE=us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/gpu/jax_nightly:latest MAXTEXT_REQUIREMENTS_FILE=requirements_with_jax_stable_stack.txt
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,9 @@ As in the TPU case, note that the compilation environment must match the executi

## Automatically Upload Logs to Vertex Tensorboard
MaxText supports automatic upload of logs collected in a directory to a Tensorboard instance in Vertex AI. Follow [user guide](getting_started/Use_Vertex_AI_Tensorboard.md) to know more.

## Announcement

### February 2025

* (Preview): We're excited to announce the preview of building Maxtext Docker images using the JAX Stable Stack base image, available for both TPUs and GPUs. This provides a more reliable and consistent build environment. Learn more [Here](getting_started/Run_MaxText_via_xpk.md)
21 changes: 16 additions & 5 deletions getting_started/Run_MaxText_via_xpk.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,22 +79,33 @@ after which log out and log back in to the machine.
bash docker_build_dependency_image.sh
```

#### New: Build Maxtext Docker Image with JAX Stable Stack
We're excited to announce that you can build the Maxtext Docker image using the JAX Stable Stack base image. This provides a more reliable and consistent build environment.
#### Build Maxtext Docker Image with JAX Stable Stack (Preview)
We're excited to announce the preview of building Maxtext Docker images using the JAX Stable Stack base image, available for both TPUs and GPUs. This provides a more reliable and consistent build environment.
###### What is JAX Stable Stack?
JAX Stable Stack provides a consistent environment for Maxtext by bundling JAX with core packages like `orbax`, `flax`, and `optax`, along with Google Cloud utilities and other essential tools. These libraries are tested to ensure compatibility, providing a stable foundation for building and running Maxtext and eliminating potential conflicts due to incompatible package versions.
###### How to Use It
To build the Maxtext Docker image with JAX Stable Stack, simply set the MODE to `stable_stack` and specify the desired `BASEIMAGE` in the `docker_build_dependency_image.sh` script:
Use the `docker_build_dependency_image.sh` script to build your Maxtext Docker image with JAX Stable Stack. Set MODE to `stable_stack` and specify the desired `BASEIMAGE`. The `DEVICE` variable determines whether to build for TPUs or GPUs.
###### For TPUs:
```
# Example bash docker_build_dependency_image.sh MODE=stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.33-rev1
bash docker_build_dependency_image.sh MODE=stable_stack BASEIMAGE={{JAX_STABLE_STACK_BASEIMAGE}}
# Example: bash docker_build_dependency_image.sh DEVICE=tpu MODE=stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.37-rev1
bash docker_build_dependency_image.sh DEVICE=tpu MODE=stable_stack BASEIMAGE={{JAX_STABLE_STACK_TPU_BASEIMAGE}}
```
You can find a list of available JAX Stable Stack base images [here](https://us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu).
###### [New] For GPUs:
```
# Example bash docker_build_dependency_image.sh DEVICE=gpu MODE=stable_stack BASEIMAGE=us-central1-docker.pkg.dev/deeplearning-images/jax-stable-stack/gpu:jax0.4.37-cuda_dl24.10-rev1
bash docker_build_dependency_image.sh MODE=stable_stack BASEIMAGE={{JAX_STABLE_STACK_BASEIMAGE}}
```
You can find a list of available JAX Stable Stack base images [here](us-central1-docker.pkg.dev/deeplearning-images/jax-stable-stack/gpu).
**Important Note:** The JAX Stable Stack is currently in the experimental phase. We encourage you to try it out and provide feedback.
3. After building the dependency image `maxtext_base_image`, xpk can handle updates to the working directory when running `xpk workload create` and using `--base-docker-image`.
Expand Down

0 comments on commit aaf467e

Please sign in to comment.