Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

End-to-end tests for torch_xla llama and mixtral #100

Merged
merged 3 commits into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,6 @@ htmlcov/

# torchprime config
.config

# Google cloud credentials generated during CI
gha-creds-*.json
71 changes: 71 additions & 0 deletions .github/actions/e2e-setup/action.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
name: E2E setup
description: Setup torchprime, docker, and gcloud for E2E tests
inputs:
gcp_project:
required: true
type: string
gcp_zone:
required: true
type: string
xpk_cluster_name:
required: true
type: string
tpu_type:
required: true
type: string
artifact_dir:
required: true
type: string
gcp_sa_key:
description: GCP service account key
required: true
type: string
runs:
using: "composite"
steps:
- name: Use Docker in rootless mode
uses: ScribeMD/[email protected]
- name: Add user to docker group
run: |
sudo usermod -aG docker $USER
newgrp docker
shell: bash
- uses: actions/setup-python@v5
with:
python-version: '3.10'
cache: 'pip'
- name: Install dev dependencies
run: |
python -m pip install --upgrade pip
pip install -e '.[dev]'
shell: bash
# Googlers: if this fails, follow http://shortn/_61iSj31q1b to debug.
- uses: google-github-actions/auth@v2
with:
credentials_json: '${{ inputs.gcp_sa_key }}'
- uses: google-github-actions/setup-gcloud@v2
with:
version: '>= 363.0.0'
install_components: 'beta,gke-gcloud-auth-plugin'
- name: Verify GCP setup
run: gcloud info
shell: bash
- name: Authenticate Docker
run: gcloud auth configure-docker --quiet
shell: bash
- name: Activate SA credentials
run: gcloud auth activate-service-account --key-file=$GOOGLE_APPLICATION_CREDENTIALS
shell: bash
- name: tp doctor
run: tp doctor
shell: bash
- name: tp use
run: >
tp use
--project '${{ inputs.gcp_project }}'
--zone '${{ inputs.gcp_zone }}'
--cluster '${{ inputs.xpk_cluster_name }}'
--num-slices 1
--artifact-dir '${{ inputs.artifact_dir }}'
--tpu-type '${{ inputs.tpu_type }}'
shell: bash
2 changes: 1 addition & 1 deletion .github/workflows/cpu_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ on:
branches:
- main
pull_request:
schedule:
schedule: # Schedule the job run at 12AM PST daily.
- cron: "0 8 * * *"

jobs:
Expand Down
99 changes: 99 additions & 0 deletions .github/workflows/e2e_test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
name: E2E tests

on:
push:
branches:
- main
pull_request:
schedule:
- cron: "0 8 * * *" # Run daily at 12AM PST (adjusted for UTC)

jobs:
tp-run:
name: Submit workloads
runs-on: ubuntu-22.04
env:
ARTIFACT_DIR: gs://torchprime-e2e-tests/${{ github.job }}/${{ github.run_id }}-${{ github.run_attempt }}
outputs:
llama-3-8b-name: ${{ steps.run-llama-3-8b.outputs.name }}
mixtral-8x7b-name: ${{ steps.run-mixtral-8x7b.outputs.name }}
artifact-dir: ${{ steps.artifacts.outputs.artifact_dir }}
steps:
- name: Record artifact dir
id: artifacts
run: |
echo "Artifact dir: $ARTIFACT_DIR"
echo "artifact_dir=$ARTIFACT_DIR" >> "$GITHUB_OUTPUT"
- name: Maximize build space
uses: AdityaGarg8/[email protected]
with:
remove-dotnet: 'true'
remove-android: 'true'
remove-haskell: 'true'
remove-codeql: 'true'
- uses: actions/checkout@v4
- uses: ./.github/actions/e2e-setup
with:
gcp_project: ${{ vars.GCP_PROJECT }}
gcp_zone: ${{ vars.GCP_ZONE }}
xpk_cluster_name: ${{ vars.XPK_CLUSTER_NAME }}
tpu_type: ${{ vars.TPU_TYPE }}
artifact_dir: ${{ env.ARTIFACT_DIR }}
gcp_sa_key: ${{ secrets.GCP_SA_KEY }}

