Skip to content

Commit

Permalink
Merge pull request #2 from AI-Hypercomputer/Mixtral_pytorch_v6e
Browse files Browse the repository at this point in the history
Add training instructions for Mixtral using PyTorch on trillium
  • Loading branch information
bhavya01 authored Nov 1, 2024
2 parents 0838d5e + a9857aa commit 48ba2e0
Show file tree
Hide file tree
Showing 13 changed files with 417 additions and 0 deletions.
62 changes: 62 additions & 0 deletions training/trillium/Mixtral-8x7B-Pytorch/GCE/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Instructions for training Mixtral-8X7B on Trillium(v6e) TPU


This user guide provides a concise overview of the essential steps required to run HuggingFace (HF) Mixtral training on Cloud TPUs.


## Environment Setup

Please follow the corresponding TPU generation's user guide to setup the GCE TPUs
first.

Please replace all your-* with your TPUs' information.

```
export TPU_NAME=your-tpu-name
export ZONE=your-tpu-zone
export PROJECT=your-tpu-project
```

You may use this command to create a 256 chip v6e slice:

```
gcloud alpha compute tpus tpu-vm create $TPU_NAME \
--accelerator-type v6e-256 --project $PROJECT --zone $ZONE \
--version v2-alpha-tpuv6e
```

## Steps to Run HF Mixtral 8x7B

The following setup runs the training job with Mixtral 8x7B on GCE TPUs using the docker image from this registry (``), the docker image uses the pytorch and torch_xla nightly build from 10/28/2024 and installed with all the package dependency needed to run the model training. All the command below should run from your own machine (not the TPU host you created).

1. git clone and navigate to this README repo and run training script:
```bash
git clone https://github.com/AI-Hypercomputer/tpu-recipes.git
cd training/trillium/Mixtral-8x7B-PyTorch
```
2. Edit `env.sh` to add the hugging face token and/or setup the training parameters.
```bash
# add your hugging face token
HF_TOKEN=hf_***
```
3. Edit `host.sh` to add the docker image URL if default docker image is not accessible to you.
```bash
# docker image URL to use for the training
DOCKER_IMAGE=us-central1-docker.pkg.dev/deeplearning-images/reproducibility/pytorch-tpu-mixtral:v0
```
4. Run the training script:
```bash
./benchmark.sh
```
`benchmark.sh` script will upload 1) environment parameters in `env.sh`, 2) model related config in `config.json`, `fsdp_config.json`, 3) docker launch script in `host.sh` and 4) python training command in `train.sh` into all TPU workers, and starts the training afterwards. When all training steps complete, it will print out training metrics of each worker as below in terminal:
```
***** train metrics *****
[worker :3] ***** train metrics *****
[worker :3] epoch = 0.0391
[worker :3] total_flos = 216428520GF
[worker :3] train_loss = 8.443
[worker :3] train_runtime = 0:04:23.15
[worker :3] train_samples = 32816
[worker :3] train_samples_per_second = 4.864
```
In addition, it will copy back the trained model under `output/*`.
10 changes: 10 additions & 0 deletions training/trillium/Mixtral-8x7B-Pytorch/GCE/benchmark.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/bin/bash

# SCP the environment setup to all instances.
gcloud compute tpus tpu-vm scp config.json fsdp_config.json train.sh host.sh env.sh "$TPU_NAME:~" --worker=all --project $PROJECT --zone=$ZONE

# Actually runs the benchmark.
gcloud compute tpus tpu-vm ssh $TPU_NAME --project $PROJECT --zone=$ZONE --worker=all --command="$(cat host.sh)"

# Copy the profile and output back
gcloud compute tpus tpu-vm scp --recurse $TPU_NAME:~/output ./ --project=$PROJECT --zone=$ZONE
29 changes: 29 additions & 0 deletions training/trillium/Mixtral-8x7B-Pytorch/GCE/config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{
"architectures": [
"MixtralForCausalLM"
],
"attention_dropout": 0.0,
"bos_token_id": 1,
"eos_token_id": 2,
"hidden_act": "silu",
"hidden_size": 4096,
"initializer_range": 0.02,
"intermediate_size": 14336,
"max_position_embeddings": 32768,
"model_type": "mixtral",
"num_attention_heads": 32,
"num_experts_per_tok": 2,
"num_hidden_layers": 32,
"num_key_value_heads": 8,
"num_local_experts": 8,
"output_router_logits": false,
"rms_norm_eps": 1e-05,
"rope_theta": 1000000.0,
"router_aux_loss_coef": 0.02,
"sliding_window": null,
"tie_word_embeddings": false,
"torch_dtype": "bfloat16",
"transformers_version": "4.36.0.dev0",
"use_cache": false,
"vocab_size": 32000
}
16 changes: 16 additions & 0 deletions training/trillium/Mixtral-8x7B-Pytorch/GCE/env.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Uncomment below to set the Huggingface token
# HF_TOKEN=hf_***
PJRT_DEVICE=TPU
XLA_IR_DEBUG=1
XLA_HLO_DEBUG=1
PROFILE_EPOCH=0
PROFILE_STEP=3
PROFILE_DURATION_MS=120000
XLA_USE_SPMD=1
MAX_STEPS=20
SEQ_LENGTH=4096

