Skip to content

Commit

Permalink
[inductor] don't allow triton config pre_hook (pytorch#134633)
Browse files Browse the repository at this point in the history
The caching autotuner caches triton configs, and it doesn't try to hash or save the pre_hook from the config if it exists. If we had a config that had a pre_hook, then we might autotune -> save the config (without the pre_config) -> later load the saved config and try to run it, but this time without the pre_hook.

So this PR adds an assert and deletes the pre_hook handling. We can be confident that we didn't have functional pre_hooks, because the pre_hook handling tries to use `self.arg_name`, which doesn't exist.

Pull Request resolved: pytorch#134633
Approved by: https://github.com/shunting314, https://github.com/jansel
  • Loading branch information
davidberard98 authored and pytorchmergebot committed Aug 30, 2024
1 parent e21d7b7 commit 9e0ddc0
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 12 deletions.
63 changes: 61 additions & 2 deletions test/inductor/test_triton_heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,20 @@

try:
import triton # noqa: F401
import triton.language as tl
except ImportError:
if __name__ == "__main__":
sys.exit(0)
raise unittest.SkipTest("requires triton") # noqa: B904

from torch._inductor import config
from torch._inductor.runtime.hints import TRITON_MAX_BLOCK
from torch._inductor.runtime.triton_heuristics import triton_config
from torch._inductor.runtime.hints import (
DeviceProperties,
HeuristicType,
TRITON_MAX_BLOCK,
)
from torch._inductor.runtime.triton_helpers import math as tl_math
from torch._inductor.runtime.triton_heuristics import CachingAutotuner, triton_config
from torch._inductor.test_case import run_tests, TestCase


Expand Down Expand Up @@ -81,6 +87,59 @@ def test_artificial_zgrid(self):
def test_artificial_grid_cpp_wrapper(self):
self._test_artificial_zgrid()

def _get_cos_kernel_caching_autotuner_args(self):
from triton.compiler.compiler import AttrsDescriptor

@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr):
xnumel = 16
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), xmask)
tmp1 = tl_math.cos(tmp0)
tl.store(out_ptr0 + (x0), tmp1, xmask)

triton_meta = {
"signature": {0: "*fp32", 1: "*fp32", 2: "i32"},
"device": DeviceProperties.create(torch.device("cuda")),
"constants": {},
"configs": [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())],
}

configs = [
triton_config([16], 64),
triton_config([256], 64),
]

inductor_meta = {}

return {
"fn": triton_,
"triton_meta": triton_meta,
"configs": configs,
"save_cache_hook": False,
"mutated_arg_names": [],
"heuristic_type": HeuristicType.POINTWISE,
"inductor_meta": inductor_meta,
}

@skipIfXpu
def test_pre_hook_assert(self):
# assert if any of the configs passed to the CachingAutotuner have pre-hooks
args = self._get_cos_kernel_caching_autotuner_args()

def pre_hook(kwargs):
if "in_ptr0" in kwargs:
kwargs["in_ptr0"].zero_()

for cfg in args["configs"]:
cfg.pre_hook = pre_hook

with self.assertRaisesRegex(AssertionError, "pre_hook"):
autotuner = CachingAutotuner(**args)


if __name__ == "__main__":
if IS_LINUX and HAS_GPU:
Expand Down
11 changes: 11 additions & 0 deletions torch/_inductor/runtime/runtime_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,17 @@ def triton_config_to_hashable(cfg):
return tuple(items)


def validate_triton_config(cfg):
# [Note: Triton pre_hook in inductor]
# pre-hook is a lambda function, which we don't attempt to serialize.
# right now, if a pre-hook is attached to the config, it will not be saved;
# and then it won't be used when the config is loaded from cache.
# So we assert - if we do get a pre_hook, it might get ignored after caching.
assert (
getattr(cfg, "pre_hook", None) is None
), "triton configs with pre_hooks not supported"


def create_bandwidth_info_str(ms, num_gb, gb_per_s, prefix="", suffix="", color=True):
info_str = f"{prefix}{ms:.3f}ms \t{num_gb:.3f} GB \t {gb_per_s:7.2f}GB/s{suffix}"
slow = ms > 0.012 and gb_per_s < 650
Expand Down
15 changes: 5 additions & 10 deletions torch/_inductor/runtime/triton_heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
get_num_bytes,
next_power_of_2,
triton_config_to_hashable,
validate_triton_config,
)


Expand Down Expand Up @@ -182,6 +183,10 @@ def __init__(
super().__init__()

assert len(configs) > 0, "Non-empty TritonConfig list required for compiling"
# makes sure there are no pre-hooks on any of the triton configs
for cfg in configs:
validate_triton_config(cfg)

self.fn = fn
self.device_props: DeviceProperties = triton_meta["device"]
self.triton_meta = {
Expand Down Expand Up @@ -655,11 +660,6 @@ def bench(self, launcher, *args, grid, with_profiler=False, **kwargs):
stream = device_interface.get_raw_stream(device_interface.current_device())

def kernel_call():
if launcher.config.pre_hook is not None:
launcher.config.pre_hook(
{**dict(zip(self.arg_names, args)), **launcher.config.kwargs}
)

cloned_args, cloned_kwargs = self.clone_args(*args, **kwargs)
launcher(
*cloned_args,
Expand Down Expand Up @@ -848,11 +848,6 @@ def run(self, *args, grid, stream, **kwargs):
if launcher.store_cubin:
self.save_gpu_kernel(grid, stream, launcher)

if launcher.config.pre_hook is not None:
launcher.config.pre_hook(
{**dict(zip(self.arg_names, args)), **launcher.config.kwargs, **kwargs}
)

if os.environ.get("TORCHINDUCTOR_DUMP_LAUNCH_PARAMS", 0) == "1":
_dump_launch_params(args, kwargs, launcher, self.fn.__name__)

Expand Down

0 comments on commit 9e0ddc0

Please sign in to comment.