Skip to content

Commit

Permalink
upd
Browse files Browse the repository at this point in the history
  • Loading branch information
yzh119 committed Feb 2, 2024
1 parent f653fe9 commit 6e66565
Show file tree
Hide file tree
Showing 10 changed files with 32 additions and 30 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
*.pdf
*.py
*.pyc
__pycache__
Expand Down
2 changes: 2 additions & 0 deletions Gemfile
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,5 @@ gem "wdm", "~> 0.1.1", :platforms => [:mingw, :x64_mingw, :mswin]
# Lock `http_parser.rb` gem to `v0.6.x` on JRuby builds since newer versions of the gem
# do not have a Java counterpart.
gem "http_parser.rb", "~> 0.6.0", :platforms => [:jruby]

gem 'jekyll-redirect-from'
3 changes: 3 additions & 0 deletions Gemfile.lock
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ GEM
webrick (~> 1.7)
jekyll-feed (0.17.0)
jekyll (>= 3.7, < 5.0)
jekyll-redirect-from (0.16.0)
jekyll (>= 3.3, < 5.0)
jekyll-sass-converter (3.0.0)
sass-embedded (~> 1.54)
jekyll-seo-tag (2.8.0)
Expand Down Expand Up @@ -75,6 +77,7 @@ DEPENDENCIES
http_parser.rb (~> 0.6.0)
jekyll (~> 4.3.2)
jekyll-feed (~> 0.12)
jekyll-redirect-from
minima (~> 2.5)
tzinfo (>= 1, < 3)
tzinfo-data
Expand Down
1 change: 1 addition & 0 deletions _config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ github_username: flashinfer-ai
theme: minima
plugins:
- jekyll-feed
- jekyll-redirect-from

# Exclude from processing.
# The following items will not be processed, by default.
Expand Down
55 changes: 25 additions & 30 deletions _posts/2024-01-08-cascade-inference.md
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
---
layout: post
title: "Cascade Inference: Memory Bandwidth Efficient Shared Prefix Batch Decoding"
date: 2024-01-08
date: 2024-02-02
comments: true
author: Zihao Ye (UW), Ruihang Lai (CMU), Roy Lu (UW), Chien-Yu Lin (UW), Size Zheng (UW & PKU), Lequn Chen (UW), Tianqi Chen (CMU & OctoAI), Luis Ceze (UW & OctoAI)
author: Zihao Ye (UW), Ruihang Lai (CMU), Bo-Ru Lu (UW), Chien-Yu Lin (UW), Size Zheng (UW & PKU), Lequn Chen (UW), Tianqi Chen (CMU & OctoML), Luis Ceze (UW & OctoML)
redirect_from: "/2024/01/08/cascade-inference"
---