GLOBAL_BATCH_SIZE=1024

# XLA flags
LIBTPU_INIT_ARGS=--xla_tpu_enable_flash_attention=false --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true --xla_tpu_scoped_vmem_limit_kib=81920
8 changes: 8 additions & 0 deletions training/trillium/Mixtral-8x7B-Pytorch/GCE/fsdp_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"fsdp_transformer_layer_cls_to_wrap": [
"MixtralDecoderLayer"
],
"xla": true,
"xla_fsdp_v2": true,
"xla_fsdp_grad_ckpt": true
}
32 changes: 32 additions & 0 deletions training/trillium/Mixtral-8x7B-Pytorch/GCE/host.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#!/bin/bash

DOCKER_IMAGE=us-central1-docker.pkg.dev/deeplearning-images/reproducibility/pytorch-tpu-mixtral:v1

worker_id=$(curl -s "http://metadata.google.internal/computeMetadata/v1/instance/attributes/agent-worker-number" -H 'Metadata-Flavor: Google')

cat >> /dev/null <<EOF
EOF

stdbuf -oL bash <<-PIPE_EOF 2>&1 | sed "s/^/[worker $worker_id] /g" | tee runlog
set -o xtrace
# Configure docker
sudo groupadd docker
sudo usermod -aG docker $USER
# newgrp applies updated group permissions
newgrp - docker
gcloud auth configure-docker us-central1-docker.pkg.dev --quiet
# Kill any running benchmarks
docker kill $USER-test
docker pull $DOCKER_IMAGE
docker run --rm \
--name $USER-test \
--privileged \
--env-file env.sh \
-v /home/$USER:/tmp/home \
--shm-size=16G \
--net host \
-u root \
--entrypoint /bin/bash $DOCKER_IMAGE \
/tmp/home/train.sh
PIPE_EOF
38 changes: 38 additions & 0 deletions training/trillium/Mixtral-8x7B-Pytorch/GCE/train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#!/bin/bash
# Remove existing repo and old data.
LOCAL_DIR=/tmp/home/
rm -rf "${LOCAL_DIR}/output"
rm -rf "${LOCAL_DIR}/plugins"
rm -rf "${LOCAL_DIR}/cache"
mkdir -p "${LOCAL_DIR}/output"
mkdir -p "${LOCAL_DIR}/plugins"
mkdir -p "${LOCAL_DIR}/cache"

unset LD_PRELOAD


cd transformers/


python3 examples/pytorch/language-modeling/run_clm.py \
--dataset_name wikitext \
--dataset_config_name wikitext-103-raw-v1 \
--per_device_train_batch_size "${GLOBAL_BATCH_SIZE}" \
--do_train \
--output_dir "${LOCAL_DIR}/output/test-clm" \
--overwrite_output_dir \
--config_name "${LOCAL_DIR}/config.json" \
--cache_dir "${LOCAL_DIR}/cache" \
--tokenizer_name mistralai/Mixtral-8x7B-v0.1 \
--block_size "$SEQ_LENGTH" \
--optim adafactor \
--save_strategy no \
--logging_strategy no \
--fsdp "full_shard" \
--fsdp_config "${LOCAL_DIR}/fsdp_config.json" \
--torch_dtype bfloat16 \
--dataloader_drop_last yes \
--flash_attention \
--num_train_epochs 1 \
--max_steps "$MAX_STEPS" \
--gmm
113 changes: 113 additions & 0 deletions training/trillium/Mixtral-8x7B-Pytorch/XPK/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@


# Instructions for training Mixtral 8x7B on Trillium TPU on multipod using XPK

