-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2 from AI-Hypercomputer/Mixtral_pytorch_v6e
Add training instructions for Mixtral using PyTorch on trillium
- Loading branch information
Showing
13 changed files
with
417 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# Instructions for training Mixtral-8X7B on Trillium(v6e) TPU | ||
|
||
|
||
This user guide provides a concise overview of the essential steps required to run HuggingFace (HF) Mixtral training on Cloud TPUs. | ||
|
||
|
||
## Environment Setup | ||
|
||
Please follow the corresponding TPU generation's user guide to setup the GCE TPUs | ||
first. | ||
|
||
Please replace all your-* with your TPUs' information. | ||
|
||
``` | ||
export TPU_NAME=your-tpu-name | ||
export ZONE=your-tpu-zone | ||
export PROJECT=your-tpu-project | ||
``` | ||
|
||
You may use this command to create a 256 chip v6e slice: | ||
|
||
``` | ||
gcloud alpha compute tpus tpu-vm create $TPU_NAME \ | ||
--accelerator-type v6e-256 --project $PROJECT --zone $ZONE \ | ||
--version v2-alpha-tpuv6e | ||
``` | ||
|
||
## Steps to Run HF Mixtral 8x7B | ||
|
||
The following setup runs the training job with Mixtral 8x7B on GCE TPUs using the docker image from this registry (``), the docker image uses the pytorch and torch_xla nightly build from 10/28/2024 and installed with all the package dependency needed to run the model training. All the command below should run from your own machine (not the TPU host you created). | ||
|
||
1. git clone and navigate to this README repo and run training script: | ||
```bash | ||
git clone https://github.com/AI-Hypercomputer/tpu-recipes.git | ||
cd training/trillium/Mixtral-8x7B-PyTorch | ||
``` | ||
2. Edit `env.sh` to add the hugging face token and/or setup the training parameters. | ||
```bash | ||
# add your hugging face token | ||
HF_TOKEN=hf_*** | ||
``` | ||
3. Edit `host.sh` to add the docker image URL if default docker image is not accessible to you. | ||
```bash | ||
# docker image URL to use for the training | ||
DOCKER_IMAGE=us-central1-docker.pkg.dev/deeplearning-images/reproducibility/pytorch-tpu-mixtral:v0 | ||
``` | ||
4. Run the training script: | ||
```bash | ||
./benchmark.sh | ||
``` | ||
`benchmark.sh` script will upload 1) environment parameters in `env.sh`, 2) model related config in `config.json`, `fsdp_config.json`, 3) docker launch script in `host.sh` and 4) python training command in `train.sh` into all TPU workers, and starts the training afterwards. When all training steps complete, it will print out training metrics of each worker as below in terminal: | ||
``` | ||
***** train metrics ***** | ||
[worker :3] ***** train metrics ***** | ||
[worker :3] epoch = 0.0391 | ||
[worker :3] total_flos = 216428520GF | ||
[worker :3] train_loss = 8.443 | ||
[worker :3] train_runtime = 0:04:23.15 | ||
[worker :3] train_samples = 32816 | ||
[worker :3] train_samples_per_second = 4.864 | ||
``` | ||
In addition, it will copy back the trained model under `output/*`. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
#!/bin/bash | ||
|
||
# SCP the environment setup to all instances. | ||
gcloud compute tpus tpu-vm scp config.json fsdp_config.json train.sh host.sh env.sh "$TPU_NAME:~" --worker=all --project $PROJECT --zone=$ZONE | ||
|
||
# Actually runs the benchmark. | ||
gcloud compute tpus tpu-vm ssh $TPU_NAME --project $PROJECT --zone=$ZONE --worker=all --command="$(cat host.sh)" | ||
|
||
# Copy the profile and output back | ||
gcloud compute tpus tpu-vm scp --recurse $TPU_NAME:~/output ./ --project=$PROJECT --zone=$ZONE |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
{ | ||
"architectures": [ | ||
"MixtralForCausalLM" | ||
], | ||
"attention_dropout": 0.0, | ||
"bos_token_id": 1, | ||
"eos_token_id": 2, | ||
"hidden_act": "silu", | ||
"hidden_size": 4096, | ||
"initializer_range": 0.02, | ||
"intermediate_size": 14336, | ||
"max_position_embeddings": 32768, | ||
"model_type": "mixtral", | ||
"num_attention_heads": 32, | ||
"num_experts_per_tok": 2, | ||
"num_hidden_layers": 32, | ||
"num_key_value_heads": 8, | ||
"num_local_experts": 8, | ||
"output_router_logits": false, | ||
"rms_norm_eps": 1e-05, | ||
"rope_theta": 1000000.0, | ||
"router_aux_loss_coef": 0.02, | ||
"sliding_window": null, | ||
"tie_word_embeddings": false, | ||
"torch_dtype": "bfloat16", | ||
"transformers_version": "4.36.0.dev0", | ||
"use_cache": false, | ||
"vocab_size": 32000 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# Uncomment below to set the Huggingface token | ||
# HF_TOKEN=hf_*** | ||
PJRT_DEVICE=TPU | ||
XLA_IR_DEBUG=1 | ||
XLA_HLO_DEBUG=1 | ||
PROFILE_EPOCH=0 | ||
PROFILE_STEP=3 | ||
PROFILE_DURATION_MS=120000 | ||
XLA_USE_SPMD=1 | ||
MAX_STEPS=20 | ||
SEQ_LENGTH=4096 | ||
|
||
GLOBAL_BATCH_SIZE=1024 | ||
|
||
# XLA flags | ||
LIBTPU_INIT_ARGS=--xla_tpu_enable_flash_attention=false --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true --xla_tpu_scoped_vmem_limit_kib=81920 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
{ | ||
"fsdp_transformer_layer_cls_to_wrap": [ | ||
"MixtralDecoderLayer" | ||
], | ||
"xla": true, | ||
"xla_fsdp_v2": true, | ||
"xla_fsdp_grad_ckpt": true | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
#!/bin/bash | ||
|
||
DOCKER_IMAGE=us-central1-docker.pkg.dev/deeplearning-images/reproducibility/pytorch-tpu-mixtral:v1 | ||
|
||
worker_id=$(curl -s "http://metadata.google.internal/computeMetadata/v1/instance/attributes/agent-worker-number" -H 'Metadata-Flavor: Google') | ||
|
||
cat >> /dev/null <<EOF | ||
EOF | ||
|
||
stdbuf -oL bash <<-PIPE_EOF 2>&1 | sed "s/^/[worker $worker_id] /g" | tee runlog | ||
set -o xtrace | ||
# Configure docker | ||
sudo groupadd docker | ||
sudo usermod -aG docker $USER | ||
# newgrp applies updated group permissions | ||
newgrp - docker | ||
gcloud auth configure-docker us-central1-docker.pkg.dev --quiet | ||
# Kill any running benchmarks | ||
docker kill $USER-test | ||
docker pull $DOCKER_IMAGE | ||
docker run --rm \ | ||
--name $USER-test \ | ||
--privileged \ | ||
--env-file env.sh \ | ||
-v /home/$USER:/tmp/home \ | ||
--shm-size=16G \ | ||
--net host \ | ||
-u root \ | ||
--entrypoint /bin/bash $DOCKER_IMAGE \ | ||
/tmp/home/train.sh | ||
PIPE_EOF |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
#!/bin/bash | ||
# Remove existing repo and old data. | ||
LOCAL_DIR=/tmp/home/ | ||
rm -rf "${LOCAL_DIR}/output" | ||
rm -rf "${LOCAL_DIR}/plugins" | ||
rm -rf "${LOCAL_DIR}/cache" | ||
mkdir -p "${LOCAL_DIR}/output" | ||
mkdir -p "${LOCAL_DIR}/plugins" | ||
mkdir -p "${LOCAL_DIR}/cache" | ||
|
||
unset LD_PRELOAD | ||
|
||
|
||
cd transformers/ | ||
|
||
|
||
python3 examples/pytorch/language-modeling/run_clm.py \ | ||
--dataset_name wikitext \ | ||
--dataset_config_name wikitext-103-raw-v1 \ | ||
--per_device_train_batch_size "${GLOBAL_BATCH_SIZE}" \ | ||
--do_train \ | ||
--output_dir "${LOCAL_DIR}/output/test-clm" \ | ||
--overwrite_output_dir \ | ||
--config_name "${LOCAL_DIR}/config.json" \ | ||
--cache_dir "${LOCAL_DIR}/cache" \ | ||
--tokenizer_name mistralai/Mixtral-8x7B-v0.1 \ | ||
--block_size "$SEQ_LENGTH" \ | ||
--optim adafactor \ | ||
--save_strategy no \ | ||
--logging_strategy no \ | ||
--fsdp "full_shard" \ | ||
--fsdp_config "${LOCAL_DIR}/fsdp_config.json" \ | ||
--torch_dtype bfloat16 \ | ||
--dataloader_drop_last yes \ | ||
--flash_attention \ | ||
--num_train_epochs 1 \ | ||
--max_steps "$MAX_STEPS" \ | ||
--gmm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
|
||
|
||
# Instructions for training Mixtral 8x7B on Trillium TPU on multipod using XPK | ||
|
||
## Environment Steup | ||
--- | ||
### 1. [Optional but suggested] Create virtual env | ||
```bash | ||
sudo apt-get update && sudo apt install python3.10-venv | ||
python3.10 -m venv myenv | ||
source myenv/bin/activate | ||
``` | ||
--- | ||
### 2. Clone XPK repository and install XPK package | ||
```bash | ||
pushd ./ | ||
git clone https://github.com/google/xpk.git | ||
cd xpk | ||
pip install . | ||
popd | ||
``` | ||
--- | ||
### 3. Update and export environment variables | ||
Modify environment variables in `env.sh` targetting your gcloud resource and the experiment model config. Source the script for future use. | ||
```bash | ||
source env.sh | ||
``` | ||
|
||
--- | ||
### 4. [Optional, skip if using existing XPK cluster] Create the XPK clusters | ||
Please follow the corresponding XPK user guide to crea the XPK cluster first. If the cluster is already created, skip to Step 4. | ||
```bash | ||
|
||
NETWORK_NAME=${CLUSTER_NAME}-mtu9k | ||
NETWORK_FW_NAME=${NETWORK_NAME}-fw | ||
|
||
# Use a custom network for better performance as well as avoid the default network to be overloaded. | ||
gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT} --subnet-mode=auto --bgp-routing-mode=regional | ||
gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network ${NETWORK_NAME} --allow tcp,icmp,udp --project=${PROJECT} | ||
export CLUSTER_ARGUMENTS="--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}" | ||
|
||
python3 xpk.py cluster create --cluster $CLUSTER_NAME --cluster-cpu-machine-type=n1-standard-8 --num-slices=$NUM_SLICES --tpu-type=$TPU_TYPE --zone=$ZONE --project=$PROJECT --on-demand --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" --create-vertex-tensorboard --gke-version=1.31.1-gke.1678000 | ||
``` | ||
Note thatt if the `gke-version` is not available anymore, pick one available from the error message from the terminal output. | ||
|
||
--- | ||
### 5. Launch the training workload to XPK cluster. | ||
``` | ||
bash benchmark.sh | ||
``` | ||
|
||
Below is part of the sample output from | ||
``` | ||
... | ||
[XPK] Waiting for `Upload Docker Image`, for 7 seconds | ||
sqpu-2024-11-01-01-15-40: digest: sha256:3fe8b828bc6f96b1c74220d90273147ee188601781330d3592bbffc4fa0897af size: 4951 | ||
[XPK] Task: `Upload Docker Image` terminated with code `0` | ||
[XPK] Task: `Creating Workload` is implemented by `kubectl apply -f /tmp/tmpc65ikqh3`, streaming output live. | ||
[XPK] Waiting for `Creating Workload`, for 0 seconds | ||
jobset.jobset.x-k8s.io/piz-xpk-v6e-256 created | ||
[XPK] Task: `Creating Workload` terminated with code `0` | ||
[XPK] Task: `GKE Dashboard List` is implemented by `gcloud monitoring dashboards list --project=tpu-prod-env-automated --filter="displayName:'GKE - TPU Monitoring Dashboard'" --format="value(name)" --verbosity=error`, hiding output unless there is an error. | ||
[XPK] No dashboard with displayName:'GKE - TPU Monitoring Dashboard' found in the project:tpu-prod-env-automated. | ||
[XPK] Follow https://github.com/google/cloud-tpu-monitoring-debugging to deploy monitoring dashboard to view statistics and outlier mode of GKE metrics. | ||
[XPK] Follow your workload here: https://console.cloud.google.com/kubernetes/service/us-east5/bodaborg-v6e-256/default/piz-xpk-v6e-256/details?project=tpu-prod-env-automated | ||
[XPK] Exiting XPK cleanly | ||
``` | ||
This will point you to a workload link `https://console.cloud.google.com/kubernetes/service/...`. Follow the workload link and check the log. If the training works correctly, we shall see below info from the log explorer: | ||
``` | ||
... | ||
INFO 2024-10-31T11:23:30.060710856Z ***** train metrics ***** | ||
INFO 2024-10-31T11:23:30.060713436Z epoch = 3.125 | ||
INFO 2024-10-31T11:23:30.060715786Z total_flos = 109152470400GF | ||
INFO 2024-10-31T11:23:30.060718096Z train_loss = 7.4942 | ||
INFO 2024-10-31T11:23:30.060720436Z train_runtime = 0:49:59.16 | ||
INFO 2024-10-31T11:23:30.060722736Z train_samples = 32816 | ||
INFO 2024-10-31T11:23:30.060725156Z train_samples_per_second = 34.143 | ||
INFO 2024-10-31T11:23:30.060727556Z train_steps_per_second = 0.033 | ||
... | ||
EXIT_CODE=0 | ||
XPK End: Thu Oct 31 02:03:01 UTC 2024 | ||
``` | ||
--- | ||
### 6. [Optional] Metric processing | ||
You can use the profile | ||
``` | ||
# download the profile from gcp bucket to local | ||
gsutil cp -r $PROFILE_LOG_DIR ./ | ||
# feed in the xplane.pd file, e.g., | ||
python utils/profile_convert.py ${PROFILE_LOG_DIR}/plugins/profile/2024_10_31_02_00_47/127.0.0.1_9012.xplane.pb | ||
``` | ||
|
||
You will see output like that tells the average step time in second: | ||
``` | ||
Parsing plugins/profile/2024_10_31_00_44_09/127.0.0.1_9012.xplane.pb | ||
Plane ID: 2, Name: /device:TPU:0 | ||
Line ID: 2, Name: XLA Modules | ||
Event Metadata Name: SyncTensorsGraph.65923(1604004898989247534), ID: 36337, Duration: 16.780938099922 s | ||
Event Metadata Name: SyncTensorsGraph.65924(16619407271639597682), ID: 72675, Duration: 1.846361047078 s | ||
Event Metadata Name: SyncTensorsGraph.65924(16619407271639597682), ID: 72675, Duration: 1.845788159422 s | ||
Event Metadata Name: SyncTensorsGraph.65924(16619407271639597682), ID: 72675, Duration: 1.84276413525 s | ||
Event Metadata Name: SyncTensorsGraph.65924(16619407271639597682), ID: 72675, Duration: 1.838797222828 s | ||
Event Metadata Name: SyncTensorsGraph.65924(16619407271639597682), ID: 72675, Duration: 1.850977674094 s | ||
Event Metadata Name: SyncTensorsGraph.65924(16619407271639597682), ID: 72675, Duration: 1.862297948406 s | ||
Event Metadata Name: SyncTensorsGraph.65924(16619407271639597682), ID: 72675, Duration: 1.838890659 s | ||
Event Metadata Name: SyncTensorsGraph.65924(16619407271639597682), ID: 72675, Duration: 1.837627439172 s | ||
Event Metadata Name: SyncTensorsGraph.65924(16619407271639597682), ID: 72675, Duration: 1.835626750328 s | ||
Got 10 iterations | ||
1.8454 | ||
``` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
#!/bin/bash | ||
|
||
source env.sh | ||
|
||
python3 xpk.py workload create \ | ||
--cluster ${CLUSTER_NAME} \ | ||
--base-docker-image=${BASE_DOCKER_IMAGE} \ | ||
--workload=${WORKLOAD_NAME} \ | ||
--tpu-type=${TPU_TYPE} \ | ||
--num-slices=${NUM_SLICE} \ | ||
--on-demand \ | ||
--zone=$ZONE \ | ||
--project=$PROJECT \ | ||
--enable-debug-logs \ | ||
--command="bash train.sh" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
{ | ||
"architectures": [ | ||
"MixtralForCausalLM" | ||
], | ||
"attention_dropout": 0.0, | ||
"bos_token_id": 1, | ||
"eos_token_id": 2, | ||
"hidden_act": "silu", | ||
"hidden_size": 4096, | ||
"initializer_range": 0.02, | ||
"intermediate_size": 14336, | ||
"max_position_embeddings": 32768, | ||
"model_type": "mixtral", | ||
"num_attention_heads": 32, | ||
"num_experts_per_tok": 2, | ||
"num_hidden_layers": 32, | ||
"num_key_value_heads": 8, | ||
"num_local_experts": 8, | ||
"output_router_logits": false, | ||
"rms_norm_eps": 1e-05, | ||
"rope_theta": 1000000.0, | ||
"router_aux_loss_coef": 0.02, | ||
"sliding_window": null, | ||
"tie_word_embeddings": false, | ||
"torch_dtype": "bfloat16", | ||
"transformers_version": "4.36.0.dev0", | ||
"use_cache": false, | ||
"vocab_size": 32000 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
#!/bin/bash | ||
|
||
# Environment variables associated with XPK on GCP. | ||
export ZONE=... | ||
export PROJECT=... | ||
export TPU_TYPE=v6e-256 | ||
export NUM_SLICE=1 | ||
export CLUSTER_NAME=xpk-$USER-... # use existing CLUSTER if you have | ||
|
||
# Environment variables associated with training config. | ||
export BATCH_PER_DEVICE=4 | ||
export SEQUENCE_LENGTH=4096 | ||
export MAX_STEP=50 | ||
export WORKLOAD_NAME=${USER}-xpk-${TPU_TYPE}-... # Your workload name. Need to update for different run. | ||
export BASE_DOCKER_IMAGE=us-central1-docker.pkg.dev/deeplearning-images/reproducibility/pytorch-tpu-mixtral:v1 | ||
export PROFILE_LOG_DIR=... # GCS bucket to store profile in form of gs://... | ||
export HF_TOKEN=... # Add your own Hugging face token to download model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
{ | ||
"fsdp_transformer_layer_cls_to_wrap": [ | ||
"MixtralDecoderLayer" | ||
], | ||
"xla": true, | ||
"xla_fsdp_v2": true, | ||
"xla_fsdp_grad_ckpt": true | ||
} |
Oops, something went wrong.