diff --git a/test/inductor/test_fp8.py b/test/inductor/test_fp8.py index 4cb6223933aef0..6241853a28824c 100644 --- a/test/inductor/test_fp8.py +++ b/test/inductor/test_fp8.py @@ -468,7 +468,7 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): w_inverse_scale, bias, ) - with config.patch({"triton.enable_persistent_tma_matmul": True}): + with config.patch({"triton.enable_persistent_tma_matmul": persistent_matmul}): linear_compiled = torch.compile( linear, backend="inductor", mode="max-autotune" ) @@ -538,7 +538,7 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): w_inverse_scale, bias, ) - with config.patch({"triton.enable_persistent_tma_matmul": True}): + with config.patch({"triton.enable_persistent_tma_matmul": persistent_matmul}): linear_compiled = torch.compile( linear, backend="inductor", mode="max-autotune" ) @@ -596,7 +596,7 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): w_inverse_scale, bias, ) - with config.patch({"triton.enable_persistent_tma_matmul": True}): + with config.patch({"triton.enable_persistent_tma_matmul": persistent_matmul}): linear_compiled = torch.compile( linear, backend="inductor", mode="max-autotune" ) @@ -655,7 +655,7 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): w_inverse_scale, bias, ) - with config.patch({"triton.enable_persistent_tma_matmul": True}): + with config.patch({"triton.enable_persistent_tma_matmul": persistent_matmul}): linear_compiled = torch.compile( linear, backend="inductor", mode="max-autotune" ) diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index c0c4d5d580efdf..d52727884dbc7e 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -28,6 +28,7 @@ parametrize, TEST_WITH_ROCM, ) +from torch.utils._triton import has_triton_tma_device aten = torch.ops.aten @@ -212,6 +213,76 @@ def mm(a, b): with config.patch({"max_autotune": True, "autotune_in_subproc": True}): torch.compile(mm, dynamic=dynamic)(a, b) + @unittest.skipIf( + not has_triton_tma_device(), "Need device-side TMA support in Triton" + ) + @parametrize("a_transposed", (False, True)) + @parametrize("b_transposed", (False, True)) + @parametrize("dynamic", (False, True)) + def test_max_autotune_regular_mm_persistent_tma( + self, + a_transposed: bool, + b_transposed: bool, + dynamic: bool, + ): + def mm(a, b): + # TMA requires 16-byte alignment: here we repeat the dims + # by the factor of 8, as float16 is 2-byte. All dims are + # repeated due to the possible transpositions below. + a = a.repeat(8, 8) + b = b.repeat(8, 8) + + if a_transposed: + a = a.T + if b_transposed: + b = b.T + + return torch.mm(a, b) + + M, N, K = 21, 31, 11 + a = torch.randn(*((K, M) if a_transposed else (M, K))).to(torch.float16).cuda() + b = torch.randn(*((N, K) if b_transposed else (K, N))).to(torch.float16).cuda() + + with config.patch( + { + "max_autotune": True, + "autotune_fallback_to_aten": False, + "triton.enable_persistent_tma_matmul": "1", + "test_configs.autotune_choice_name_regex": "mm_persistent_tma", + } + ): + c_actual = torch.compile(mm, dynamic=dynamic)(a, b) + c_expected = mm(a, b) + + torch.testing.assert_close(c_actual, c_expected, atol=1e-2, rtol=1e-2) + + @unittest.skipIf( + not has_triton_tma_device(), "Need device-side TMA support in Triton" + ) + @parametrize("dynamic", (False, True)) + def test_max_autotune_regular_mm_persistent_tma_illegal_alignment(self, dynamic): + def mm(a, b): + return torch.mm(a, b) + + M, N, K = 21, 31, 11 + a = torch.randn(M, K).to(torch.float16).cuda() + b = torch.randn(K, N).to(torch.float16).cuda() + + with self.assertRaises(BackendCompilerFailed) as context, config.patch( + { + "max_autotune": True, + "autotune_fallback_to_aten": False, + "triton.enable_persistent_tma_matmul": "1", + "test_configs.autotune_choice_name_regex": "mm_persistent_tma", + } + ): + torch.compile(mm, dynamic=dynamic)(a, b) + + # Lowering to the persistent+TMA Triton template should be skipped + # if any of the input inner dims are not 16-byte aligned. As a result, + # given the config flags above, we should have no choices left. + self.assertIn("NoValidChoicesError", str(context.exception)) + @parametrize("dynamic", (False, True)) def test_max_autotune_regular_mm_zero_size_input(self, dynamic: bool): """ @@ -316,6 +387,79 @@ def addmm(x, a, b): Y = addmm(x, a, b) torch.testing.assert_close(Y_compiled, Y, atol=1e-2, rtol=1e-2) + @unittest.skipIf( + not has_triton_tma_device(), "Need device-side TMA support in Triton" + ) + @parametrize("a_transposed", (False, True)) + @parametrize("b_transposed", (False, True)) + @parametrize("dynamic", (False, True)) + def test_max_autotune_addmm_persistent_tma( + self, + a_transposed: bool, + b_transposed: bool, + dynamic: bool, + ): + def addmm(x, a, b): + # TMA requires 16-byte alignment: here we repeat the dims + # by the factor of 8, as float16 is 2-byte. All dims are + # repeated due to the possible transpositions below. + x = x.repeat(8) + a = a.repeat(8, 8) + b = b.repeat(8, 8) + + if a_transposed: + a = a.T + if b_transposed: + b = b.T + + return torch.addmm(x, a, b) + + M, N, K = 21, 31, 11 + a = torch.randn(*((K, M) if a_transposed else (M, K))).to(torch.float16).cuda() + b = torch.randn(*((N, K) if b_transposed else (K, N))).to(torch.float16).cuda() + x = torch.randn(N).to(torch.float16).cuda() + + with config.patch( + { + "max_autotune": True, + "autotune_fallback_to_aten": False, + "triton.enable_persistent_tma_matmul": "1", + "test_configs.autotune_choice_name_regex": "mm_persistent_tma", + } + ): + c_actual = torch.compile(addmm, dynamic=dynamic)(x, a, b) + c_expected = addmm(x, a, b) + + torch.testing.assert_close(c_actual, c_expected, atol=1e-2, rtol=1e-2) + + @unittest.skipIf( + not has_triton_tma_device(), "Need device-side TMA support in Triton" + ) + @parametrize("dynamic", (False, True)) + def test_max_autotune_addmm_persistent_tma_illegal_alignment(self, dynamic): + def addmm(x, a, b): + return torch.addmm(x, a, b) + + M, N, K = 21, 31, 11 + a = torch.randn(M, K).to(torch.float16).cuda() + b = torch.randn(K, N).to(torch.float16).cuda() + x = torch.randn(N).to(torch.float16).cuda() + + with self.assertRaises(BackendCompilerFailed) as context, config.patch( + { + "max_autotune": True, + "autotune_fallback_to_aten": False, + "triton.enable_persistent_tma_matmul": "1", + "test_configs.autotune_choice_name_regex": "mm_persistent_tma", + } + ): + torch.compile(addmm, dynamic=dynamic)(x, a, b) + + # Lowering to the persistent+TMA Triton template should be skipped + # if any of the input inner dims are not 16-byte aligned. As a result, + # given the config flags above, we should have no choices left. + self.assertIn("NoValidChoicesError", str(context.exception)) + @parametrize("dynamic", (False, True)) def test_max_autotune_addmm_zero_size_input(self, dynamic): """ diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 3863d7e15a0721..d7f7e253e5a375 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1368,6 +1368,11 @@ class test_configs: runtime_triton_dtype_assert = False + # regex to control the set of considered autotuning + # choices (aka configs) by name and / or description + autotune_choice_name_regex: Optional[str] = None + autotune_choice_desc_regex: Optional[str] = None + if TYPE_CHECKING: from torch.utils._config_typing import * # noqa: F401, F403 diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py index b4b164ff4a499c..14200943258587 100644 --- a/torch/_inductor/kernel/mm.py +++ b/torch/_inductor/kernel/mm.py @@ -31,12 +31,14 @@ ) from ..utils import ( get_gpu_shared_memory, + get_tma_workspace_arg, use_aten_gemm_kernels, use_ck_gemm_template, use_cpp_gemm_template, use_cutlass_template, use_max_autotune, use_triton_template, + use_triton_tma_template, ) from .mm_common import ( _is_static_problem, @@ -48,6 +50,9 @@ mm_configs, mm_grid, mm_options, + persistent_mm_configs, + persistent_mm_grid, + persistent_mm_options, triton_config, ) @@ -128,6 +133,110 @@ """, ) +persistent_tma_mm_template = TritonTemplate( + name="mm_persistent_tma", + grid=persistent_mm_grid, + source=r""" +{{def_kernel("A", "B")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + + start_pid = tl.program_id(0) + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + k_tiles = tl.cdiv(K, BLOCK_K) + num_tiles = grid_m * grid_n + tiles_per_SM = num_tiles // NUM_SMS + if start_pid < num_tiles % NUM_SMS: + tiles_per_SM += 1 + + tile_id = start_pid - NUM_SMS + ki = -1 + + width = GROUP_M * grid_n + rk_for_mask = tl.arange(0, BLOCK_K) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + + workspace_base = ws_ptr + start_pid * 2 * TMA_SIZE + a_desc_ptr = workspace_base + b_desc_ptr = workspace_base + TMA_SIZE + + triton.language.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=a_desc_ptr, + global_address=A, + load_size=[BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M], + global_size=[M, K] if A_ROW_MAJOR else [K, M], + element_ty=A.dtype.element_ty, + ) + triton.language.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=b_desc_ptr, + global_address=B, + load_size=[BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K], + global_size=[K, N] if B_ROW_MAJOR else [N, K], + element_ty=B.dtype.element_ty, + ) + + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr) + + pid_m = 0 + pid_n = 0 + rm = 0 + rn = 0 + + for _ in range(0, k_tiles * tiles_per_SM): + ki = tl.where(ki == k_tiles - 1, 0, ki + 1) + if ki == 0: + tile_id += NUM_SMS + # re-order program ID for better L2 performance + group_id = tile_id // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (tile_id % group_size) + pid_n = (tile_id % width) // (group_size) + + rm = pid_m * BLOCK_M + rn = pid_n * BLOCK_N + + rk = ki * BLOCK_K + + a = tl._experimental_descriptor_load( + a_desc_ptr, + [rm, rk] if A_ROW_MAJOR else [rk, rm], + [BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M], + A.dtype.element_ty, + ) + b = tl._experimental_descriptor_load( + b_desc_ptr, + [rk, rn] if B_ROW_MAJOR else [rn, rk], + [BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K], + B.dtype.element_ty, + ) + if B_PROLOGUE_CAST_TYPE is not None: + b = b.to(B_PROLOGUE_CAST_TYPE) + acc += tl.dot( + a if A_ROW_MAJOR else a.T, + b if B_ROW_MAJOR else b.T, + allow_tf32=ALLOW_TF32, + ) + + if ki == k_tiles - 1: + # rematerialize rm and rn to save registers + rcm = rm + tl.arange(0, BLOCK_M) + rcn = rn + tl.arange(0, BLOCK_N) + idx_m = rcm[:, None] + idx_n = rcn[None, :] + mask = (idx_m < M) & (idx_n < N) + + # inductor generates a suffix + {{store_output(("idx_m", "idx_n"), "acc", "mask", indent_width=12)}} + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) +""", +) + # prevent duplication registration of extern functions @functools.lru_cache(None) @@ -206,6 +315,22 @@ def tuned_mm(mat1, mat2, *, layout=None): layout=layout, **mm_options(config, m, n, k, layout), ) + if use_triton_tma_template(mat1, mat2): + for config in persistent_mm_configs( + m, n, k, **mm_config_kwargs(ir.get_device_type(mat1)) + ): + persistent_tma_mm_template.maybe_append_choice( + choices, + input_nodes=(mat1, mat2), + layout=layout, + workspace_arg=get_tma_workspace_arg( + num_tma_descriptors=2, + device=mat1.get_device(), + ), + **mm_options(config, m, n, k, layout), + **persistent_mm_options(mat1, mat2), + ) + if is_nonzero and use_cutlass_template(layout, m, n, k): CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2]) @@ -398,6 +523,24 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta), ) + if use_triton_tma_template(mat1, mat2): + for config in persistent_mm_configs( + m, n, k, **mm_config_kwargs(ir.get_device_type(mat1)) + ): + persistent_tma_mm_template.maybe_append_choice( + choices, + input_nodes=(inp_expanded, mat1, mat2), + layout=layout, + workspace_arg=get_tma_workspace_arg( + num_tma_descriptors=2, + device=mat1.get_device(), + ), + **mm_options(config, m, n, k, layout), + **persistent_mm_options(mat1, mat2), + prefix_args=1, + epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta), + ) + if static_shape and is_nonzero and use_cutlass_template(layout, m, n, k): # Filter out a known cause of CUDA illegal memory access errors # broadcasting on the last dim of the bias term seems not to be working diff --git a/torch/_inductor/kernel/mm_common.py b/torch/_inductor/kernel/mm_common.py index 0d891f3673012d..6b491a61e5bfea 100644 --- a/torch/_inductor/kernel/mm_common.py +++ b/torch/_inductor/kernel/mm_common.py @@ -15,7 +15,12 @@ from ..codegen.wrapper import PythonWrapperCodegen from ..ir import Layout from ..runtime.runtime_utils import next_power_of_2 -from ..utils import ceildiv as cdiv, get_backend_num_stages +from ..utils import ( + ceildiv as cdiv, + get_backend_num_stages, + get_num_sms, + TMA_DESCRIPTOR_SIZE, +) log = logging.getLogger(__name__) @@ -225,18 +230,13 @@ def filtered_configs( ) persistent_mm_kernel_configs = [ + {"config": (128, 256, 64, 3, 8), "cond": True}, {"config": (128, 128, 64, 3, 8), "cond": True}, {"config": (128, 128, 128, 3, 8), "cond": True}, - {"config": (128, 128, 128, 4, 8), "cond": True}, - {"config": (128, 128, 128, 4, 4), "cond": True}, {"config": (128, 128, 128, 3, 4), "cond": True}, - {"config": (128, 128, 128, 5, 4), "cond": True}, - {"config": (128, 128, 128, 5, 8), "cond": True}, - {"config": (128, 128, 128, 6, 8), "cond": True}, {"config": (128, 128, 64, 4, 8), "cond": True}, ] - scaled_mm_kernel_configs = [ {"config": (128, 256, 32, 3, 8), "cond": True}, {"config": (256, 128, 32, 3, 8), "cond": True}, @@ -337,6 +337,18 @@ def filtered_configs( {"config": (32, 256, 64, 6, 4), "cond": True}, ] +scaled_persistent_mm_kernel_configs = [ + {"config": (128, 128, 64, 3, 8), "cond": True}, + {"config": (128, 128, 128, 3, 8), "cond": True}, + {"config": (128, 128, 128, 4, 8), "cond": True}, + {"config": (128, 128, 128, 4, 4), "cond": True}, + {"config": (128, 128, 128, 3, 4), "cond": True}, + {"config": (128, 128, 128, 5, 4), "cond": True}, + {"config": (128, 128, 128, 5, 8), "cond": True}, + {"config": (128, 128, 128, 6, 8), "cond": True}, + {"config": (128, 128, 64, 4, 8), "cond": True}, +] + # Create filtered list of configs based on cond evaluation mm_platform_configs = tuple( @@ -359,15 +371,19 @@ def filtered_configs( for config in mixed_mm_kernel_configs if config["cond"] ) +persistent_mm_platform_configs = tuple( + cast(Tuple[int, int, int, int, int], config["config"]) + for config in persistent_mm_kernel_configs + if config["cond"] +) scaled_mm_platform_configs = tuple( cast(Tuple[int, int, int, int, int], config["config"]) for config in scaled_mm_kernel_configs if config["cond"] ) - -persistent_mm_platform_configs = tuple( +scaled_persistent_mm_platform_configs = tuple( cast(Tuple[int, int, int, int, int], config["config"]) - for config in persistent_mm_kernel_configs + for config in scaled_persistent_mm_kernel_configs if config["cond"] ) @@ -399,13 +415,19 @@ def filtered_configs( configs=mixed_mm_platform_configs, ) +persistent_mm_configs = functools.partial( + filtered_configs, + configs=persistent_mm_platform_configs, +) + scaled_mm_configs = functools.partial( filtered_configs, configs=scaled_mm_platform_configs, ) -persistent_mm_configs = functools.partial( - filtered_configs, configs=persistent_mm_platform_configs +scaled_persistent_mm_configs = functools.partial( + filtered_configs, + configs=scaled_persistent_mm_platform_configs, ) @@ -416,7 +438,7 @@ def mm_grid(m, n, meta): return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), 1, 1) -def persistent_grid(M: int, N: int, meta: Dict[str, Any]): +def persistent_mm_grid(M: int, N: int, meta: Dict[str, Any]): """Defines the grid for persistent kernels.""" return ( min(meta["NUM_SMS"], cdiv(M, meta["BLOCK_M"]) * cdiv(N, meta["BLOCK_N"])), @@ -456,6 +478,15 @@ def mm_options(config, sym_m, sym_n, sym_k, layout, b_prologue_cast_type=None): ) +def persistent_mm_options(mat1, mat2): + return dict( + A_ROW_MAJOR=not mat1.layout.is_transposed(), + B_ROW_MAJOR=not mat2.layout.is_transposed(), + NUM_SMS=get_num_sms(), + TMA_SIZE=TMA_DESCRIPTOR_SIZE, + ) + + def mm_args( mat1, mat2, diff --git a/torch/_inductor/kernel/mm_scaled.py b/torch/_inductor/kernel/mm_scaled.py index ad8994416c6b85..5f7747146496b8 100644 --- a/torch/_inductor/kernel/mm_scaled.py +++ b/torch/_inductor/kernel/mm_scaled.py @@ -8,7 +8,6 @@ from torch.utils._triton import has_triton_tma_device from .. import config as inductor_config -from ..codegen.common import WorkspaceArg, WorkspaceZeroMode from ..config import triton as triton_config from ..ir import _IntLike, ChoiceCaller, Layout, StorageBox, TensorBox from ..lowering import add_layout_constraint, constrain_to_fx_strides, register_lowering @@ -19,18 +18,24 @@ realize_inputs, TritonTemplate, ) -from ..utils import use_aten_gemm_kernels, use_ck_gemm_template, use_triton_template +from ..utils import ( + get_num_sms, + get_tma_workspace_arg, + TMA_DESCRIPTOR_SIZE, + use_aten_gemm_kernels, + use_ck_gemm_template, + use_triton_template, +) from .mm_common import ( _is_static_problem, mm_args, mm_grid, - persistent_grid, - persistent_mm_configs, + persistent_mm_grid, scaled_mm_configs, + scaled_persistent_mm_configs, ) -_TMA_SIZE = 128 log = logging.getLogger(__name__) aten = torch.ops.aten @@ -110,10 +115,9 @@ def apply_scaling( k_tiles = tl.cdiv(K, BLOCK_K) num_tiles = num_pid_m * num_pid_n - workspace_base = ws_ptr + start_pid * 3 * TMA_SIZE + workspace_base = ws_ptr + start_pid * 2 * TMA_SIZE a_desc_ptr = workspace_base b_desc_ptr = workspace_base + TMA_SIZE - c_desc_ptr = workspace_base + 2 * TMA_SIZE triton.language.extra.cuda.experimental_device_tensormap_create2d( desc_ptr=a_desc_ptr, @@ -204,7 +208,7 @@ def apply_scaling( scaled_mm_device_tma_template = TritonTemplate( name="scaled_mm_device_tma", - grid=persistent_grid, + grid=persistent_mm_grid, source=device_tma + load_scales + apply_scaling, ) @@ -445,8 +449,8 @@ def scaled_mm_options_device_tma( # type: ignore[no-untyped-def] num_warps=config.num_warps, # tensor-wise scaling if scalar scales SCALING_ROWWISE=len(scale_a.get_size()) == 2, - TMA_SIZE=_TMA_SIZE, - NUM_SMS=NUM_SMS, + TMA_SIZE=TMA_DESCRIPTOR_SIZE, + NUM_SMS=get_num_sms(), **config.kwargs, ) @@ -488,25 +492,6 @@ def scaled_mm_options( # type: ignore[no-untyped-def] add_layout_constraint(aten._scaled_mm.default, constrain_to_fx_strides) -def get_workspace_size( - num_sms: int, TMA_SIZE: int = _TMA_SIZE, NUM_TMA_DESCRIPTORS: int = 3 -) -> int: - """Device side TMA requires a workspace buffer to be allocated in global memory.""" - return num_sms * NUM_TMA_DESCRIPTORS * TMA_SIZE - - -def get_workspace_arg(num_sms: int, device: torch.device) -> WorkspaceArg: - """Builds and returns a WorkspaceArg for the device side TMA workspace buffer.""" - size = get_workspace_size(num_sms) - zero_mode = WorkspaceZeroMode.from_bool(False) - return WorkspaceArg( - count=size, - zero_mode=zero_mode, - device=device, - outer_name=WorkspaceArg.unique_name(), - ) - - def use_persistent_tma(k: sympy.core.numbers.Integer, has_bias: bool) -> bool: available = has_triton_tma_device() and triton_config.enable_persistent_tma_matmul # _determine_swizzle_mode_2d requires BLOCK_K to be at least 32 contiguous bytes @@ -553,11 +538,11 @@ def tuned_scaled_mm( if use_aten_gemm_kernels(): choices.append(aten_choice) - _static_shape, is_nonzero = _is_static_problem(layout) + _, is_nonzero = _is_static_problem(layout) if is_nonzero and use_triton_template(layout, enable_float8=True): if use_persistent_tma(k, bias is not None): - for config in persistent_mm_configs(m, n, k): + for config in scaled_persistent_mm_configs(m, n, k): kwargs = scaled_mm_options_device_tma( config, m, n, k, layout, scale_a, scale_b, use_fast_accum ) @@ -566,8 +551,9 @@ def tuned_scaled_mm( choices, input_nodes=input_nodes, layout=layout, - workspace_arg=get_workspace_arg( - kwargs["NUM_SMS"], mat_a.get_device() + workspace_arg=get_tma_workspace_arg( + num_tma_descriptors=2, + device=mat_a.get_device(), ), **kwargs, ) diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 209d6205c6ff73..e4c072e9af8d93 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -2670,6 +2670,18 @@ def log_fusion(ms_fused: float, ms1: float, ms2: float) -> None: if not isinstance(choice, torch._inductor.ir.TritonTemplateCallerBase): continue + # For prologue fusion we check if the underlying template of the choice + # supports all allowed prologue inputs. If not, we skip this choice in + # the fusion benchmark. + # TODO: Remove this check after all Triton templates support prologue fusion. + # Currently, persistent+TMA Triton template does not due to the TMA-based loads. + if ( + not epilogue_fusion + and hasattr(choice, "allowed_prologue_inps") + and choice.allowed_prologue_inps != multi_node.allowed_prologue_inps + ): + continue + if unfused_time >= ms1 + ms2: break diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 2fe9bf3dc2b1a1..e0e0c9afd6ab85 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -10,6 +10,7 @@ import math import operator import os +import re import sys import textwrap import time @@ -1644,6 +1645,25 @@ def __call__( # TODO(nmacchioni): remove once CI tests are fixed choices = [choice for choice in choices if choice is not None] + if config.test_configs.autotune_choice_name_regex is not None: + choices = [ + c + for c in choices + if re.search( + config.test_configs.autotune_choice_name_regex, + c.name, + ) + ] + if config.test_configs.autotune_choice_desc_regex is not None: + choices = [ + c + for c in choices + if re.search( + config.test_configs.autotune_choice_desc_regex, + c.description, + ) + ] + if mm_file_name := get_mm_log_filename(): M, K = input_nodes[-2].get_size()[:2] N = input_nodes[-1].get_size()[-1] @@ -1846,20 +1866,14 @@ def get_timings(): return timings - # Assume the same base template, with same prologue support. - # We could relax this assumption by taking union of allowed prologue inputs, - # and within benchmark fusion not allow prologue fusion for choices which dont support it - # No use case of that yet. - allowed_prologue_inps: Optional[OrderedSet[str]] = None + # We take the union of allowed prologue inputs from all choices, + # and, within benchmark fusion, don't allow prologue fusion for + # choices which dont support the whole union. + allowed_prologue_inps: OrderedSet[str] = OrderedSet() for c in choices: if isinstance(c, TritonTemplateCaller): - if allowed_prologue_inps is None: - allowed_prologue_inps = c.allowed_prologue_inps - else: - assert allowed_prologue_inps == c.allowed_prologue_inps + allowed_prologue_inps |= c.allowed_prologue_inps - if allowed_prologue_inps is None: - allowed_prologue_inps = OrderedSet() return torch._inductor.ir.TensorBox.create( torch._inductor.ir.MultiTemplateBuffer( layout, diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index b7d68a32933f31..c9293ad0f69b7e 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -51,6 +51,7 @@ if TYPE_CHECKING: from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND + from .codegen.common import WorkspaceArg from torch.utils._ordered_set import OrderedSet from torch.utils._pytree import tree_map_only @@ -102,6 +103,9 @@ def get_gpu_type(): GPU_ALIGN_BYTES = 16 ALIGNMENT = 16 +TMA_ALIGNMENT = 16 +TMA_DESCRIPTOR_SIZE = 128 + ALIGN_BYTES = 64 assert (ALIGN_BYTES & (ALIGN_BYTES - 1)) == 0 and ALIGN_BYTES >= 8, "must be power of 2" @@ -1149,6 +1153,28 @@ def is_big_gpu(index_or_device: Union[int, torch.device] = 0) -> bool: return True +@functools.lru_cache +def get_num_sms() -> int: + return torch.cuda.get_device_properties("cuda").multi_processor_count + + +def get_tma_workspace_arg( + num_tma_descriptors: int, + device: torch.device, +) -> WorkspaceArg: + """Builds and returns a WorkspaceArg for the device side TMA workspace buffer.""" + from .codegen.common import WorkspaceArg, WorkspaceZeroMode + + zero_mode = WorkspaceZeroMode.from_bool(False) + size = get_num_sms() * num_tma_descriptors * TMA_DESCRIPTOR_SIZE + return WorkspaceArg( + count=size, + zero_mode=zero_mode, + device=device, + outer_name=WorkspaceArg.unique_name(), + ) + + def use_max_autotune() -> bool: return ( config.max_autotune or config.max_autotune_gemm or config.search_autotune_cache @@ -1197,6 +1223,37 @@ def use_triton_template(layout, *, enable_int32=False, enable_float8=False): ) +def use_triton_tma_template(*matrices): + from torch.utils._triton import has_triton_tma_device + + from .virtualized import V + + def _is_tma_compatible(x): + if len(x.get_size()) != 2: + return False + + dtype = x.get_dtype() + if dtype not in (torch.float16, torch.bfloat16): + return False + + layout = x.get_layout() + transposed = layout.is_transposed() + if not (layout.is_contiguous() or transposed): + return False + + inner_dim = layout.size[1] + if transposed: + inner_dim = layout.size[0] + inner_bytes = inner_dim * dtype.itemsize + return V.graph.sizevars.statically_known_multiple_of(inner_bytes, TMA_ALIGNMENT) + + return ( + config.triton.enable_persistent_tma_matmul + and has_triton_tma_device() + and all(_is_tma_compatible(m) for m in matrices) + ) + + def use_cutlass_template(layout, m, n, k): from .virtualized import V