Skip to content

Commit

Permalink
Add persistent+TMA version of Triton mm and addmm (pytorch#142101)
Browse files Browse the repository at this point in the history
This PR adds persistent+TMA versions (Triton template + the corresponding infra) for the `tuned_mm` and `tuned_addmm` lowerings. The persistent+TMA choices are added to the GEMM autotuning if (checked by the `use_triton_tma_template` helper):

1. The min. hardware and Triton version requirements are met for the TMA support.

2. The GEMM inputs are compatible with the Triton TMA API (i.e., 16-byte aligned and contiguous).

3. The `config.triton.enable_persistent_tma_matmul` is set to `True`.

Additional notes:

1. As added in this PR, the TMA uses are not compatible with prolog / epilogue fusion. To this end, in the new Triton template we currently support: TMA-based loads of A/B, but no prologue fusion; epilogue fusion, but no TMA-based stores of C. TMA + fusion compatibility can be added as a follow-up.

2. The current Triton TMA API (`experimental_device_tensormap_create2d`) does not support strides. Due to this, we limit the applicability of the new Triton template to the cases where the inputs are contiguous.

3. The transposed layouts of A and / or B are supported by passing the constexpr flags to the kernel and adjusting the ordering of the block sizes accordingly in the kernel code (this should have no effect on the kernel perf, as decided at the Triton compilation time).

4. After the next Triton pin update, we can switch to the tensor descriptor API (landed recently in triton-lang/triton#5290) in the new Triton template, which should allow lifting 2 and 3 above.

5. The configs for the new Triton template in `persistent_mm_kernel_configs` are preliminary. We should do more perf exploration and possibly augment the config in a follow-up.

6. This PR is rebased onto and unifies with two related PRs landed previously: pytorch#142045 (some infra unification with the persistent+TMA template for _scaled_mm) and pytorch#134532 (add possibility to disable prolog fusion for selected choices).

7. The current Triton TMA API only supports 1D and 2D descriptors (even after triton-lang/triton#5290, see [here](https://github.com/triton-lang/triton/blob/9829ce87ccb333a2b264b3a80b39a534bfa865ac/python/triton/language/core.py#L1957)). For now, this blocks adding persistent+TMA template for `torch.bmm`.

Pull Request resolved: pytorch#142101
Approved by: https://github.com/drisspg, https://github.com/eellison
  • Loading branch information
aakhundov authored and pytorchmergebot committed Dec 16, 2024
1 parent 17b71e5 commit e885225
Show file tree
Hide file tree
Showing 9 changed files with 453 additions and 61 deletions.
8 changes: 4 additions & 4 deletions test/inductor/test_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias):
w_inverse_scale,
bias,
)
with config.patch({"triton.enable_persistent_tma_matmul": True}):
with config.patch({"triton.enable_persistent_tma_matmul": persistent_matmul}):
linear_compiled = torch.compile(
linear, backend="inductor", mode="max-autotune"
)
Expand Down Expand Up @@ -538,7 +538,7 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias):
w_inverse_scale,
bias,
)
with config.patch({"triton.enable_persistent_tma_matmul": True}):
with config.patch({"triton.enable_persistent_tma_matmul": persistent_matmul}):
linear_compiled = torch.compile(
linear, backend="inductor", mode="max-autotune"
)
Expand Down Expand Up @@ -596,7 +596,7 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias):
w_inverse_scale,
bias,
)
with config.patch({"triton.enable_persistent_tma_matmul": True}):
with config.patch({"triton.enable_persistent_tma_matmul": persistent_matmul}):
linear_compiled = torch.compile(
linear, backend="inductor", mode="max-autotune"
)
Expand Down Expand Up @@ -655,7 +655,7 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias):
w_inverse_scale,
bias,
)
with config.patch({"triton.enable_persistent_tma_matmul": True}):
with config.patch({"triton.enable_persistent_tma_matmul": persistent_matmul}):
linear_compiled = torch.compile(
linear, backend="inductor", mode="max-autotune"
)
Expand Down
144 changes: 144 additions & 0 deletions test/inductor/test_max_autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
parametrize,
TEST_WITH_ROCM,
)
from torch.utils._triton import has_triton_tma_device