## Environment Steup
---
### 1. [Optional but suggested] Create virtual env
```bash
sudo apt-get update && sudo apt install python3.10-venv
python3.10 -m venv myenv
source myenv/bin/activate
```
---
### 2. Clone XPK repository and install XPK package
```bash
pushd ./
git clone https://github.com/google/xpk.git
cd xpk
pip install .
popd
```
---
### 3. Update and export environment variables
Modify environment variables in `env.sh` targetting your gcloud resource and the experiment model config. Source the script for future use.
```bash
source env.sh
```

---
### 4. [Optional, skip if using existing XPK cluster] Create the XPK clusters
Please follow the corresponding XPK user guide to crea the XPK cluster first. If the cluster is already created, skip to Step 4.
```bash

NETWORK_NAME=${CLUSTER_NAME}-mtu9k
NETWORK_FW_NAME=${NETWORK_NAME}-fw

# Use a custom network for better performance as well as avoid the default network to be overloaded.
gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT} --subnet-mode=auto --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network ${NETWORK_NAME} --allow tcp,icmp,udp --project=${PROJECT}
export CLUSTER_ARGUMENTS="--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}"

python3 xpk.py cluster create --cluster $CLUSTER_NAME --cluster-cpu-machine-type=n1-standard-8 --num-slices=$NUM_SLICES --tpu-type=$TPU_TYPE --zone=$ZONE --project=$PROJECT --on-demand --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" --create-vertex-tensorboard --gke-version=1.31.1-gke.1678000
```
Note thatt if the `gke-version` is not available anymore, pick one available from the error message from the terminal output.

---
### 5. Launch the training workload to XPK cluster.
```
bash benchmark.sh
```

Below is part of the sample output from
```
...
[XPK] Waiting for `Upload Docker Image`, for 7 seconds
sqpu-2024-11-01-01-15-40: digest: sha256:3fe8b828bc6f96b1c74220d90273147ee188601781330d3592bbffc4fa0897af size: 4951
[XPK] Task: `Upload Docker Image` terminated with code `0`
[XPK] Task: `Creating Workload` is implemented by `kubectl apply -f /tmp/tmpc65ikqh3`, streaming output live.
[XPK] Waiting for `Creating Workload`, for 0 seconds
jobset.jobset.x-k8s.io/piz-xpk-v6e-256 created
[XPK] Task: `Creating Workload` terminated with code `0`
[XPK] Task: `GKE Dashboard List` is implemented by `gcloud monitoring dashboards list --project=tpu-prod-env-automated --filter="displayName:'GKE - TPU Monitoring Dashboard'" --format="value(name)" --verbosity=error`, hiding output unless there is an error.
[XPK] No dashboard with displayName:'GKE - TPU Monitoring Dashboard' found in the project:tpu-prod-env-automated.
[XPK] Follow https://github.com/google/cloud-tpu-monitoring-debugging to deploy monitoring dashboard to view statistics and outlier mode of GKE metrics.
[XPK] Follow your workload here: https://console.cloud.google.com/kubernetes/service/us-east5/bodaborg-v6e-256/default/piz-xpk-v6e-256/details?project=tpu-prod-env-automated
[XPK] Exiting XPK cleanly
```
This will point you to a workload link `https://console.cloud.google.com/kubernetes/service/...`. Follow the workload link and check the log. If the training works correctly, we shall see below info from the log explorer:
```
...
INFO 2024-10-31T11:23:30.060710856Z ***** train metrics *****
INFO 2024-10-31T11:23:30.060713436Z epoch = 3.125
INFO 2024-10-31T11:23:30.060715786Z total_flos = 109152470400GF
INFO 2024-10-31T11:23:30.060718096Z train_loss = 7.4942
INFO 2024-10-31T11:23:30.060720436Z train_runtime = 0:49:59.16
INFO 2024-10-31T11:23:30.060722736Z train_samples = 32816
INFO 2024-10-31T11:23:30.060725156Z train_samples_per_second = 34.143
INFO 2024-10-31T11:23:30.060727556Z train_steps_per_second = 0.033
...
EXIT_CODE=0
XPK End: Thu Oct 31 02:03:01 UTC 2024
```
---
### 6. [Optional] Metric processing
You can use the profile
```
# download the profile from gcp bucket to local
gsutil cp -r $PROFILE_LOG_DIR ./
# feed in the xplane.pd file, e.g.,
python utils/profile_convert.py ${PROFILE_LOG_DIR}/plugins/profile/2024_10_31_02_00_47/127.0.0.1_9012.xplane.pb
```

