Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster MHA backwards pass #22820

Merged
merged 1 commit into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
272 changes: 162 additions & 110 deletions jax/experimental/pallas/ops/gpu/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,6 @@ def body(start_k, carry):
curr_k_slice = pl.dslice(start_k * block_k, block_k)

k = pl.load(k_ref, (curr_k_slice, slice(None)))
kv_segment_ids = (
None
if segment_ids_ref is None
else pl.load(segment_ids_ref, (curr_k_slice,))
)
qk = pl.dot(q, k.T) # [block_q, block_k]
if sm_scale != 1.:
qk *= sm_scale # [block_q, block_k]
Expand All @@ -87,6 +82,7 @@ def body(start_k, carry):
if causal or segment_ids_ref is not None:
mask = None
if segment_ids_ref is not None:
kv_segment_ids = pl.load(segment_ids_ref, (curr_k_slice,))
mask = segment_mask(q_segment_ids, kv_segment_ids)
if causal:
span_q = start_q * block_q + jnp.arange(block_q)
Expand Down Expand Up @@ -354,6 +350,9 @@ def _preprocess_backward(out, do, l, block_q: int,
return do_scaled, delta


# This kernel computes dK_i, dV_i and dQ_i in parallel across the sequence
# length.
# Inspired by the triton tutorial: https://github.com/triton-lang/triton/blob/main/python/tutorials/06-fused-attention.py
def mha_backward_kernel(
# Inputs
q_ref,
Expand All @@ -365,92 +364,148 @@ def mha_backward_kernel(
l_ref,
m_ref,
delta_ref,
_,
# Outputs
dq_ref,
dk_ref,
dv_ref,
*,
sm_scale: float,
causal: bool,
block_q: int,
block_q1: int,
block_k1: int,
block_q2: int,
block_k2: int,
block_d: int,
block_k: int,
):
del out_ref, l_ref # Not needed
seq_len = q_ref.shape[0]

def outer_loop(start_k, _):

dv = jnp.zeros([block_k, block_d], dtype=jnp.float32)
dk = jnp.zeros([block_k, block_d], dtype=jnp.float32)
k = pl.load(k_ref, (pl.ds(start_k * block_k, block_k), slice(None)))
v = pl.load(v_ref, (pl.ds(start_k * block_k, block_k), slice(None)))
span_k = start_k * block_k + jnp.arange(block_k)
kv_segment_ids = (
None
if segment_ids_ref is None
else pl.load(segment_ids_ref, (pl.ds(start_k * block_k, block_k),))
)

def inner_loop(start_q, carry):
dv, dk = carry
q = pl.load(q_ref, (pl.ds(start_q * block_q, block_q), slice(None)))
qk = pl.dot(q, k.T)
qk = qk.astype(q_ref.dtype)
qk = qk.astype(jnp.float32)
if sm_scale != 1.0:
qk *= sm_scale

q_segment_ids = (
None
if segment_ids_ref is None
else pl.load(segment_ids_ref, (pl.ds(start_q * block_q, block_q),))
)

if causal or segment_ids_ref is not None:
mask = None
if segment_ids_ref is not None:
mask = segment_mask(q_segment_ids, kv_segment_ids)

if causal:
span_q = start_q * block_q + jnp.arange(block_q)
causal_mask = span_q[:, None] >= span_k[None, :]
mask = (
causal_mask
if mask is None
else jnp.logical_and(mask, causal_mask)
)
qk = jnp.where(mask, qk, DEFAULT_MASK_VALUE)

m = pl.load(m_ref, (pl.ds(start_q * block_q, block_q),))
p = jnp.exp(qk - m[:, None])
do = pl.load(do_scaled_ref, (pl.ds(start_q * block_q, block_q), slice(None)))
dv = dv + pl.dot(p.astype(do.dtype).T, do)
di = pl.load(delta_ref, (pl.ds(start_q * block_q, block_q),))
dp = jnp.zeros((block_q, block_k), dtype=jnp.float32) - di[:, None]
dp = dp + pl.dot(do, v.T)
ds = p * dp
if sm_scale != 1.0:
ds = ds * sm_scale
dk = dk + pl.dot(ds.astype(q_ref.dtype).T, q)
dq = pl.load(dq_ref, (pl.ds(start_q * block_q, block_q),
slice(None)), eviction_policy="evict_last")
dq = dq + pl.dot(ds.astype(k.dtype), k).astype(dq.dtype)
pl.store(dq_ref, (pl.ds(start_q * block_q, block_q),
slice(None)), dq, eviction_policy="evict_last")
return dv, dk
if causal:
lower_bound = lax.div(start_k * block_k, block_q)
else:
lower_bound = 0
dv, dk = lax.fori_loop(lower_bound, pl.cdiv(seq_len, block_q), inner_loop,
(dv, dk))
pl.store(dv_ref, (pl.ds(start_k * block_k, block_k),
slice(None)), dv.astype(dv_ref.dtype))
pl.store(dk_ref, (pl.ds(start_k * block_k, block_k),
slice(None)), dk.astype(dk_ref.dtype))
lax.fori_loop(0, pl.cdiv(seq_len, block_k), outer_loop, None)
# Scan #1: dK and dV
# 1. Load a block of K and V of size (block_k1, head_dim) in SMEM.
# 2. Iterate through Q in chunks of (block_q1, head_dim) to accumulate
# dK and dV.
start_k = pl.program_id(2)
curr_k_slice = pl.dslice(start_k * block_k1, block_k1)

dv = jnp.zeros([block_k1, block_d], dtype=jnp.float32)
dk = jnp.zeros([block_k1, block_d], dtype=jnp.float32)

v = pl.load(v_ref, (curr_k_slice, slice(None)))
k = pl.load(k_ref, (curr_k_slice, slice(None)))
span_k = start_k * block_k1 + jnp.arange(block_k1)
kv_segment_ids = (
None
if segment_ids_ref is None
else pl.load(segment_ids_ref, (curr_k_slice,))
)

def inner_loop_dkdv(start_q, carry):
dv, dk = carry
curr_q_slice = pl.dslice(start_q * block_q1, block_q1)

q = pl.load(q_ref, (curr_q_slice, slice(None)))
qk = pl.dot(q, k.T)
if sm_scale != 1.0:
qk *= sm_scale

if causal or segment_ids_ref is not None:
mask = None
if segment_ids_ref is not None:
q_segment_ids = pl.load(segment_ids_ref, (curr_q_slice,))
mask = segment_mask(q_segment_ids, kv_segment_ids)

if causal:
span_q = start_q * block_q1 + jnp.arange(block_q1)
causal_mask = span_q[:, None] >= span_k[None, :]
mask = (
causal_mask if mask is None else jnp.logical_and(mask, causal_mask)
)
qk = jnp.where(mask, qk, DEFAULT_MASK_VALUE)

m = pl.load(m_ref, (curr_q_slice,))
di = pl.load(delta_ref, (curr_q_slice,))
do = pl.load(do_scaled_ref, (curr_q_slice, slice(None)))

p = jnp.exp(qk - m[:, None])
dv = dv + pl.dot(p.astype(do.dtype).T, do)
dp = jnp.zeros((block_q1, block_k1), dtype=jnp.float32) - di[:, None]
dp = dp + pl.dot(do, v.T)
ds = p * dp
if sm_scale != 1.0:
ds = ds * sm_scale
dk = dk + pl.dot(ds.astype(q_ref.dtype).T, q)

return dv, dk

lower_bound = lax.div(start_k * block_k1, block_q1) if causal else 0
dv, dk = lax.fori_loop(
lower_bound, pl.cdiv(seq_len, block_q1), inner_loop_dkdv, (dv, dk)
)
pl.store(dv_ref, (curr_k_slice, slice(None)), dv.astype(dv_ref.dtype))
pl.store(dk_ref, (curr_k_slice, slice(None)), dk.astype(dk_ref.dtype))

del dv, dk

# Scan #2: dQ
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, is there an advantage to doing this in one kernel vs two kernels?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it comes down to the fact that there’s more work to do in a kernel, which leads to better occupancy of warps.

Other factors that will influence the gpu utilization:

  • There’s some data locality between the 2 loops, but it's more significant for smaller sequence lengths.
  • Overhead of launching 2 kernels.
  • Making sure the kernels are actually executing in parallel. They need to be launched on separate cuda streams, and even I don't think that it's guaranteed.

# 1. Load a block of Q of size (block_q2, head_dim) in SMEM.
# 2. Iterate through K and V in chunks of (block_k2, head_dim) to
# accumulate dQ.
start_q = pl.program_id(2)
curr_q_slice = pl.ds(start_q * block_q2, block_q2)
span_q = start_q * block_q2 + jnp.arange(block_q2)
dq = jnp.zeros([block_q2, block_d], dtype=jnp.float32)

q = pl.load(q_ref, (curr_q_slice, slice(None)))
q_segment_ids = (
None
if segment_ids_ref is None
else pl.load(segment_ids_ref, (curr_q_slice,))
)
m = pl.load(m_ref, (curr_q_slice,))
do = pl.load(do_scaled_ref, (curr_q_slice, slice(None)))
di = pl.load(delta_ref, (curr_q_slice,))

def inner_loop_dq(start_k, dq):
curr_k_slice = pl.dslice(start_k * block_k2, block_k2)
k = pl.load(k_ref, (curr_k_slice, slice(None)))
v = pl.load(v_ref, (curr_k_slice, slice(None)))

qk = pl.dot(q, k.T)
if sm_scale != 1.0:
qk *= sm_scale

if causal or segment_ids_ref is not None:
mask = None
if segment_ids_ref is not None:
kv_segment_ids = pl.load(segment_ids_ref, (curr_k_slice,))
mask = segment_mask(q_segment_ids, kv_segment_ids)

if causal:
span_k = start_k * block_k2 + jnp.arange(block_k2)
causal_mask = span_q[:, None] >= span_k[None, :]
mask = (
causal_mask if mask is None else jnp.logical_and(mask, causal_mask)
)
qk = jnp.where(mask, qk, DEFAULT_MASK_VALUE)

p = jnp.exp(qk - m[:, None])
dp = jnp.zeros((block_q2, block_k2), dtype=jnp.float32) - di[:, None]
dp = dp + pl.dot(do, v.T)
ds = p * dp
if sm_scale != 1.0:
ds = ds * sm_scale

dq = dq + pl.dot(ds.astype(k.dtype), k).astype(dq.dtype)

return dq

if causal:
upper_bound = lax.div((start_q + 1) * block_q2, block_k2)
else:
upper_bound = pl.cdiv(seq_len, block_k2)

dq = lax.fori_loop(0, upper_bound, inner_loop_dq, (dq))
pl.store(dq_ref, (curr_q_slice, slice(None)), dq.astype(dq_ref.dtype))


def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int,
Expand All @@ -473,75 +528,72 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int,
block_q = min(block_q, seq_len)
block_k = min(block_k, seq_len)
do_scaled, delta = _preprocess_backward(out, do, l, block_q, debug, interpret)
# We accumulate into dq so we need to initialize it to zeros.
dq = jnp.zeros(q.shape, jnp.float32)
out_shapes = [
jax.ShapeDtypeStruct(dq.shape, dq.dtype),
jax.ShapeDtypeStruct(k.shape, k.dtype),
jax.ShapeDtypeStruct(v.shape, v.dtype),
jax.ShapeDtypeStruct(q.shape, q.dtype),
jax.ShapeDtypeStruct(k.shape, k.dtype),
jax.ShapeDtypeStruct(v.shape, v.dtype),
]

in_specs = [
pl.BlockSpec(
(None, seq_len, None, head_dim), lambda j, k: (j, 0, k, 0)
),
pl.BlockSpec(
(None, seq_len, None, head_dim), lambda j, k: (j, 0, k, 0)
(None, seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0)
),
pl.BlockSpec(
(None, seq_len, None, head_dim), lambda j, k: (j, 0, k, 0)
(None, seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0)
),
pl.BlockSpec(
(None, seq_len, None, head_dim), lambda j, k: (j, 0, k, 0)
(None, seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0)
),
pl.BlockSpec(
(None, seq_len, None, head_dim), lambda j, k: (j, 0, k, 0)
(None, seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0)
),
pl.BlockSpec((None, None, seq_len), lambda j, k: (j, k, 0)),
pl.BlockSpec((None, None, seq_len), lambda j, k: (j, k, 0)),
pl.BlockSpec((None, None, seq_len), lambda j, k: (j, k, 0)),
pl.BlockSpec(
(None, seq_len, None, head_dim), lambda j, k: (j, 0, k, 0)
(None, seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0)
),
pl.BlockSpec((None, None, seq_len), lambda i, j, _: (i, j, 0)),
pl.BlockSpec((None, None, seq_len), lambda i, j, _: (i, j, 0)),
pl.BlockSpec((None, None, seq_len), lambda i, j, _: (i, j, 0)),
]
if segment_ids is None:
in_specs.insert(3, None) # type: ignore[arg-type]
input_output_aliases = {8: 0}
else:
in_specs.insert(3, pl.BlockSpec((None, seq_len), lambda j, k: (j, 0)))
input_output_aliases = {9: 0}
grid = (batch_size, num_heads)
# TODO(sharadmv): figure out why num_warps=8 doesn't work!
in_specs.insert(3, pl.BlockSpec((None, seq_len), lambda i, j, _: (i, 0)))

grid = (batch_size, num_heads, pl.cdiv(seq_len, block_k))
num_warps = 8
dq, dk, dv = pl.pallas_call(
functools.partial(
mha_backward_kernel,
block_q=block_q,
block_d=head_dim,
block_k=block_k,
sm_scale=sm_scale,
causal=causal,
block_q1=block_q,
block_k1=block_k,
block_q2=block_q,
block_k2=block_k,
block_d=head_dim,
),
grid=grid,
out_shape=out_shapes,
in_specs=in_specs,
grid=grid,
out_specs=[
pl.BlockSpec(
(None, seq_len, None, head_dim), lambda j, k: (j, 0, k, 0)
(None, seq_len, None, head_dim),
lambda i, j, _: (i, 0, j, 0), # dq
),
pl.BlockSpec(
(None, seq_len, None, head_dim), lambda j, k: (j, 0, k, 0)
(None, seq_len, None, head_dim),
lambda i, j, _: (i, 0, j, 0), # dk
),
pl.BlockSpec(
(None, seq_len, None, head_dim), lambda j, k: (j, 0, k, 0)
(None, seq_len, None, head_dim),
lambda i, j, _: (i, 0, j, 0), # dv
),
],
name="mha_backward",
debug=debug,
interpret=interpret,
compiler_params=dict(triton=dict(num_warps=num_warps, num_stages=1)),
input_output_aliases=input_output_aliases,
)(q, k, v, segment_ids, out, do_scaled, l, m, delta, dq)
compiler_params=dict(triton=dict(num_warps=num_warps, num_stages=2)),
)(q, k, v, segment_ids, out, do_scaled, l, m, delta)
else:
raise ValueError(f"Invalid backward pass implementation: {backward_pass_impl}")
return dq.astype(q.dtype), dk, dv, None
Expand Down
4 changes: 2 additions & 2 deletions tests/pallas/gpu_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,8 @@ def impl(q, k, v):
(1, 384, 1, 32, False, False),
(2, 384, 2, 32, False, True),
(2, 384, 2, 32, False, False),
# TODO(b/283035396): (1, 384, 1, 32, True, True),
# TODO(b/283035396): (2, 384, 2, 32, True, True),
(1, 384, 1, 32, True, True),
(2, 384, 2, 32, True, True),
]
]
)
Expand Down