aten = torch.ops.aten
Expand Down Expand Up @@ -212,6 +213,76 @@ def mm(a, b):
with config.patch({"max_autotune": True, "autotune_in_subproc": True}):
torch.compile(mm, dynamic=dynamic)(a, b)

@unittest.skipIf(
not has_triton_tma_device(), "Need device-side TMA support in Triton"
)
@parametrize("a_transposed", (False, True))
@parametrize("b_transposed", (False, True))
@parametrize("dynamic", (False, True))
def test_max_autotune_regular_mm_persistent_tma(
self,
a_transposed: bool,
b_transposed: bool,
dynamic: bool,
):
def mm(a, b):
# TMA requires 16-byte alignment: here we repeat the dims
# by the factor of 8, as float16 is 2-byte. All dims are
# repeated due to the possible transpositions below.
a = a.repeat(8, 8)
b = b.repeat(8, 8)

if a_transposed:
a = a.T
if b_transposed:
b = b.T

return torch.mm(a, b)

M, N, K = 21, 31, 11
a = torch.randn(*((K, M) if a_transposed else (M, K))).to(torch.float16).cuda()
b = torch.randn(*((N, K) if b_transposed else (K, N))).to(torch.float16).cuda()

with config.patch(
{
"max_autotune": True,
"autotune_fallback_to_aten": False,
"triton.enable_persistent_tma_matmul": "1",
"test_configs.autotune_choice_name_regex": "mm_persistent_tma",
}
):
c_actual = torch.compile(mm, dynamic=dynamic)(a, b)
c_expected = mm(a, b)

torch.testing.assert_close(c_actual, c_expected, atol=1e-2, rtol=1e-2)

@unittest.skipIf(
not has_triton_tma_device(), "Need device-side TMA support in Triton"
)
@parametrize("dynamic", (False, True))
def test_max_autotune_regular_mm_persistent_tma_illegal_alignment(self, dynamic):
def mm(a, b):
return torch.mm(a, b)

M, N, K = 21, 31, 11
a = torch.randn(M, K).to(torch.float16).cuda()
b = torch.randn(K, N).to(torch.float16).cuda()

with self.assertRaises(BackendCompilerFailed) as context, config.patch(
{
"max_autotune": True,
"autotune_fallback_to_aten": False,
"triton.enable_persistent_tma_matmul": "1",
"test_configs.autotune_choice_name_regex": "mm_persistent_tma",
}
):
torch.compile(mm, dynamic=dynamic)(a, b)

# Lowering to the persistent+TMA Triton template should be skipped
# if any of the input inner dims are not 16-byte aligned. As a result,
# given the config flags above, we should have no choices left.
self.assertIn("NoValidChoicesError", str(context.exception))

