Skip to content

Commit

Permalink
add working fp8 varlen
Browse files Browse the repository at this point in the history
  • Loading branch information
jayhshah committed Oct 31, 2024
1 parent 478ee66 commit 483b26e
Show file tree
Hide file tree
Showing 8 changed files with 393 additions and 58 deletions.
39 changes: 36 additions & 3 deletions hopper/benchmark_flash_attention_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from flash_attn.utils.benchmark import benchmark_fwd_bwd, benchmark_combined

from flash_attn import flash_attn_qkvpacked_func
from flash_attn_interface import flash_attn_func, _flash_attn_forward
from flash_attn_interface import flash_attn_func, _flash_attn_forward, _flash_attn_varlen_forward

try:
from triton_fused_attention import attention as attention_triton
Expand Down Expand Up @@ -128,7 +128,7 @@ def get_default_scale_tensor():
descale_s=descale_s,
scale_s=scale_s,
scale_o=scale_o,
is_inference=True,
is_inference=False,
attn_scale=1.0 / math.sqrt(headdim),
use_causal_mask=causal,
name="sdpa",
Expand Down Expand Up @@ -214,7 +214,7 @@ def time_fwd(func, *args, **kwargs):

torch.manual_seed(0)

repeats = 30
repeats = 20
device = 'cuda'
# dtype = torch.float16
dtype = torch.float8_e4m3fn
Expand All @@ -225,11 +225,13 @@ def time_fwd(func, *args, **kwargs):
# bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048)]
causal_vals = [False, True]
headdim_vals = [64, 128, 256]
# headdim_vals = [128]
dim = 2048
# dim = 256
dropout_p = 0.0

methods = (["Pytorch", "Flash3"]
+ (["FA3 varlen"])
+ (["cuDNN"] if cudnn is not None else [])
# + (["Triton"] if attention_triton is not None else [])
# + (["xformers.c"] if xops is not None else [])
Expand Down Expand Up @@ -337,6 +339,37 @@ def time_fwd(func, *args, **kwargs):
# print(item_baseline)
# torch.testing.assert_close(res, res_baseline, atol=0.05, rtol=0.05)

# For var-seq-len
lens = torch.full([q.shape[0]], seqlen, dtype=torch.int32)
cu_seqlens = torch.cat([torch.tensor([0], dtype=torch.int32), torch.cumsum(lens, dim=0, dtype=torch.int32)]).cuda()

q_var = q.reshape(-1, q.shape[-2], q.shape[-1])
k_var = k.reshape(-1, k.shape[-2], k.shape[-1])
v_var = v.reshape(-1, v.shape[-2], v.shape[-1])
# print(q_var.shape, q_var.dtype)
# print(k_var.shape, k_var.dtype)
# print(v_var.shape, v_var.dtype)

f = time_fwd(
_flash_attn_varlen_forward,
q_var,
k_var,
v_var,
cu_seqlens,
cu_seqlens,
seqlen,
seqlen,
softmax_scale,
causal=causal,
window_size=(-1,-1),
descale_q=descale_q,
descale_k=descale_k,
descale_v=descale_v,
repeats=repeats,
verbose=False
)
time_f[config, "FA3 varlen"] = f

print(f"### causal={causal}, headdim={headdim}, batch_size={batch_size}, seqlen={seqlen} ###")
for method in methods:
speed_f[config, method] = efficiency(
Expand Down
48 changes: 44 additions & 4 deletions hopper/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,9 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
int max_seqlen_q,
const int max_seqlen_k,
const float softmax_scale,
c10::optional<at::Tensor> &descale_q_, // 1
c10::optional<at::Tensor> &descale_k_, // 1
c10::optional<at::Tensor> &descale_v_, // 1
bool is_causal,
int window_size_left,
int window_size_right) {
Expand All @@ -648,8 +651,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
TORCH_CHECK(is_sm90, "FlashAttention only supports Hopper GPUs or newer.");

auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
"FlashAttention only support fp16 and bf16 data type");
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16 || q_dtype == at::ScalarType::Float8_e4m3fn,
"FlashAttention-3 varlen only support fp16, bf16, or fp8 e4m3 data type");
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
Expand Down Expand Up @@ -717,13 +720,20 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
at::Tensor out;
if (out_.has_value()) {
out = out_.value();
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
TORCH_CHECK(q_dtype == at::ScalarType::Float8_e4m3fn
? (out.dtype() == at::kBFloat16)
: (out.dtype() == q_dtype),
"Output must have the same dtype as input dtype if dtype is "
"not fp8, or fp16 for fp8 input.");
CHECK_DEVICE(out);
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
CHECK_SHAPE(out, sizes[0], sizes[1], head_size_og);
if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
} else {
out = torch::empty_like(q_padded);
if (q_dtype == at::ScalarType::Float8_e4m3fn)
out = torch::empty_like(q_padded, at::kBFloat16);
else
out = torch::empty_like(q_padded);
}

auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
Expand Down Expand Up @@ -764,6 +774,36 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
params.total_q = total_q;
params.total_k = total_k;

auto tile_count_semaphore = is_causal || params.is_local
? torch::zeros({1}, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32));
params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>();

at::Tensor descale_q, descale_k, descale_v;
if(q_dtype == at::ScalarType::Float8_e4m3fn) {
if (descale_q_.has_value()) {
descale_q = descale_q_.value();
CHECK_DEVICE(descale_q);
CHECK_SHAPE(descale_q, 1);
} else { descale_q = torch::ones({1}, opts.dtype(at::kFloat)); }
if (descale_k_.has_value()) {
descale_k = descale_k_.value();
CHECK_DEVICE(descale_k);
CHECK_SHAPE(descale_k, 1);
} else { descale_k = torch::ones({1}, opts.dtype(at::kFloat)); }
if (descale_v_.has_value()) {
descale_v = descale_v_.value();
CHECK_DEVICE(descale_v);
CHECK_SHAPE(descale_v, 1);
} else { descale_v = torch::ones({1}, opts.dtype(at::kFloat)); }
params.descale_q_ptr = descale_q.data_ptr<float>();
params.descale_k_ptr = descale_k.data_ptr<float>();
params.descale_v_ptr = descale_v.data_ptr<float>();
} else {
params.descale_q_ptr = nullptr;
params.descale_k_ptr = nullptr;
params.descale_v_ptr = nullptr;
}

