diff --git a/include/flashinfer/sampling.cuh b/include/flashinfer/sampling.cuh index 32e9f4df..2ae21947 100644 --- a/include/flashinfer/sampling.cuh +++ b/include/flashinfer/sampling.cuh @@ -16,14 +16,17 @@ #ifndef FLASHINFER_SAMPLING_CUH_ #define FLASHINFER_SAMPLING_CUH_ -#include -#include -#include +// #include +// #include +// #include +// #include +// #include +#include +#include #include #include "math.cuh" #include "utils.cuh" -#include "vec_dtypes.cuh" namespace flashinfer { @@ -59,10 +62,11 @@ struct BoolDiffOp { } }; -template struct SamplingTempStorage { union { + typename BlockLoad::TempStorage load; typename BlockScan::TempStorage scan; typename BlockReduce::TempStorage reduce; typename BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage reduce_pair; @@ -80,8 +84,9 @@ struct SamplingTempStorage { template __device__ __forceinline__ void DeviceSamplingFromProb( - uint32_t i, uint32_t d, T threshold, T u, vec_t prob_vec, T& aggregate, - SamplingTempStorage* temp_storage) { + uint32_t i, uint32_t d, T threshold, T u, T prob_vec[VEC_SIZE], T& aggregate, + SamplingTempStorage* + temp_storage) { const uint32_t tx = threadIdx.x; T prob_greater_than_threshold[VEC_SIZE]; T inclusive_cdf[VEC_SIZE]; @@ -137,23 +142,24 @@ __global__ void SamplingFromProbKernel(DType* probs, DType* uniform_samples, IdT const uint32_t bx = blockIdx.x, tx = threadIdx.x; const uint32_t row_idx = row_indices == nullptr ? bx : row_indices[bx]; - extern __shared__ __align__( - alignof(SamplingTempStorage)) + extern __shared__ __align__(alignof( + SamplingTempStorage)) uint8_t smem_sampling[]; auto& temp_storage = reinterpret_cast< - SamplingTempStorage&>(smem_sampling); + SamplingTempStorage&>( + smem_sampling); temp_storage.data.sampled_id = d - 1; __syncthreads(); - vec_t probs_vec; + DType probs_vec[VEC_SIZE]; DType aggregate(0); float u = uniform_samples[bx]; for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - probs_vec.fill(DType(0)); - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); - } + BlockLoad(temp_storage.block_prim.load) + .Load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE, probs_vec, + d - i * BLOCK_THREADS * VEC_SIZE, DType(0)); + __syncthreads(); DeviceSamplingFromProb( i, d, DType(0), u, probs_vec, aggregate, &temp_storage); @@ -172,13 +178,14 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, DType* uniform_samples, const uint32_t batch_size = gridDim.x; const uint32_t bx = blockIdx.x, tx = threadIdx.x; - extern __shared__ __align__( - alignof(SamplingTempStorage)) + extern __shared__ __align__(alignof( + SamplingTempStorage)) uint8_t smem_sampling[]; auto& temp_storage = reinterpret_cast< - SamplingTempStorage&>(smem_sampling); + SamplingTempStorage&>( + smem_sampling); - vec_t probs_vec; + DType probs_vec[VEC_SIZE]; DType aggregate; DType q = DType(0); DType pivot = DType(0); @@ -189,10 +196,10 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, DType* uniform_samples, DType u = uniform_samples[round * batch_size + bx] * (DType(1) - q); aggregate = DType(0); for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - probs_vec.fill(DType(0)); - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); - } + BlockLoad(temp_storage.block_prim.load) + .Load(probs + bx * d + i * BLOCK_THREADS * VEC_SIZE, probs_vec, + d - i * BLOCK_THREADS * VEC_SIZE, DType(0)); + __syncthreads(); DeviceSamplingFromProb( i, d, pivot, u, probs_vec, aggregate, &temp_storage); @@ -206,10 +213,10 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, DType* uniform_samples, Pair aggregate_leq_pivot{DType(0), 0}; for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - probs_vec.fill(DType(0)); - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); - } + BlockLoad(temp_storage.block_prim.load) + .Load(probs + bx * d + i * BLOCK_THREADS * VEC_SIZE, probs_vec, + d - i * BLOCK_THREADS * VEC_SIZE, DType(0)); + __syncthreads(); Pair probs_leq_pivot[VEC_SIZE]; #pragma unroll @@ -266,13 +273,14 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples, } const uint32_t row_idx = row_indices == nullptr ? bx : row_indices[bx]; - extern __shared__ __align__( - alignof(SamplingTempStorage)) + extern __shared__ __align__(alignof( + SamplingTempStorage)) uint8_t smem_sampling[]; auto& temp_storage = reinterpret_cast< - SamplingTempStorage&>(smem_sampling); + SamplingTempStorage&>( + smem_sampling); - vec_t probs_vec; + DType probs_vec[VEC_SIZE]; DType aggregate; DType q = DType(0); DType pivot = DType(0); @@ -283,10 +291,10 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples, DType u = uniform_samples[round * batch_size + bx] * (DType(1) - q); aggregate = DType(0); for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - probs_vec.fill(DType(0)); - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); - } + BlockLoad(temp_storage.block_prim.load) + .Load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE, probs_vec, + d - i * BLOCK_THREADS * VEC_SIZE, DType(0)); + __syncthreads(); DeviceSamplingFromProb( i, d, pivot, u, probs_vec, aggregate, &temp_storage); @@ -300,10 +308,10 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples, DType aggregate_leq_pivot = DType(0); for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - probs_vec.fill(DType(0)); - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); - } + BlockLoad(temp_storage.block_prim.load) + .Load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE, probs_vec, + d - i * BLOCK_THREADS * VEC_SIZE, DType(0)); + __syncthreads(); DType probs_leq_pivot[VEC_SIZE]; #pragma unroll @@ -346,18 +354,16 @@ template cudaError_t SamplingFromProb(T* probs, T* uniform_samples, IdType* output, uint32_t batch_size, uint32_t d, cudaStream_t stream = 0) { constexpr uint32_t BLOCK_THREADS = 1024; - const uint32_t vec_size = std::gcd(16 / sizeof(T), d); + constexpr uint32_t VEC_SIZE = 16 / sizeof(T); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); IdType* row_indices_placeholder = nullptr; void* args[] = {&probs, &uniform_samples, &output, &row_indices_placeholder, &d}; - const uint32_t smem_size = sizeof(SamplingTempStorage); + const uint32_t smem_size = + sizeof(SamplingTempStorage); - DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { - auto kernel = - SamplingFromProbKernel; - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); - }); + auto kernel = SamplingFromProbKernel; + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); return cudaSuccess; } @@ -366,17 +372,15 @@ cudaError_t ParallelSamplingFromProb(T* probs, T* uniform_samples, IdType* outpu IdType* row_indices, uint32_t batch_size, uint32_t d, cudaStream_t stream = 0) { constexpr uint32_t BLOCK_THREADS = 1024; - const uint32_t vec_size = std::gcd(16 / sizeof(T), d); + constexpr uint32_t VEC_SIZE = 16 / sizeof(T); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); void* args[] = {&probs, &uniform_samples, &output, &row_indices, &d}; - const uint32_t smem_size = sizeof(SamplingTempStorage); + const uint32_t smem_size = + sizeof(SamplingTempStorage); - DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { - auto kernel = - SamplingFromProbKernel; - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); - }); + auto kernel = SamplingFromProbKernel; + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); return cudaSuccess; } @@ -385,20 +389,19 @@ cudaError_t TopKSamplingFromProb(T* probs, T* uniform_samples, IdType* output, b IdType top_k, uint32_t batch_size, uint32_t d, uint32_t max_top_k_rounds, cudaStream_t stream = 0) { constexpr uint32_t BLOCK_THREADS = 1024; - const uint32_t vec_size = std::gcd(16 / sizeof(T), d); + constexpr uint32_t VEC_SIZE = 16 / sizeof(T); - const uint32_t smem_size = sizeof(SamplingTempStorage); + const uint32_t smem_size = + sizeof(SamplingTempStorage); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); void* args[] = {&probs, &uniform_samples, &output, &success, &top_k, &d, &max_top_k_rounds}; - DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { - auto kernel = - TopKSamplingFromProbKernel; - FLASHINFER_CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); - }); + auto kernel = + TopKSamplingFromProbKernel; + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); return cudaSuccess; } @@ -407,9 +410,10 @@ cudaError_t TopPSamplingFromProb(T* probs, T* uniform_samples, IdType* output, b T top_p, uint32_t batch_size, uint32_t d, uint32_t max_top_p_rounds, cudaStream_t stream = 0) { constexpr uint32_t BLOCK_THREADS = 1024; - const uint32_t vec_size = std::gcd(16 / sizeof(T), d); + constexpr uint32_t VEC_SIZE = 16 / sizeof(T); - const uint32_t smem_size = sizeof(SamplingTempStorage); + const uint32_t smem_size = + sizeof(SamplingTempStorage); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); IdType* row_indices_placeholder = nullptr; @@ -424,21 +428,22 @@ cudaError_t TopPSamplingFromProb(T* probs, T* uniform_samples, IdType* output, b &d, &max_top_p_rounds}; - DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { - auto kernel = - TopPSamplingFromProbKernel; - FLASHINFER_CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); - }); + auto kernel = + TopPSamplingFromProbKernel; + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); return cudaSuccess; } -template +template struct RenormTempStorage { union { + typename BlockLoad::TempStorage load; typename BlockReduce::TempStorage reduce; typename BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage reduce_pair; + typename BlockStore::TempStorage store; } block_prim; struct { T max_val; @@ -456,20 +461,21 @@ __global__ void TopPRenormProbKernel(DType* probs, IdType* renormed_prob, float const uint32_t bx = blockIdx.x, tx = threadIdx.x; const uint32_t row_idx = bx; - extern __shared__ __align__(alignof(RenormTempStorage)) - uint8_t smem_renorm[]; + extern __shared__ __align__(alignof( + RenormTempStorage)) uint8_t smem_renorm[]; auto& temp_storage = - reinterpret_cast&>(smem_renorm); + reinterpret_cast&>( + smem_renorm); temp_storage.data.max_val = DType(0); - vec_t probs_vec; + DType probs_vec[VEC_SIZE]; DType probs_greater_than_pivot[VEC_SIZE]; // pivot initialized to 0 DType threadlocal_max_val = DType(0); for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - probs_vec.fill(DType(0)); - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); - } + BlockLoad(temp_storage.block_prim.load) + .Load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE, probs_vec, + d - i * BLOCK_THREADS * VEC_SIZE, DType(0)); + __syncthreads(); #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { probs_greater_than_pivot[j] = probs_vec[j]; @@ -494,10 +500,10 @@ __global__ void TopPRenormProbKernel(DType* probs, IdType* renormed_prob, float DType threadlocal_sum(0); float mid = (low + high) / 2; for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - probs_vec.fill(DType(0)); - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); - } + BlockLoad(temp_storage.block_prim.load) + .Load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE, probs_vec, + d - i * BLOCK_THREADS * VEC_SIZE, DType(0)); + __syncthreads(); #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { probs_greater_than_pivot[j] = (probs_vec[j] > mid) ? probs_vec[j] : DType(0); @@ -524,17 +530,18 @@ __global__ void TopPRenormProbKernel(DType* probs, IdType* renormed_prob, float // normalize for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - probs_vec.fill(DType(0)); - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); - } + BlockLoad(temp_storage.block_prim.load) + .Load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE, probs_vec, + d - i * BLOCK_THREADS * VEC_SIZE, DType(0)); + __syncthreads(); #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { probs_vec[j] = (probs_vec[j] > low) ? probs_vec[j] * normalizer : DType(0); } - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.store(renormed_prob + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); - } + BlockStore(temp_storage.block_prim.store) + .Store(renormed_prob + row_idx * d + i * BLOCK_THREADS * VEC_SIZE, probs_vec, + d - i * BLOCK_THREADS * VEC_SIZE, DType(0)); + __syncthreads(); } } @@ -545,20 +552,21 @@ __global__ void TopKRenormProbKernel(DType* probs, IdType* renormed_prob, uint32 const uint32_t bx = blockIdx.x, tx = threadIdx.x; const uint32_t row_idx = bx; - extern __shared__ __align__(alignof(RenormTempStorage)) - uint8_t smem_renorm[]; + extern __shared__ __align__(alignof( + RenormTempStorage)) uint8_t smem_renorm[]; auto& temp_storage = - reinterpret_cast&>(smem_renorm); + reinterpret_cast&>( + smem_renorm); temp_storage.data.max_val = DType(0); - vec_t probs_vec; + DType probs_vec[VEC_SIZE]; DType probs_greater_than_pivot[VEC_SIZE]; // pivot initialized to 0 DType threadlocal_max_val = DType(0); for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - probs_vec.fill(DType(0)); - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); - } + BlockLoad(temp_storage.block_prim.load) + .Load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE, probs_vec, + d - i * BLOCK_THREADS * VEC_SIZE, DType(0)); + __syncthreads(); #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { probs_greater_than_pivot[j] = probs_vec[j]; @@ -584,10 +592,10 @@ __global__ void TopKRenormProbKernel(DType* probs, IdType* renormed_prob, uint32 Pair probs_greater_than_pivot_pair[VEC_SIZE]; // pivot initialized to 0 float mid = (low + high) / 2; for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - probs_vec.fill(DType(0)); - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); - } + BlockLoad(temp_storage.block_prim.load) + .Load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE, probs_vec, + d - i * BLOCK_THREADS * VEC_SIZE, DType(0)); + __syncthreads(); #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { probs_greater_than_pivot_pair[j] = { @@ -616,36 +624,35 @@ __global__ void TopKRenormProbKernel(DType* probs, IdType* renormed_prob, uint32 // normalize for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - probs_vec.fill(DType(0)); - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); - } + BlockLoad(temp_storage.block_prim.load) + .Load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE, probs_vec, + d - i * BLOCK_THREADS * VEC_SIZE, DType(0)); + __syncthreads(); #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { probs_vec[j] = (probs_vec[j] > low) ? probs_vec[j] * normalizer : DType(0); } - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.store(renormed_prob + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); - } + BlockStore(temp_storage.block_prim.store) + .Store(renormed_prob + row_idx * d + i * BLOCK_THREADS * VEC_SIZE, probs_vec, + d - i * BLOCK_THREADS * VEC_SIZE, DType(0)); + __syncthreads(); } } template cudaError_t TopPRenormProb(DType* probs, IdType* renormed_prob, float p, float eps, uint32_t batch_size, uint32_t d, cudaStream_t stream = 0) { - const uint32_t BLOCK_THREADS = 1024; - const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); + constexpr uint32_t BLOCK_THREADS = 1024; + constexpr uint32_t VEC_SIZE = 16 / sizeof(DType); - const uint32_t smem_size = sizeof(RenormTempStorage); + const uint32_t smem_size = sizeof(RenormTempStorage); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); void* args[] = {&probs, &renormed_prob, &p, &eps, &d}; - DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { - auto kernel = TopPRenormProbKernel; - FLASHINFER_CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); - }); + auto kernel = TopPRenormProbKernel; + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); return cudaSuccess; } @@ -653,18 +660,16 @@ template cudaError_t TopKRenormProb(DType* probs, IdType* renormed_prob, uint32_t k, float eps, uint32_t batch_size, uint32_t d, cudaStream_t stream = 0) { const uint32_t BLOCK_THREADS = 1024; - const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); + const uint32_t VEC_SIZE = 16 / sizeof(DType); - const uint32_t smem_size = sizeof(RenormTempStorage); + const uint32_t smem_size = sizeof(RenormTempStorage); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); void* args[] = {&probs, &renormed_prob, &k, &eps, &d}; - DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { - auto kernel = TopKRenormProbKernel; - FLASHINFER_CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); - }); + auto kernel = TopKRenormProbKernel; + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); return cudaSuccess; } @@ -677,11 +682,12 @@ __global__ void ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token const uint32_t bx = blockIdx.x, tx = threadIdx.x; const uint32_t row_idx = bx; - extern __shared__ __align__( - alignof(SamplingTempStorage)) + extern __shared__ __align__(alignof( + SamplingTempStorage)) uint8_t smem_sampling[]; auto& temp_storage = reinterpret_cast< - SamplingTempStorage&>(smem_sampling); + SamplingTempStorage&>( + smem_sampling); uint32_t pos = 0; for (pos = 0; pos < num_speculative_tokens; ++pos) { @@ -699,19 +705,20 @@ __global__ void ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token // sample from relu(target_probs - draft_probs) DType sum_relu_q_minus_p(0); - vec_t q_vec, p_vec; + DType q_vec[VEC_SIZE], p_vec[VEC_SIZE]; DType relu_q_minus_p[VEC_SIZE]; for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - q_vec.fill(DType(0)); - p_vec.fill(DType(0)); - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - q_vec.load(target_probs + row_idx * (num_speculative_tokens + 1) * d + - i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); - if (pos != num_speculative_tokens) { - // there is no draft_probs for the bonus token - p_vec.load(draft_probs + row_idx * num_speculative_tokens * d + - i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); - } + BlockLoad(temp_storage.block_prim.load) + .Load(target_probs + row_idx * (num_speculative_tokens + 1) * d + + i * BLOCK_THREADS * VEC_SIZE, + q_vec, d - i * BLOCK_THREADS * VEC_SIZE, DType(0)); + __syncthreads(); + if (pos != num_speculative_tokens) { + // there is no draft_probs for the bonus token + BlockLoad(temp_storage.block_prim.load) + .Load(draft_probs + row_idx * num_speculative_tokens * d + i * BLOCK_THREADS * VEC_SIZE, + p_vec, d - i * BLOCK_THREADS * VEC_SIZE, DType(0)); + __syncthreads(); } #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { @@ -732,19 +739,20 @@ __global__ void ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token DType aggregate_relu_q_minus_p(0); for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - q_vec.fill(DType(0)); - p_vec.fill(DType(0)); - if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - q_vec.load(target_probs + row_idx * (num_speculative_tokens + 1) * d + - i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); - if (pos != num_speculative_tokens) { - // there is no draft_probs for the bonus token - p_vec.load(draft_probs + row_idx * num_speculative_tokens * d + - i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); - } + BlockLoad(temp_storage.block_prim.load) + .Load(target_probs + row_idx * (num_speculative_tokens + 1) * d + + i * BLOCK_THREADS * VEC_SIZE, + q_vec, d - i * BLOCK_THREADS * VEC_SIZE, DType(0)); + __syncthreads(); + if (pos != num_speculative_tokens) { + // there is no draft_probs for the bonus token + BlockLoad(temp_storage.block_prim.load) + .Load(draft_probs + row_idx * num_speculative_tokens * d + i * BLOCK_THREADS * VEC_SIZE, + p_vec, d - i * BLOCK_THREADS * VEC_SIZE, DType(0)); + __syncthreads(); } - vec_t relu_q_minus_p_vec; + DType relu_q_minus_p_vec[VEC_SIZE]; #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { relu_q_minus_p_vec[j] = max(q_vec[j] - p_vec[j], DType(0)); @@ -774,22 +782,21 @@ cudaError_t ParallelTopPSamplingFromProb(T* probs, T* uniform_samples, IdType* o uint32_t batch_size, uint32_t d, uint32_t max_top_p_rounds, cudaStream_t stream = 0) { constexpr uint32_t BLOCK_THREADS = 1024; - const uint32_t vec_size = std::gcd(16 / sizeof(T), d); + constexpr uint32_t VEC_SIZE = 16 / sizeof(T); - const uint32_t smem_size = sizeof(SamplingTempStorage); + const uint32_t smem_size = + sizeof(SamplingTempStorage); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); T top_p_placeholder = 0; void* args[] = {&probs, &uniform_samples, &output, &success, &row_indices, &top_p_arr, &top_p_placeholder, &d, &max_top_p_rounds}; - DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { - auto kernel = - TopPSamplingFromProbKernel; - FLASHINFER_CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); - }); + auto kernel = + TopPSamplingFromProbKernel; + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); return cudaSuccess; } @@ -800,10 +807,10 @@ cudaError_t ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token_ids uint32_t num_speculative_tokens, uint32_t d, cudaStream_t stream = 0) { constexpr uint32_t BLOCK_THREADS = 1024; - const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); + constexpr uint32_t VEC_SIZE = 16 / sizeof(DType); const uint32_t smem_size = - sizeof(SamplingTempStorage); + sizeof(SamplingTempStorage); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); void* args[] = {&draft_probs, @@ -813,13 +820,11 @@ cudaError_t ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token_ids &output_token_ids, &num_speculative_tokens, &d}; - DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { - auto kernel = - ChainSpeculativeSampling; - FLASHINFER_CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); - }); + auto kernel = + ChainSpeculativeSampling; + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); return cudaSuccess; } diff --git a/include/flashinfer/utils.cuh b/include/flashinfer/utils.cuh index 849dae19..097d2458 100644 --- a/include/flashinfer/utils.cuh +++ b/include/flashinfer/utils.cuh @@ -218,40 +218,6 @@ } \ } -#define DISPATCH_ALIGNED_VEC_SIZE(aligned_vec_size, ALIGNED_VEC_SIZE, ...) \ - switch (aligned_vec_size) { \ - case 16: { \ - constexpr size_t ALIGNED_VEC_SIZE = 16; \ - __VA_ARGS__ \ - break; \ - } \ - case 8: { \ - constexpr size_t ALIGNED_VEC_SIZE = 8; \ - __VA_ARGS__ \ - break; \ - } \ - case 4: { \ - constexpr size_t ALIGNED_VEC_SIZE = 4; \ - __VA_ARGS__ \ - break; \ - } \ - case 2: { \ - constexpr size_t ALIGNED_VEC_SIZE = 2; \ - __VA_ARGS__ \ - break; \ - } \ - case 1: { \ - constexpr size_t ALIGNED_VEC_SIZE = 1; \ - __VA_ARGS__ \ - break; \ - } \ - default: { \ - std::ostringstream err_msg; \ - err_msg << "Unsupported aligned_vec_size: " << aligned_vec_size; \ - throw std::invalid_argument(err_msg.str()); \ - } \ - } - namespace flashinfer { inline bool is_device_ptr(const void* ptr) {