@parametrize("dynamic", (False, True))
def test_max_autotune_regular_mm_zero_size_input(self, dynamic: bool):
"""
Expand Down Expand Up @@ -316,6 +387,79 @@ def addmm(x, a, b):
Y = addmm(x, a, b)
torch.testing.assert_close(Y_compiled, Y, atol=1e-2, rtol=1e-2)

@unittest.skipIf(
not has_triton_tma_device(), "Need device-side TMA support in Triton"
)
@parametrize("a_transposed", (False, True))
@parametrize("b_transposed", (False, True))
@parametrize("dynamic", (False, True))
def test_max_autotune_addmm_persistent_tma(
self,
a_transposed: bool,
b_transposed: bool,
dynamic: bool,
):
def addmm(x, a, b):
# TMA requires 16-byte alignment: here we repeat the dims
# by the factor of 8, as float16 is 2-byte. All dims are
# repeated due to the possible transpositions below.
x = x.repeat(8)
a = a.repeat(8, 8)
b = b.repeat(8, 8)

if a_transposed:
a = a.T
if b_transposed:
b = b.T

return torch.addmm(x, a, b)

M, N, K = 21, 31, 11
a = torch.randn(*((K, M) if a_transposed else (M, K))).to(torch.float16).cuda()
b = torch.randn(*((N, K) if b_transposed else (K, N))).to(torch.float16).cuda()
x = torch.randn(N).to(torch.float16).cuda()

with config.patch(
{
"max_autotune": True,
"autotune_fallback_to_aten": False,
"triton.enable_persistent_tma_matmul": "1",
"test_configs.autotune_choice_name_regex": "mm_persistent_tma",
}
):
c_actual = torch.compile(addmm, dynamic=dynamic)(x, a, b)
c_expected = addmm(x, a, b)

torch.testing.assert_close(c_actual, c_expected, atol=1e-2, rtol=1e-2)

@unittest.skipIf(
not has_triton_tma_device(), "Need device-side TMA support in Triton"
)
@parametrize("dynamic", (False, True))
def test_max_autotune_addmm_persistent_tma_illegal_alignment(self, dynamic):
def addmm(x, a, b):
return torch.addmm(x, a, b)

M, N, K = 21, 31, 11
a = torch.randn(M, K).to(torch.float16).cuda()
b = torch.randn(K, N).to(torch.float16).cuda()
x = torch.randn(N).to(torch.float16).cuda()

with self.assertRaises(BackendCompilerFailed) as context, config.patch(
{
"max_autotune": True,
"autotune_fallback_to_aten": False,
"triton.enable_persistent_tma_matmul": "1",
"test_configs.autotune_choice_name_regex": "mm_persistent_tma",
}
):
torch.compile(addmm, dynamic=dynamic)(x, a, b)

# Lowering to the persistent+TMA Triton template should be skipped
# if any of the input inner dims are not 16-byte aligned. As a result,
# given the config flags above, we should have no choices left.
self.assertIn("NoValidChoicesError", str(context.exception))

@parametrize("dynamic", (False, True))
def test_max_autotune_addmm_zero_size_input(self, dynamic):
"""
Expand Down
5 changes: 5 additions & 0 deletions torch/_inductor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1368,6 +1368,11 @@ class test_configs:

runtime_triton_dtype_assert = False

# regex to control the set of considered autotuning
# choices (aka configs) by name and / or description
autotune_choice_name_regex: Optional[str] = None
autotune_choice_desc_regex: Optional[str] = None


if TYPE_CHECKING:
from torch.utils._config_typing import * # noqa: F401, F403
Expand Down
143 changes: 143 additions & 0 deletions torch/_inductor/kernel/mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,14 @@
)
from ..utils import (
get_gpu_shared_memory,
get_tma_workspace_arg,
use_aten_gemm_kernels,
use_ck_gemm_template,
use_cpp_gemm_template,
use_cutlass_template,
use_max_autotune,
use_triton_template,
use_triton_tma_template,
)
from .mm_common import (
_is_static_problem,
Expand All @@ -48,6 +50,9 @@
mm_configs,
mm_grid,
mm_options,
persistent_mm_configs,
persistent_mm_grid,
persistent_mm_options,
triton_config,
)