if (max_seqlen_k > 0) {
auto stream = at::cuda::getCurrentCUDAStream().stream();
run_mha_fwd(params, stream);
Expand Down
33 changes: 31 additions & 2 deletions hopper/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,18 @@
def maybe_contiguous(x):
return x.contiguous() if x is not None and x.stride(-1) != 1 else x

def _flash_attn_forward(q, k, v, softmax_scale, causal, window_size, descale_q = None, descale_k = None, descale_v = None, gqa_parallel=False):
def _flash_attn_forward(
q,
k,
v,
softmax_scale,
causal,
window_size,
descale_q = None,
descale_k = None,
descale_v = None,
gqa_parallel=False
):
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, q, k, v, out_padded, softmax_lse, S_dmask = flashattn_hopper_cuda.fwd(
q,
Expand Down Expand Up @@ -81,6 +92,9 @@ def _flash_attn_varlen_forward(
window_size=(-1, -1),
seqused_q=None,
seqused_k=None,
descale_q=None,
descale_k=None,
descale_v=None,
):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
Expand All @@ -96,6 +110,9 @@ def _flash_attn_varlen_forward(
max_seqlen_q,
max_seqlen_k,
softmax_scale,
descale_q,
descale_k,
descale_v,
causal,
window_size[0],
window_size[1],
Expand Down Expand Up @@ -242,6 +259,9 @@ def forward(
deterministic=False,
seqused_q=None,
seqused_k=None,
descale_q=None,
descale_k=None,
descale_v=None,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
Expand All @@ -258,6 +278,9 @@ def forward(
window_size=window_size,
seqused_q=seqused_q,
seqused_k=seqused_k,
descale_q=descale_q,
descale_k=descale_k,
descale_v=descale_v,
)
ctx.save_for_backward(
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k,
Expand Down Expand Up @@ -299,7 +322,7 @@ def backward(ctx, dout, *args):
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None


def flash_attn_func(
Expand Down Expand Up @@ -395,6 +418,9 @@ def flash_attn_varlen_func(
deterministic=False,
seqused_q=None,
seqused_k=None,
descale_q=None,
descale_k=None,
descale_v=None,
):
"""
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
Expand Down Expand Up @@ -450,6 +476,9 @@ def flash_attn_varlen_func(
deterministic,
seqused_q,
seqused_k,
descale_q,
descale_k,
descale_v,
)


Expand Down
32 changes: 26 additions & 6 deletions hopper/flash_fwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,13 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,

if (warp_idx == 0 && lane_predicate) {
shared_storage.barrier_Q.init(1 /*numThreads*/);
if constexpr (!No_smem_O) { shared_storage.barrier_O.init(size(ClusterShape{}) /*numThreads*/); }
if constexpr (!No_smem_O) {
if constexpr(seqlen_traits_q.UseVarSeqLen) {
shared_storage.barrier_O.init(NumMmaThreads /*numThreads*/);
} else {
shared_storage.barrier_O.init(size(ClusterShape{}) /*numThreads*/);
}
}
}
// We're counting on pipeline_k to call cutlass::arch::fence_barrier_init();
MainloopPipeline pipeline_k(shared_storage.pipeline_k, pipeline_params, ClusterShape{});
Expand Down Expand Up @@ -121,6 +127,8 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
if constexpr(seqlen_traits_q.UseVarSeqLen) {
// NOTE: to support in future with gqa packed layouts, changed kBlockM to kBlockM/kBlockH
if (m_block * (kBlockM/kBlockH) >= seqlen_traits_q.actual_seq_len) {
scheduler.prefetch_next_work(scheduler_params, work_tile_info);
scheduler.broadcast_next_work(work_tile_info);
continue;
}
}
Expand Down Expand Up @@ -202,7 +210,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
collective_epilogue.store(
epilogue_params, tOrO, softmax.row_sum, shared_storage, tiled_mma1,
threadIdx.x - NumCopyThreads, block_coord, seqlen_traits_q, mainloop_params.qhead_per_khead_divmod);

if constexpr(!No_smem_O && seqlen_traits_q.UseVarSeqLen) { shared_storage.barrier_O.arrive(); }
++work_idx;
}
collective_epilogue.store_tail();
Expand Down Expand Up @@ -277,7 +285,13 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,

if (warp_idx == 0 && lane_predicate) {
shared_storage.barrier_Q.init(1 /*numThreads*/);
if constexpr (!No_smem_O) { shared_storage.barrier_O.init(size(ClusterShape{}) /*numThreads*/); }
if constexpr (!No_smem_O) {
if constexpr(seqlen_traits_q.UseVarSeqLen) {
shared_storage.barrier_O.init(NumMmaThreads /*numThreads*/);
} else {
shared_storage.barrier_O.init(size(ClusterShape{}) /*numThreads*/);
}
}
}
// We're counting on pipeline_k to call cutlass::arch::fence_barrier_init();
MainloopPipeline pipeline_k(shared_storage.pipeline_k, pipeline_params, ClusterShape{});
Expand Down Expand Up @@ -321,10 +335,15 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
auto [m_block, n_split_idx, bidh, bidb] = block_coord;

if constexpr (seqlen_traits_q.UseVarSeqLen) { seqlen_traits_q.init(bidb); }
if (shared_storage.seqlen_init_k) { seqlen_traits_k.init_no_guard(bidb); }
if constexpr (seqlen_traits_k.UseVarSeqLen) { seqlen_traits_k.init(bidb); }
else if (shared_storage.seqlen_init_k) { seqlen_traits_k.init_no_guard(bidb); }
if constexpr(seqlen_traits_q.UseVarSeqLen) {
// NOTE: to support in future with gqa packed layout, changed kBlockM to kBlockM/kBlockH
if (m_block * (kBlockM/kBlockH) >= seqlen_traits_q.actual_seq_len) {
scheduler.prefetch_next_work(scheduler_params, work_tile_info);
scheduler.broadcast_next_work(work_tile_info);
// need to sync producer warpgroup
cutlass::arch::NamedBarrier::sync(NumCopyThreads, static_cast<int>(FwdNamedBarriers::ProducerWG) /*id*/);
continue;
}
}
Expand Down Expand Up @@ -377,7 +396,8 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
auto [m_block, n_split_idx, bidh, bidb] = block_coord;

if constexpr (seqlen_traits_q.UseVarSeqLen) { seqlen_traits_q.init(bidb); }
if (shared_storage.seqlen_init_k) { seqlen_traits_k.init_no_guard(bidb); }
if constexpr (seqlen_traits_k.UseVarSeqLen) { seqlen_traits_k.init(bidb); }
else if (shared_storage.seqlen_init_k) { seqlen_traits_k.init_no_guard(bidb); }
if constexpr(seqlen_traits_q.UseVarSeqLen) {
// NOTE: to support in future with gqa packed layout, changed kBlockM to kBlockM/kBlockH
if (m_block * (kBlockM/kBlockH) >= seqlen_traits_q.actual_seq_len) {
Expand Down Expand Up @@ -409,7 +429,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
collective_epilogue.store(
epilogue_params, tOrO, softmax.row_sum, shared_storage, tiled_mma1,
threadIdx.x - NumCopyThreads, block_coord, seqlen_traits_q, mainloop_params.qhead_per_khead_divmod);

if constexpr(!No_smem_O && seqlen_traits_q.UseVarSeqLen) { shared_storage.barrier_O.arrive(); }
++work_idx;
}
collective_epilogue.store_tail();
Expand Down
Loading

0 comments on commit 483b26e

Please sign in to comment.