forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add persistent+TMA version of Triton mm and addmm (pytorch#142101)
This PR adds persistent+TMA versions (Triton template + the corresponding infra) for the `tuned_mm` and `tuned_addmm` lowerings. The persistent+TMA choices are added to the GEMM autotuning if (checked by the `use_triton_tma_template` helper): 1. The min. hardware and Triton version requirements are met for the TMA support. 2. The GEMM inputs are compatible with the Triton TMA API (i.e., 16-byte aligned and contiguous). 3. The `config.triton.enable_persistent_tma_matmul` is set to `True`. Additional notes: 1. As added in this PR, the TMA uses are not compatible with prolog / epilogue fusion. To this end, in the new Triton template we currently support: TMA-based loads of A/B, but no prologue fusion; epilogue fusion, but no TMA-based stores of C. TMA + fusion compatibility can be added as a follow-up. 2. The current Triton TMA API (`experimental_device_tensormap_create2d`) does not support strides. Due to this, we limit the applicability of the new Triton template to the cases where the inputs are contiguous. 3. The transposed layouts of A and / or B are supported by passing the constexpr flags to the kernel and adjusting the ordering of the block sizes accordingly in the kernel code (this should have no effect on the kernel perf, as decided at the Triton compilation time). 4. After the next Triton pin update, we can switch to the tensor descriptor API (landed recently in triton-lang/triton#5290) in the new Triton template, which should allow lifting 2 and 3 above. 5. The configs for the new Triton template in `persistent_mm_kernel_configs` are preliminary. We should do more perf exploration and possibly augment the config in a follow-up. 6. This PR is rebased onto and unifies with two related PRs landed previously: pytorch#142045 (some infra unification with the persistent+TMA template for _scaled_mm) and pytorch#134532 (add possibility to disable prolog fusion for selected choices). 7. The current Triton TMA API only supports 1D and 2D descriptors (even after triton-lang/triton#5290, see [here](https://github.com/triton-lang/triton/blob/9829ce87ccb333a2b264b3a80b39a534bfa865ac/python/triton/language/core.py#L1957)). For now, this blocks adding persistent+TMA template for `torch.bmm`. Pull Request resolved: pytorch#142101 Approved by: https://github.com/drisspg, https://github.com/eellison
- Loading branch information
1 parent
17b71e5
commit e885225
Showing
9 changed files
with
453 additions
and
61 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.