Expand Down Expand Up @@ -128,6 +133,110 @@
""",
)

persistent_tma_mm_template = TritonTemplate(
name="mm_persistent_tma",
grid=persistent_mm_grid,
source=r"""
{{def_kernel("A", "B")}}
M = {{size("A", 0)}}
N = {{size("B", 1)}}
K = {{size("A", 1)}}
if M * N == 0:
# early exit due to zero-size input(s)
return
start_pid = tl.program_id(0)
grid_m = tl.cdiv(M, BLOCK_M)
grid_n = tl.cdiv(N, BLOCK_N)
k_tiles = tl.cdiv(K, BLOCK_K)
num_tiles = grid_m * grid_n
tiles_per_SM = num_tiles // NUM_SMS
if start_pid < num_tiles % NUM_SMS:
tiles_per_SM += 1
tile_id = start_pid - NUM_SMS
ki = -1
width = GROUP_M * grid_n
rk_for_mask = tl.arange(0, BLOCK_K)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
workspace_base = ws_ptr + start_pid * 2 * TMA_SIZE
a_desc_ptr = workspace_base
b_desc_ptr = workspace_base + TMA_SIZE
triton.language.extra.cuda.experimental_device_tensormap_create2d(
desc_ptr=a_desc_ptr,
global_address=A,
load_size=[BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M],
global_size=[M, K] if A_ROW_MAJOR else [K, M],
element_ty=A.dtype.element_ty,
)
triton.language.extra.cuda.experimental_device_tensormap_create2d(
desc_ptr=b_desc_ptr,
global_address=B,
load_size=[BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K],
global_size=[K, N] if B_ROW_MAJOR else [N, K],
element_ty=B.dtype.element_ty,
)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr)
pid_m = 0
pid_n = 0
rm = 0
rn = 0
for _ in range(0, k_tiles * tiles_per_SM):
ki = tl.where(ki == k_tiles - 1, 0, ki + 1)
if ki == 0:
tile_id += NUM_SMS
# re-order program ID for better L2 performance
group_id = tile_id // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (tile_id % group_size)
pid_n = (tile_id % width) // (group_size)
rm = pid_m * BLOCK_M
rn = pid_n * BLOCK_N
rk = ki * BLOCK_K
a = tl._experimental_descriptor_load(
a_desc_ptr,
[rm, rk] if A_ROW_MAJOR else [rk, rm],
[BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M],
A.dtype.element_ty,
)
b = tl._experimental_descriptor_load(
b_desc_ptr,
[rk, rn] if B_ROW_MAJOR else [rn, rk],
[BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K],
B.dtype.element_ty,
)
if B_PROLOGUE_CAST_TYPE is not None:
b = b.to(B_PROLOGUE_CAST_TYPE)
acc += tl.dot(
a if A_ROW_MAJOR else a.T,
b if B_ROW_MAJOR else b.T,
allow_tf32=ALLOW_TF32,
)
if ki == k_tiles - 1:
# rematerialize rm and rn to save registers
rcm = rm + tl.arange(0, BLOCK_M)
rcn = rn + tl.arange(0, BLOCK_N)
idx_m = rcm[:, None]
idx_n = rcn[None, :]
mask = (idx_m < M) & (idx_n < N)
# inductor generates a suffix
{{store_output(("idx_m", "idx_n"), "acc", "mask", indent_width=12)}}
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
""",
)


# prevent duplication registration of extern functions
@functools.lru_cache(None)
Expand Down Expand Up @@ -206,6 +315,22 @@ def tuned_mm(mat1, mat2, *, layout=None):
layout=layout,
**mm_options(config, m, n, k, layout),
)
if use_triton_tma_template(mat1, mat2):
for config in persistent_mm_configs(
m, n, k, **mm_config_kwargs(ir.get_device_type(mat1))
):
persistent_tma_mm_template.maybe_append_choice(
choices,
input_nodes=(mat1, mat2),
layout=layout,
workspace_arg=get_tma_workspace_arg(
num_tma_descriptors=2,
device=mat1.get_device(),
),
**mm_options(config, m, n, k, layout),
**persistent_mm_options(mat1, mat2),
)

if is_nonzero and use_cutlass_template(layout, m, n, k):
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2])

Expand Down Expand Up @@ -398,6 +523,24 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta),
)

if use_triton_tma_template(mat1, mat2):
for config in persistent_mm_configs(
m, n, k, **mm_config_kwargs(ir.get_device_type(mat1))
):
persistent_tma_mm_template.maybe_append_choice(
choices,
input_nodes=(inp_expanded, mat1, mat2),
layout=layout,
workspace_arg=get_tma_workspace_arg(
num_tma_descriptors=2,
device=mat1.get_device(),
),
**mm_options(config, m, n, k, layout),
**persistent_mm_options(mat1, mat2),
prefix_args=1,
epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta),
)

if static_shape and is_nonzero and use_cutlass_template(layout, m, n, k):
# Filter out a known cause of CUDA illegal memory access errors
# broadcasting on the last dim of the bias term seems not to be working
Expand Down
Loading

0 comments on commit e885225

Please sign in to comment.