Skip to content

Commit

Permalink
bugfix: FusedAddRMSNorm kernels might require more than 48KB shared m…
Browse files Browse the repository at this point in the history
…emory when d is large. (#718)

The original implementation will cause RuntimeError: invalid argument
when hidden_size=16384.

---------

Co-authored-by: Bo Li <[email protected]>
  • Loading branch information
bobboli and Bo Li authored Jan 6, 2025
1 parent f72745b commit 9a00cc2
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
2 changes: 2 additions & 0 deletions include/flashinfer/norm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ cudaError_t FusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_siz

DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
auto kernel = FusedAddRMSNormKernel<VEC_SIZE, T>;
FLASHINFER_CUDA_CALL(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
});

Expand Down Expand Up @@ -255,6 +256,7 @@ cudaError_t GemmaFusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batc

DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
auto kernel = FusedAddRMSNormKernel<VEC_SIZE, T>;
FLASHINFER_CUDA_CALL(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
});

Expand Down
8 changes: 4 additions & 4 deletions tests/test_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def fused_add_rms_norm(x, residual, weight, eps):


@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192])
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("specify_out", [True, False])
def test_norm(batch_size, hidden_size, dtype, specify_out):
Expand All @@ -83,7 +83,7 @@ def test_norm(batch_size, hidden_size, dtype, specify_out):


@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192])
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
@pytest.mark.parametrize("dtype", [torch.float16])
def test_fused_add_rmsnorm(batch_size, hidden_size, dtype):
eps = 1e-6
Expand All @@ -105,7 +105,7 @@ def test_fused_add_rmsnorm(batch_size, hidden_size, dtype):


@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192])
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("specify_out", [True, False])
def test_gemma_norm(batch_size, hidden_size, dtype, specify_out):
Expand All @@ -123,7 +123,7 @@ def test_gemma_norm(batch_size, hidden_size, dtype, specify_out):


@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192])
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
@pytest.mark.parametrize("dtype", [torch.float16])
def test_gemma_fused_add_rmsnorm(batch_size, hidden_size, dtype):
eps = 1e-6
Expand Down

0 comments on commit 9a00cc2

Please sign in to comment.