Many LLM inference tasks involves multiple independent text generation from a shared prefix (prompt), e.g. [Self-Consistency](https://arxiv.org/abs/2203.11171), [Tree of Thoughts](https://arxiv.org/abs/2305.10601) and [Skeleton-of-thought](https://arxiv.org/abs/2307.15337). Serving LLMs with common prefix could be memory and time-consuming, especially when common prefix is long and the number of requests is large: a possible use case is long document QA (Figure 1), multiple users interacts with ChatBot with the same document as prompt. While [vLLM](https://arxiv.org/abs/2309.06180) alleviate the memory issue by only storing one copy of the common prefix. However, it still suffers from the low-efficiency because the default PageAttention implementation do not optimize KV-Cache access to the shared prompt.

In this blog post, we introduce Cascade Inference, which simply decouples attention of shared prefix and unique suffixes, and enables storing shared KV-Cache in GPU shared memory (SMEM for short) for fast access in multiple requests. We show that Cascade Inference can greatly accelerate shared-prefix batch decoding operator, with up to 31x speedup compared to the baseline vLLM PageAttention implementation and 26x speedup compared to FlashInfer batch decoding operator without cascading on a H100 SXM 80GB. The kernels have been supported in [FlashInfer](https://github.com/flashinfer-ai/flashinfer/) as [PyTorch](https://docs.flashinfer.ai/api/python/cascade.html#cascade-attention) and C++ APIs.

<p align="center">
<figure>
<img src="/assets/imgs/document-qa-serving.png" alt="Document QA Serving" width="800"/>
<figcaption> Figure 1. An example of serving Document QA for multiple users, all of the requests share the same book as prompt. </figcaption>
</figure>
<br>
Figure 1. An example of serving Document QA for multiple users, all of the requests share the same book as prompt.
</p>

## Background
Expand All @@ -33,7 +33,7 @@ The single-query attention kernel (used in decode), on the other hand, assumes t

Neither multi-query attention nor single-query attention kernel is a good fit for shared-prefix batch decoding. However, multi-query attention is perfect for attention between queries and shared prefix, while single-query attention can deal with the attention between queries and unique suffixes. Can we combine the advantages of both approaches?

### Recursive Softmax/Attention
### Recursive Attention

The answer is "yes" if we can find a way to "merge" the attention of the same queries with shared prefix and unique suffixes. Fortunately, FlashAttention has shown it's possible to combine local
softmax/attention results by not only storing the local attention result, but also the normalization scales and renormalizing local attention results on the fly. Here we formulate the idea in concise notations:
Expand All @@ -48,20 +48,27 @@ $$ s(I) = \log\left(\sum_{i\in I} \exp(s_i) \right),$$

let's also generalize the value vector $\mathbf{v}$ from index to index sets (Note that the generalization of both $s$ and $v$ are self-consistent because when $I$ equals $\{i\}$, we have $s(I) = s_i$ and $\mathbf{v}(I) = \mathbf{v}_i$):

$$ \mathbf{v}(I)=\frac{\sum_{i\in I}\exp\left(s_i\right)\mathbf{v}_i}{\exp(s(I))}, $$
$$ \mathbf{v}(I) = \sum_{i\in I}\textrm{softmax}(s_i) \mathbf{v}_i = \frac{\sum_{i\in I}\exp\left(s_i\right)\mathbf{v}_i}{\exp(s(I))}, $$

The **attention state** between a query with KV of an index set $I$ can be defined as a tuple $\begin{bmatrix}\mathbf{v}(I) \\\ s(I)\end{bmatrix}$,
then we can define the **merge** operator $\oplus$ to combine two states as [^3]:
the $\textrm{softmax}$ function are restricted to the index set $I$. Note that $\mathbf{v}(\{1,2,\cdots, n\})$ is the self-attention output of the entire sequence. The **attention state** between a query with KV of an index set $I$ can be defined as a tuple $\begin{bmatrix}\mathbf{v}(I) \\\ s(I)\end{bmatrix}$,
then we can define a binary **merge** operator $\oplus$ to combine two states as (in practice we will minus $s$ with maximum value to guarantee numerical stability and here we omit them for simplicity):

$$\begin{bmatrix}\mathbf{v}(I\cup J)\\s(I\cup J)\end{bmatrix}=\begin{bmatrix}\mathbf{v}(I)\\s(I)\end{bmatrix}\oplus\begin{bmatrix}\mathbf{v}(J)\\s(J)\end{bmatrix}=\begin{bmatrix} \frac{\mathbf{v}(I)\exp(s(I)) + \mathbf{v}(J)\exp(s(J))}{\exp(s(I)) + \exp(s(J))} \\ \log(\exp(s(I)) + \exp(s(J))) \end{bmatrix},$$

and we can define **attention state** on the entire sequence (suppose sequence length is $n$):
the **merge** operator can be generalized to any number of attention state inputs:

$$\begin{bmatrix}\mathbf{v}(\{1,2,\dots, n\})\\s(\{1,2,\dots, n\})\end{bmatrix} = \bigoplus_{i=1}^{n} \begin{bmatrix}\mathbf{v}_i\\s_i\end{bmatrix}$$
$$\begin{bmatrix}\mathbf{v}(\bigcup_{i=1}^{n}I_i) \\ s(\bigcup_{i=1}^{n}I_i) \end{bmatrix} = \bigoplus_{i=1}^{n}\begin{bmatrix}\mathbf{v}(I_i) \\ s(I_i)\end{bmatrix} = \begin{bmatrix} \sum_{i=1}^{n} \textrm{softmax}(s(I_i))\mathbf{v}(I_i) \\ \log(\sum_{i=1}^{n} \exp (s(I_i))) \end{bmatrix} $$

Then $\mathbf{v}(\\{1,2,\cdots, n\\})$ is the self-attention result between query and the entire KV. Note that $\oplus$ is communicative and associative, which means we can get the exact attention result by merging the attention states of index subsets as long as their disjoint union is the $\\{1,2,\cdots, n\\}$, regardless of merge order.
The above n-ary merge operator is consistent with the binary merge operator, and we can prove the operator is *communicative* and *associative*. There are different ways to get the attention state of the entire sequence by merging the attention states of index subsets, and the final outcome is mathematically equivalent:

The KV sequence partitioning trick in FlashInfer and Flash-Decoding uses the same idea to merge partial attention states from different thread blocks.
<p align="center">
<img src="/assets/imgs/recursive-attention.png" alt="recursive-attention" width="800"/>
<br>
Figure 3. Different order to merge attention states are mathematically equivalent.
</p>

Recursive Attention allow us to decompose attention computation into multiple stages, different stages
can be dispatched to different compute units/devices. The KV sequence partitioning trick in FlashInfer and Flash-Decoding uses the same idea to merge partial attention states from different thread blocks.

### Cascade Inference: The Algorithm

Expand All @@ -74,10 +81,9 @@ we propose the following Divide-and-Conquer algorithm:
The overall workflow is explained on the left side of Figure 2, different color of rectangles are processed in different thread blocks in GPU. Note that for multi-query attention kernels, we access KV-Cache through SMEM or registers and for decode kernels we can only access KV-Cache through L2 Cache or Global Memory. Cascade Inference allow us to maximize memory reuse for common prefix, thus making the attention computation much more memory efficient.

<p align="center">
<figure>
<img src="/assets/imgs/cascade-inference.png" alt="Cascade Inference" width="800"/>
<figcaption> Figure 2. Workflow of Cascade Inference, throughput values adapted from blog: <a href="https://khairy2011.medium.com/tpu-vs-gpu-vs-cerebras-vs-graphcore-a-fair-comparison-between-ml-hardware-3f5a19d89e38">TPU vs GPU vs Cerebras vs Graphcore: A Fair Comparison between ML Hardware</a></figcaption>
</figure>
<br>
Figure 2. Workflow of Cascade Inference, throughput values adapted from blog: <a href="https://khairy2011.medium.com/tpu-vs-gpu-vs-cerebras-vs-graphcore-a-fair-comparison-between-ml-hardware-3f5a19d89e38">TPU vs GPU vs Cerebras vs Graphcore: A Fair Comparison between ML Hardware</a>
</p>

We call the divide-and-conquer approach for shared-prefix attention the "Cascade Inference".
Expand All @@ -87,25 +93,15 @@ We call the divide-and-conquer approach for shared-prefix attention the "Cascade
We evaluate Cascade Inference on H100 SXM 80GB and A100 PCIE 80GB GPUs. The input shape are adapted from LLaMA2-7B (32 heads, 128 dimension per head). We varies three parameters: number of requests (batch size), shared prefix length and unique suffix length per request. The baseline implementations is PageAttention kernel implemented in vLLM 0.2.6, we also show the performance of FlashInfer batch decoding operator without cascading. The page size (or block size, equivalently) is fixed to 16 for all implementations.

<p align="center">
<figure>
<img src="/assets/imgs/cascade-inference-performance-h100.png" alt="speedup-h100" width="800"/>
<figcaption>
<center>
<br>
Figure 3. Speedup over vLLM PageAttention on H100 SXM 80GB
</center>
</figcaption>
</figure>
</p>

<p align="center">
<figure>
<img src="/assets/imgs/cascade-inference-performance-a100.png" alt="speedup-a100" width="800"/>
<figcaption>
<center>
Speedup over vLLM PageAttention on A100 PCIe 80GB
</center>
</figcaption>
</figure>
<br>
Figure 4. Speedup over vLLM PageAttention on A100 PCIe 80GB
</p>

Figure 3 and 4 show the normalized performance on FlashInfer kernels in cascading and non-cascading setting
Expand All @@ -122,4 +118,3 @@ Recently, [SGLang](https://arxiv.org/abs/2312.07104) (a domain-specific language

[^1]: thread block: the programming abstraction that represents a group of cooperative threads, one SM can execute multiple thread blocks and one thread block cannot span multiple SMs.
[^2]: [Hopper architecture](https://resources.nvidia.com/en-us-tensor-core) introduces a new abstraction called Thread Block Clusters which enables a thread block to access shared memory of other thread blocks within the same SM. Hopper also supports direct SM-to-SM communication without accessing global memory (a.k.a Distributed Shared Memory), which can greatly accelerate cross SM communication. However, these features are not available in pre-Hopper architectures such as A100 GPUs.
[^3]: The tricks such as minus $s$ with max value to avoid numerically issues are omitted for simplicity
File renamed without changes
Binary file modified assets/imgs/devices-roofline.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/imgs/flashinfer-roofline-devices.pdf
Binary file not shown.
Binary file added assets/imgs/recursive-attention.pdf
Binary file not shown.
Binary file added assets/imgs/recursive-attention.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 6e66565

Please sign in to comment.