From 22ea932b516e1cf4e149bffa9e0d3ae9ebbd4246 Mon Sep 17 00:00:00 2001 From: The jax_triton Authors Date: Mon, 5 Jun 2023 05:56:04 -0700 Subject: [PATCH] Raise exception on too much shared memory requested PiperOrigin-RevId: 537849760 --- jax_triton/triton_lib.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index d71ae845..9c4f2920 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -15,6 +15,7 @@ """Module for calling Triton kernels from JAX.""" import functools import os +import sys import types from typing import Any, Callable, Dict, Optional, Protocol, Sequence, Tuple, Union import weakref @@ -134,6 +135,20 @@ def ptx_get_kernel_name(module) -> str: return tc.get_kernel_name(module, pattern='// .globl') +# From https://en.wikipedia.org/wiki/CUDA#Technical_Specification +# "Amount of shared memory per multiprocessor". In bytes. +_SHARED_MEMORY_PER_SM = { + 70: 98304, + 72: 98304, + 75: 65536, + 80: 167936, + 86: 102400, + 87: 167936, + 89: 102400, + 90: 233472, +} + + def compile_ttir_inplace( ttir, device: int = 0, @@ -163,6 +178,8 @@ def compile_ttir_inplace( ttgir.dump() raise ValueError("TTGIR->LLIR pass failed!") from e shared_mem = _triton.get_shared_memory_size(ttgir) + if shared_mem > _SHARED_MEMORY_PER_SM.get(compute_capability, sys.maxsize): + raise RuntimeError("Shared memory requested exceeds device resources.") if dump: print(llir) ptx = tc.llir_to_ptx(llir, compute_capability)