- name: Run Llama 3.0 8B
id: run-llama-3-8b
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
XLA_IR_DEBUG: 1
XLA_HLO_DEBUG: 1
run: |
name=$(e2e_testing/gen_name.py llama-3-8b)
echo "name=$name" >> "$GITHUB_OUTPUT"
tp run \
--name $name \
torchprime/torch_xla_models/train.py \
model=llama-3-8b \
global_batch_size=8 \
mesh.fsdp=4 \
dataset_config_name=wikitext-2-raw-v1 \
profile_step=3 \
max_steps=15

- name: Run Mixtral 8x7B
id: run-mixtral-8x7b
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
XLA_IR_DEBUG: 1
XLA_HLO_DEBUG: 1
run: |
name=$(e2e_testing/gen_name.py mixtral-8x7b)
echo "name=$name" >> "$GITHUB_OUTPUT"
tp run \
--name $name \
torchprime/torch_xla_models/train.py \
model=mixtral-8x7b \
model.num_hidden_layers=16 \
global_batch_size=8 \
mesh.fsdp=4 \
dataset_config_name=wikitext-2-raw-v1 \
profile_step=3 \
max_steps=15

llama-3-8b:
name: Llama 3.0 8B
needs: tp-run
uses: ./.github/workflows/reusable_e2e_check.yml
with:
jobset_name: ${{ needs.tp-run.outputs.llama-3-8b-name }}
artifact_dir: ${{ needs.tp-run.outputs.artifact-dir }}
secrets: inherit

mixtral-8x7b:
name: Mixtral 8x7B
needs: tp-run
uses: ./.github/workflows/reusable_e2e_check.yml
with:
jobset_name: ${{ needs.tp-run.outputs.mixtral-8x7b-name }}
artifact_dir: ${{ needs.tp-run.outputs.artifact-dir }}
secrets: inherit
54 changes: 54 additions & 0 deletions .github/workflows/reusable_e2e_check.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
name: Reusable E2E Check Workflow

on:
workflow_call:
inputs:
jobset_name:
description: "The jobset name to check (e.g. llama-3-8b-XXXX)"
required: true
type: string
artifact_dir:
description: "GCS artifact directory to use for the run"
required: true
type: string
secrets:
GCP_SA_KEY:
required: true
# TODO(https://github.com/AI-Hypercomputer/torchprime/issues/14): Remove and burn the token.
HF_TOKEN:
required: true

jobs:
results:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v4
- uses: ./.github/actions/e2e-setup
with:
gcp_project: ${{ vars.GCP_PROJECT }}
gcp_zone: ${{ vars.GCP_ZONE }}
xpk_cluster_name: ${{ vars.XPK_CLUSTER_NAME }}
tpu_type: ${{ vars.TPU_TYPE }}
artifact_dir: ${{ inputs.artifact_dir }}
gcp_sa_key: ${{ secrets.GCP_SA_KEY }}
- name: Get GKE credentials
run: |
gcloud container clusters get-credentials ${{ vars.XPK_CLUSTER_NAME }} --region=${{ vars.GCP_ZONE }} --project=${{ vars.GCP_PROJECT }}
kubectl config view
kubectl config set-context --current --namespace=default
- name: Stream logs
run: |
pod_name=$(kubectl get pods -l jobset.sigs.k8s.io/jobset-name=${{ inputs.jobset_name }} -o json | jq --raw-output '.items[0].metadata.name')
# Save logs to a file for later checks
kubectl logs -c jax-tpu -f $pod_name | tee /tmp/pod-$pod_name.log
- name: Wait for workload to complete
run: |
xpk workload list --cluster ${{ vars.XPK_CLUSTER_NAME }} --wait-for-job-completion=${{ inputs.jobset_name }}
- name: Validate logs
run: |
pod_name=$(kubectl get pods -l jobset.sigs.k8s.io/jobset-name=${{ inputs.jobset_name }} -o json | jq --raw-output '.items[0].metadata.name')
e2e_testing/check_logs.py /tmp/pod-$pod_name.log
- name: Validate profile
run: |
profile_dir="${{ inputs.artifact_dir }}/${{ inputs.jobset_name }}/profile/0-0"
e2e_testing/check_profile.py "$profile_dir"
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,6 @@ htmlcov/

# torchprime config
.config

