Skip to content

Commit

Permalink
Add attention backend tests to more-tests.yml
Browse files Browse the repository at this point in the history
Tests from pytorch#1477 only, without the generate.py refactor
  • Loading branch information
mikekgfb authored Jan 27, 2025
1 parent 5684175 commit b52d1b4
Showing 1 changed file with 62 additions and 0 deletions.
62 changes: 62 additions & 0 deletions .github/workflows/more-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ jobs:
gpu-arch-version: "12.4"
timeout: 60
script: |
set -xeou pipefail
echo "::group::Print machine info"
uname -a
echo "::endgroup::"
Expand Down Expand Up @@ -83,3 +84,64 @@ jobs:
echo "tests complete"
echo "******************************************"
echo "::endgroup::"
test-sdpa-backends:
permissions:
id-token: write
contents: read
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
with:
runner: linux.g5.4xlarge.nvidia.gpu
gpu-arch-type: cuda
gpu-arch-version: "12.4"
timeout: 60
script: |
set -xeou pipefail
echo "::group::Print machine info"
uname -a
echo "::endgroup::"
echo "::group::Download checkpoints"
# Install requirements
./install/install_requirements.sh cuda
pip3 list
python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")'
echo "::endgroup::"
echo "::group::Download checkpoints"
mkdir -p checkpoints/stories15M
pushd checkpoints/stories15M
wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.pt
wget https://github.com/karpathy/llama2.c/raw/master/tokenizer.model
popd
echo "::endgroup::"
echo "::group::Run inference"
export MODEL_PATH=checkpoints/stories15M/stories15M.pt
export MODEL_NAME=stories15M
export MODEL_DIR=/tmp
for DEVICE in cpu cuda; do
for DTYPE in bfloat16 float16 float32; do
for SDPA in 'math' 'flash_attention' 'efficient_attention' 'cudnn_attention'; do
echo "******************************************************************"
echo "******* $DEVICE $DTYPE $SDPA "
###################################################################
# Python execution interpreted vanilla
python torchchat.py generate --checkpoint-path ${MODEL_PATH} --attention-backend ${SDPA} --device ${DEVICE} --dtype ${DTYPE} --temperature 0
###################################################################
# prefill, and compile and prefill compile
python torchchat.py generate --checkpoint-path ${MODEL_PATH} --attention-backend ${SDPA} --device ${DEVICE} --dtype ${DTYPE} --temperature 0 --compile --compile-prefill
###################################################################
# sequential prefill
python torchchat.py generate --checkpoint-path ${MODEL_PATH} --attention-backend ${SDPA} --device ${DEVICE} --dtype ${DTYPE} --temperature 0 --sequential-prefill
###################################################################
# prefill, and compile
python torchchat.py generate --checkpoint-path ${MODEL_PATH} --attention-backend ${SDPA} --device ${DEVICE} --dtype ${DTYPE} --temperature 0 --sequential-prefill --compile
done
done
done
echo "tests complete"
echo "******************************************"
echo "::endgroup::"

0 comments on commit b52d1b4

Please sign in to comment.