forked from GoogleCloudPlatform/ai-on-gke
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Jetstream support (GoogleCloudPlatform#677)
- Loading branch information
Showing
23 changed files
with
386 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
29 changes: 29 additions & 0 deletions
29
benchmarks/benchmark/tools/locust-load-inference/sample-tfvars/jetstream-sample.tfvars
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
credentials_config = { | ||
fleet_host = "https://connectgateway.googleapis.com/v1/projects/PROJECT_NUMBER/locations/global/gkeMemberships/ai-benchmark" | ||
} | ||
|
||
project_id = "PROJECT_ID" | ||
|
||
namespace = "default" | ||
ksa = "benchmark-sa" | ||
request_type = "grpc" | ||
|
||
k8s_hf_secret = "hf-token" | ||
|
||
|
||
# Locust service configuration | ||
artifact_registry = "REGISTRY_LOCATION" | ||
inference_server_service = "jetstream-svc:9000" | ||
locust_runner_kubernetes_service_account = "sample-runner-sa" | ||
output_bucket = "${PROJECT_ID}-benchmark-output-bucket-01" | ||
gcs_path = "PATH_TO_PROMPT_BUCKET" | ||
|
||
# Benchmark configuration for Locust Docker accessing inference server | ||
inference_server_framework = "jetstream" | ||
tokenizer = "google/gemma-7b" | ||
|
||
# Benchmark configuration for triggering single test via Locust Runner | ||
test_duration = 60 | ||
# Increase test_users to allow more parallelism (especially when testing HPA) | ||
test_users = 1 | ||
test_rate = 5 |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
# AI on GKE Benchmarking for JetStream | ||
|
||
Deploying and benchmarking JetStream on TPU has many similarities with the standard GPU path. But distinct enough differences to warrant a separate readme. If you are familiar with deploying on GPU, much of this should be familiar. For a more detailed understanding of each step. Refer to our primary benchmarking [README](https://github.com/GoogleCloudPlatform/ai-on-gke/tree/main/benchmarks) | ||
|
||
## Pre-requisites | ||
- [kaggle user/token](https://www.kaggle.com/docs/api) | ||
- [huggingface user/token](https://huggingface.co/docs/hub/en/security-tokens) | ||
|
||
### Creating K8s infra | ||
|
||
To create our TPU cluster, run: | ||
|
||
``` | ||
# Stage 1 creates the cluster. | ||
cd infra/stage-1 | ||
# Copy the sample variables and update the project ID, cluster name and other | ||
parameters as needed in the `terraform.tfvars` file. | ||
cp sample-tfvars/jetstream-sample.tfvars terraform.tfvars | ||
# Initialize the Terraform modules. | ||
terraform init | ||
# Run plan to see the changes that will be made. | ||
terraform plan | ||
# Run apply if the changes look good by confirming the prompt. | ||
terraform apply | ||
``` | ||
To verify that the cluster has been set up correctly, run | ||
``` | ||
# Get credentials using fleet membership | ||
gcloud container fleet memberships get-credentials <cluster-name> | ||
# Run a kubectl command to verify | ||
kubectl get nodes | ||
``` | ||
|
||
## Configure the cluster | ||
|
||
To configure the cluster to run inference workloads we need to set up workload identity and GCS Fuse. | ||
``` | ||
# Stage 2 configures the cluster for running inference workloads. | ||
cd infra/stage-2 | ||
# Copy the sample variables and update the project number and cluster name in | ||
# the fleet_host variable "https://connectgateway.googleapis.com/v1/projects/<project-number>/locations/global/gkeMemberships/<cluster-name>" | ||
# and the project name and bucket name parameters as needed in the | ||
# `terraform.tfvars` file. You can specify a new bucket name in which case it | ||
# will be created. | ||
cp sample-tfvars/jetstream-sample.tfvars terraform.tfvars | ||
# Initialize the Terraform modules. | ||
terraform init | ||
# Run plan to see the changes that will be made. | ||
terraform plan | ||
# Run apply if the changes look good by confirming the prompt. | ||
terraform apply | ||
``` | ||
|
||
### Convert Gemma model weights to maxtext weights | ||
|
||
JetStream has [two engine implementations](https://github.com/google/JetStream?tab=readme-ov-file#jetstream-engine-implementation). A Jax variant (via MaxText) and a Pytorch variant. This guide will use the Jax backend. | ||
|
||
Jetstream currently requires that models be converted to MaxText weights. This example will deploy a Gemma-7b model. Much of this information is similar to this guide [here](https://cloud.google.com/kubernetes-engine/docs/tutorials/serve-gemma-tpu-jetstream#convert-checkpoints). | ||
|
||
*SKIP IF ALREADY COMPLETED* | ||
|
||
Create kaggle secret | ||
``` | ||
kubectl create secret generic kaggle-secret \ | ||
--from-file=kaggle.json | ||
``` | ||
|
||
Replace `model-conversion/kaggle_converter.yaml: GEMMA_BUCKET_NAME` with the correct bucket name where you would like the model to be stored. | ||
***NOTE: If you are using a different bucket that the ones you created give the service account Storage Admin permissions on that bucket. This can be done on the UI or by running: | ||
``` | ||
gcloud projects add-iam-policy-binding PROJECT_ID \ | ||
--member "serviceAccount:SA_NAME@PROJECT_ID.iam.gserviceaccount.com" \ | ||
--role roles/storage.admin | ||
``` | ||
|
||
Run: | ||
``` | ||
kubectl apply -f model-conversion/kaggle_converter.yaml | ||
``` | ||
|
||
This should take ~10 minutes to complete. | ||
|
||
### Deploy JetStream | ||
|
||
Replace the `jetstream.yaml:GEMMA_BUCKET_NAME` with the same bucket name as above. | ||
|
||
Run: | ||
``` | ||
kubectl apply -f jetstream.yaml | ||
``` | ||
|
||
Verify the pod is running with | ||
``` | ||
kubectl get pods | ||
``` | ||
|
||
Get the external IP with: | ||
|
||
``` | ||
kubectl get services | ||
``` | ||
|
||
And you can make a request prompt with: | ||
``` | ||
curl --request POST \ | ||
--header "Content-type: application/json" \ | ||
-s \ | ||
JETSTREAM_EXTERNAL_IP:8000/generate \ | ||
--data \ | ||
'{ | ||
"prompt": "What is a TPU?", | ||
"max_tokens": 200 | ||
}' | ||
``` | ||
|
||
### Deploy the benchmark | ||
|
||
To prepare the dataset for the Locust inference benchmark, view the README.md file in: | ||
``` | ||
cd benchmark/dataset/ShareGPT_v3_unflitered_cleaned_split | ||
``` | ||
|
||
To deploy the Locust inference benchmark with the above model, run | ||
``` | ||
cd benchmark/tools/locust-load-inference | ||
# Copy the sample variables and update the project number and cluster name in | ||
# the fleet_host variable "https://connectgateway.googleapis.com/v1/projects/<project-number>/locations/global/gkeMemberships/<cluster-name>" | ||
# in the `terraform.tfvars` file. | ||
cp sample-tfvars/jetstream-sample.tfvars terraform.tfvars | ||
# Initialize the Terraform modules. | ||
terraform init | ||
# Run plan to see the changes that will be made. | ||
terraform plan | ||
# Run apply if the changes look good by confirming the prompt. | ||
terraform apply | ||
``` | ||
|
||
To further interact with the Locust inference benchmark, view the README.md file in `benchmark/tools/locust-load-inference` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
apiVersion: apps/v1 | ||
kind: Deployment | ||
metadata: | ||
name: maxengine-server | ||
spec: | ||
replicas: 1 | ||
selector: | ||
matchLabels: | ||
app: maxengine-server | ||
template: | ||
metadata: | ||
labels: | ||
app: maxengine-server | ||
spec: | ||
serviceAccountName: benchmark-sa | ||
nodeSelector: | ||
cloud.google.com/gke-tpu-topology: 2x2 | ||
cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice | ||
containers: | ||
- name: maxengine-server | ||
image: us-docker.pkg.dev/cloud-tpu-images/inference/maxengine-server:v0.2.0 | ||
args: | ||
- model_name=gemma-7b | ||
- tokenizer_path=assets/tokenizer.gemma | ||
- per_device_batch_size=4 | ||
- max_prefill_predict_length=1024 | ||
- max_target_length=2048 | ||
- async_checkpointing=false | ||
- ici_fsdp_parallelism=1 | ||
- ici_autoregressive_parallelism=-1 | ||
- ici_tensor_parallelism=1 | ||
- scan_layers=false | ||
- weight_dtype=bfloat16 | ||
- load_parameters_path=gs://GEMMA_BUCKET_NAME/final/unscanned/gemma_7b-it/0/checkpoints/0/items | ||
ports: | ||
- containerPort: 9000 | ||
resources: | ||
requests: | ||
google.com/tpu: 4 | ||
limits: | ||
google.com/tpu: 4 | ||
- name: jetstream-http | ||
image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:v0.2.0 | ||
ports: | ||
- containerPort: 8000 | ||
--- | ||
apiVersion: v1 | ||
kind: Service | ||
metadata: | ||
name: jetstream-svc | ||
spec: | ||
selector: | ||
app: maxengine-server | ||
ports: | ||
- protocol: TCP | ||
name: http | ||
port: 8000 | ||
targetPort: 8000 | ||
- protocol: TCP | ||
name: grpc | ||
port: 9000 | ||
targetPort: 9000 | ||
type: LoadBalancer |
33 changes: 33 additions & 0 deletions
33
benchmarks/inference-server/jetstream/model-conversion/kaggle_converter.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
apiVersion: batch/v1 | ||
kind: Job | ||
metadata: | ||
name: data-loader-7b | ||
spec: | ||
ttlSecondsAfterFinished: 30 | ||
template: | ||
spec: | ||
serviceAccountName: benchmark-sa | ||
restartPolicy: Never | ||
containers: | ||
- name: inference-checkpoint | ||
image: us-docker.pkg.dev/cloud-tpu-images/inference/inference-checkpoint:v0.2.0 | ||
args: | ||
- -b=GEMMA_BUCKET_NAME | ||
- -m=google/gemma/maxtext/7b-it/2 | ||
volumeMounts: | ||
- mountPath: "/kaggle/" | ||
name: kaggle-credentials | ||
readOnly: true | ||
resources: | ||
requests: | ||
google.com/tpu: 4 | ||
limits: | ||
google.com/tpu: 4 | ||
nodeSelector: | ||
cloud.google.com/gke-tpu-topology: 2x2 | ||
cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice | ||
volumes: | ||
- name: kaggle-credentials | ||
secret: | ||
defaultMode: 0400 | ||
secretName: kaggle-secret |
Oops, something went wrong.