From 9a00cc2ea2d3b79e071a8a05a3e0c2b90e081e01 Mon Sep 17 00:00:00 2001 From: Bo Li Date: Tue, 7 Jan 2025 03:37:47 +0800 Subject: [PATCH] bugfix: FusedAddRMSNorm kernels might require more than 48KB shared memory when d is large. (#718) The original implementation will cause RuntimeError: invalid argument when hidden_size=16384. --------- Co-authored-by: Bo Li --- include/flashinfer/norm.cuh | 2 ++ tests/test_norm.py | 8 ++++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/include/flashinfer/norm.cuh b/include/flashinfer/norm.cuh index ee807ab0..5a50c639 100644 --- a/include/flashinfer/norm.cuh +++ b/include/flashinfer/norm.cuh @@ -213,6 +213,7 @@ cudaError_t FusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_siz DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { auto kernel = FusedAddRMSNormKernel; + FLASHINFER_CUDA_CALL(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); }); @@ -255,6 +256,7 @@ cudaError_t GemmaFusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batc DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { auto kernel = FusedAddRMSNormKernel; + FLASHINFER_CUDA_CALL(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); }); diff --git a/tests/test_norm.py b/tests/test_norm.py index 8827f5c8..392e47b7 100644 --- a/tests/test_norm.py +++ b/tests/test_norm.py @@ -65,7 +65,7 @@ def fused_add_rms_norm(x, residual, weight, eps): @pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) -@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192]) +@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384]) @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("specify_out", [True, False]) def test_norm(batch_size, hidden_size, dtype, specify_out): @@ -83,7 +83,7 @@ def test_norm(batch_size, hidden_size, dtype, specify_out): @pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) -@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192]) +@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384]) @pytest.mark.parametrize("dtype", [torch.float16]) def test_fused_add_rmsnorm(batch_size, hidden_size, dtype): eps = 1e-6 @@ -105,7 +105,7 @@ def test_fused_add_rmsnorm(batch_size, hidden_size, dtype): @pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) -@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192]) +@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384]) @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("specify_out", [True, False]) def test_gemma_norm(batch_size, hidden_size, dtype, specify_out): @@ -123,7 +123,7 @@ def test_gemma_norm(batch_size, hidden_size, dtype, specify_out): @pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) -@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192]) +@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384]) @pytest.mark.parametrize("dtype", [torch.float16]) def test_gemma_fused_add_rmsnorm(batch_size, hidden_size, dtype): eps = 1e-6