# Google cloud credentials generated during CI
gha-creds-*.json
25 changes: 25 additions & 0 deletions e2e_testing/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# E2E testing

These scripts are used during the [E2E test][e2e-test] GitHub action to run some
models and validate the results.

## E2E test design

The workflows in [e2e_test.yml][e2e-test] does a few things:

- Set up `gcloud` credentials from a Service Account key managed in repo secrets.
- Install `torchprime`.
- Test `tp use` and point it to an XPK cluster hosted internally.
- Test `tp run` on a few models.

After kicking off the training of some models, it starts a parallel job for each
model, and runs a few checks. This is implemented in
[reusable_e2e_check.yml][e2e-check]:

- Stream the logs.
- Check workload exit code.
- Check for specific log strings that indicate training success.
- Check that there is a profile `.pb` file.

[e2e-test]: /.github/workflows/e2e_test.yml
[e2e-check]: /.github/workflows/reusable_e2e_check.yml
33 changes: 33 additions & 0 deletions e2e_testing/check_logs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#!/usr/bin/env python3
import re
import sys


def check_logs(file_path):
try:
with open(file_path) as f:
log_data = f.read()
except Exception as e:
print(f"Error reading log file {file_path}: {e}")
return 1

# Validate that the log contains the expected patterns.
if not re.search(r"Finished training run", log_data):
print("Error: 'Finished training run' not found in logs")
return 1

step_duration = re.search(r"Step duration:.*s", log_data)
if not step_duration:
print("Error: 'Step duration' not found in logs")
return 1

print(step_duration.group())
print("Logs check passed.")
return 0


if __name__ == "__main__":
if len(sys.argv) != 2:
print("Usage: check_logs.py <log_file>")
sys.exit(1)
sys.exit(check_logs(sys.argv[1]))
39 changes: 39 additions & 0 deletions e2e_testing/check_profile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#!/usr/bin/env python3
import subprocess
import sys


def check_profile(profile_dir):
try:
# List all .xplane.pb files recursively in the given directory.
result = subprocess.run(
["gcloud", "storage", "ls", "-r", f"{profile_dir}/**/*.xplane.pb"],
capture_output=True,
text=True,
check=True,
)
except subprocess.CalledProcessError as e:
print(f"Error running gcloud storage ls: {e}", file=sys.stderr)
return 1

# Count the number of matching files.
files = [line for line in result.stdout.splitlines() if line.strip()]
count = len(files)
if count != 1:
print(
f"Error: Expected exactly one .xplane.pb file in {profile_dir}, found {count}.",
file=sys.stderr,
)
print("Files found:", result.stdout)
return 1

print(f"Found profile at: {files[0]}")
print("Profile check passed.")
return 0


if __name__ == "__main__":
if len(sys.argv) != 2:
print("Usage: check_profile.py <profile_dir>")
sys.exit(1)
sys.exit(check_profile(sys.argv[1]))
17 changes: 17 additions & 0 deletions e2e_testing/gen_name.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#!/usr/bin/env python3
import random
import sys


def gen_name(name: str | None):
# 8 random characters from a-z and 0-9
random_id = "".join(random.choices("abcdefghijklmnopqrstuvwxyz0123456789", k=8))
name = random_id if name is None else f"{name}-{random_id}"
return name


if __name__ == "__main__":
if len(sys.argv) > 2:
print("Usage: gen_name.py [name]")
sys.exit(1)
print(gen_name(sys.argv[1] if len(sys.argv) > 1 else None), flush=True, end="")
3 changes: 1 addition & 2 deletions torchprime/launcher/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ FROM us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_3.10_tpuv

ARG USE_TRANSFORMERS=false
# Install system dependencies
RUN apt-get update && apt-get upgrade -y
RUN apt-get update && apt-get install -y curl gnupg

# Add the Google Cloud SDK package repository
Expand All @@ -25,7 +24,7 @@ RUN update-alternatives --install /usr/bin/python3 python3 /usr/local/bin/python
WORKDIR /workspaces

# Install torchax
RUN git clone https://github.com/pytorch/xla.git
RUN git clone --depth 1 https://github.com/pytorch/xla.git
WORKDIR /workspaces/xla/torchax
RUN pip install torch_xla[pallas] \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
Expand Down
Loading