diff --git a/.github/workflows/more-tests.yml b/.github/workflows/more-tests.yml index f772382d1..c2502e7e4 100644 --- a/.github/workflows/more-tests.yml +++ b/.github/workflows/more-tests.yml @@ -19,6 +19,7 @@ jobs: gpu-arch-version: "12.4" timeout: 60 script: | + set -xeou pipefail echo "::group::Print machine info" uname -a echo "::endgroup::" @@ -83,3 +84,64 @@ jobs: echo "tests complete" echo "******************************************" echo "::endgroup::" + + test-sdpa-backends: + permissions: + id-token: write + contents: read + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + with: + runner: linux.g5.4xlarge.nvidia.gpu + gpu-arch-type: cuda + gpu-arch-version: "12.4" + timeout: 60 + script: | + set -xeou pipefail + echo "::group::Print machine info" + uname -a + echo "::endgroup::" + + echo "::group::Download checkpoints" + # Install requirements + ./install/install_requirements.sh cuda + pip3 list + python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")' + echo "::endgroup::" + + echo "::group::Download checkpoints" + mkdir -p checkpoints/stories15M + pushd checkpoints/stories15M + wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.pt + wget https://github.com/karpathy/llama2.c/raw/master/tokenizer.model + popd + echo "::endgroup::" + + echo "::group::Run inference" + export MODEL_PATH=checkpoints/stories15M/stories15M.pt + export MODEL_NAME=stories15M + export MODEL_DIR=/tmp + + for DEVICE in cpu cuda; do + for DTYPE in bfloat16 float16 float32; do + for SDPA in 'math' 'flash_attention' 'efficient_attention' 'cudnn_attention'; do + echo "******************************************************************" + echo "******* $DEVICE $DTYPE $SDPA " + ################################################################### + # Python execution interpreted vanilla + python torchchat.py generate --checkpoint-path ${MODEL_PATH} --attention-backend ${SDPA} --device ${DEVICE} --dtype ${DTYPE} --temperature 0 + ################################################################### + # prefill, and compile and prefill compile + python torchchat.py generate --checkpoint-path ${MODEL_PATH} --attention-backend ${SDPA} --device ${DEVICE} --dtype ${DTYPE} --temperature 0 --compile --compile-prefill + ################################################################### + # sequential prefill + python torchchat.py generate --checkpoint-path ${MODEL_PATH} --attention-backend ${SDPA} --device ${DEVICE} --dtype ${DTYPE} --temperature 0 --sequential-prefill + ################################################################### + # prefill, and compile + python torchchat.py generate --checkpoint-path ${MODEL_PATH} --attention-backend ${SDPA} --device ${DEVICE} --dtype ${DTYPE} --temperature 0 --sequential-prefill --compile + done + done + done + + echo "tests complete" + echo "******************************************" + echo "::endgroup::"