You will see output like that tells the average step time in second:
```
Parsing plugins/profile/2024_10_31_00_44_09/127.0.0.1_9012.xplane.pb
Plane ID: 2, Name: /device:TPU:0
Line ID: 2, Name: XLA Modules
Event Metadata Name: SyncTensorsGraph.65923(1604004898989247534), ID: 36337, Duration: 16.780938099922 s
Event Metadata Name: SyncTensorsGraph.65924(16619407271639597682), ID: 72675, Duration: 1.846361047078 s
Event Metadata Name: SyncTensorsGraph.65924(16619407271639597682), ID: 72675, Duration: 1.845788159422 s
Event Metadata Name: SyncTensorsGraph.65924(16619407271639597682), ID: 72675, Duration: 1.84276413525 s
Event Metadata Name: SyncTensorsGraph.65924(16619407271639597682), ID: 72675, Duration: 1.838797222828 s
Event Metadata Name: SyncTensorsGraph.65924(16619407271639597682), ID: 72675, Duration: 1.850977674094 s
Event Metadata Name: SyncTensorsGraph.65924(16619407271639597682), ID: 72675, Duration: 1.862297948406 s
Event Metadata Name: SyncTensorsGraph.65924(16619407271639597682), ID: 72675, Duration: 1.838890659 s
Event Metadata Name: SyncTensorsGraph.65924(16619407271639597682), ID: 72675, Duration: 1.837627439172 s
Event Metadata Name: SyncTensorsGraph.65924(16619407271639597682), ID: 72675, Duration: 1.835626750328 s
Got 10 iterations
1.8454
```

15 changes: 15 additions & 0 deletions training/trillium/Mixtral-8x7B-Pytorch/XPK/benchmark.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#!/bin/bash

source env.sh

python3 xpk.py workload create \
--cluster ${CLUSTER_NAME} \
--base-docker-image=${BASE_DOCKER_IMAGE} \
--workload=${WORKLOAD_NAME} \
--tpu-type=${TPU_TYPE} \
--num-slices=${NUM_SLICE} \
--on-demand \
--zone=$ZONE \
--project=$PROJECT \
--enable-debug-logs \
--command="bash train.sh"
29 changes: 29 additions & 0 deletions training/trillium/Mixtral-8x7B-Pytorch/XPK/config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{
"architectures": [
"MixtralForCausalLM"
],
"attention_dropout": 0.0,
"bos_token_id": 1,
"eos_token_id": 2,
"hidden_act": "silu",
"hidden_size": 4096,
"initializer_range": 0.02,
"intermediate_size": 14336,
"max_position_embeddings": 32768,
"model_type": "mixtral",
"num_attention_heads": 32,
"num_experts_per_tok": 2,
"num_hidden_layers": 32,
"num_key_value_heads": 8,
"num_local_experts": 8,
"output_router_logits": false,
"rms_norm_eps": 1e-05,
"rope_theta": 1000000.0,
"router_aux_loss_coef": 0.02,
"sliding_window": null,
"tie_word_embeddings": false,
"torch_dtype": "bfloat16",
"transformers_version": "4.36.0.dev0",
"use_cache": false,
"vocab_size": 32000
}
17 changes: 17 additions & 0 deletions training/trillium/Mixtral-8x7B-Pytorch/XPK/env.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#!/bin/bash

# Environment variables associated with XPK on GCP.
export ZONE=...
export PROJECT=...
export TPU_TYPE=v6e-256
export NUM_SLICE=1
export CLUSTER_NAME=xpk-$USER-... # use existing CLUSTER if you have

# Environment variables associated with training config.
export BATCH_PER_DEVICE=4
export SEQUENCE_LENGTH=4096
export MAX_STEP=50
export WORKLOAD_NAME=${USER}-xpk-${TPU_TYPE}-... # Your workload name. Need to update for different run.
export BASE_DOCKER_IMAGE=us-central1-docker.pkg.dev/deeplearning-images/reproducibility/pytorch-tpu-mixtral:v1
export PROFILE_LOG_DIR=... # GCS bucket to store profile in form of gs://...
export HF_TOKEN=... # Add your own Hugging face token to download model
8 changes: 8 additions & 0 deletions training/trillium/Mixtral-8x7B-Pytorch/XPK/fsdp_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"fsdp_transformer_layer_cls_to_wrap": [
"MixtralDecoderLayer"
],
"xla": true,
"xla_fsdp_v2": true,
"xla_fsdp_grad_ckpt": true
}
Loading

0 comments on commit 48ba2e0

Please sign in to comment.