From 54bb2430eea7048235123e5c919b25bdccdd0c5e Mon Sep 17 00:00:00 2001 From: Gautham Date: Tue, 2 Jan 2024 18:24:21 +0000 Subject: [PATCH] added matmul32_block2d for sgemm removed the matmul_single methods --- llamafile/tinyblas.cu | 145 +++++++++++++++++++++++------------------- 1 file changed, 81 insertions(+), 64 deletions(-) diff --git a/llamafile/tinyblas.cu b/llamafile/tinyblas.cu index ac3d6f487c..2b8c550cbb 100644 --- a/llamafile/tinyblas.cu +++ b/llamafile/tinyblas.cu @@ -21,36 +21,86 @@ (((trans) == CUBLAS_OP_N) ? (A)[(i) + (j) * (ld)] : (A)[(j) + (i) * (ld)]) #define READ16(A, trans, ld, i, j) __half2float(READ(A, trans, ld, i, j)) -static __device__ __forceinline__ void matmul_single(int m, int n, int k, - int i, int j, - const half *A, int lda, - const half *B, int ldb, - void *C, - cudaDataType_t Ctype, - int ldc) { - float sum = 0.0f; - for (int l = 0; l < k; ++l) { - sum += READ16(A, CUBLAS_OP_T, lda, i, l) * - READ16(B, CUBLAS_OP_N, ldb, l, j); +#define BM 64 +#define BN 32 +#define BK BM +#define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) + +static __device__ void matmul32_block2d(int m, int n, int k, int x, int y, + const float *A, int lda, float *As, + const float *B, int ldb, float *Bs, + void *C, int ldc, float *Cs) { + const int i = threadIdx.x; + int j, l, blob; + // within each block + // we first zero out Cs + for (j = 0; j < BN; ++j) Cs[j] = 0; + + for (blob = 0; blob < k; blob += BK) { + // we copy into As from A + if (i < BM && (x + i) < m) { + for (j = 0; j < BK && blob + j < k; ++j) { + As[(i * BK) + j] = READ(A, CUBLAS_OP_T, lda, x + i, blob + j); + } + for (; j < BK; ++j) As[(i * BK) + j] = 0; + } else { // UNLIKELY + for (j = 0; j < BK; ++j) As[(i * BK) + j] = 0; + } + + // we copy into Bs from B + if (i < BK && (blob + i) < k) { + for (j = 0; j < BN && y + j < n; ++j) { + Bs[(i * BN) + j] = READ(B, CUBLAS_OP_N, ldb, blob + i, y + j); + } + for (; j < BN; ++j) Bs[(i * BN) + j] = 0; + } else { // UNLIKELY + for (j = 0; j < BN; ++j) Bs[(i * BN) + j] = 0; + } + __syncthreads(); + + // We matmul the blobs, basically Cs += matmul(As, Bs) + for (j = 0; j < BN; ++j) { + for (l = 0; l < BK; ++l) { + Cs[j] += As[(i * BK) + l] * Bs[(l * BN) + j]; + } + } + __syncthreads(); } - if (Ctype == CUDA_R_16F) { - *((half *)C + i + j * ldc) = __float2half(sum); - } else { - *((float *)C + i + j * ldc) = sum; + + // We write Cs out into C + if (x + i < m) { + for (j = 0; j < BN && y + j < n; ++j) { + *((float *)C + (x + i) + (y + j) * ldc) = Cs[j]; + } } + __syncthreads(); } -static __device__ __forceinline__ void matmul_single32(int m, int n, int k, int i, int j, - const float *A, int lda, - const float *B, int ldb, - float *C, int ldc) { - float sum = 0.0f; - float *cptr = C + i + j * ldc; - for (int l = 0; l < k; ++l) { - sum += READ(A, CUBLAS_OP_T, lda, i, l) * - READ(B, CUBLAS_OP_N, ldb, l, j); +static __global__ void tinyblasS_entry(int m, int n, int k, + const float *A, int lda, + const float *B, int ldb, + float *C, int ldc) { + int x = blockIdx.x * BM; + const int jump1 = gridDim.x * BM; + int y = blockIdx.y * BN; + const int jump2 = gridDim.y * BN; + + assert(blockDim.x == BM); + extern __shared__ float svals[]; // shared across all threads in a block + float *As = svals; + float *Bs = svals + BM * BK; + float Cs[BN]; // only within a particular thread + + // each block handles a sub-matrix of C, of size BM * BN + // each thread handles a sub-row of size BN + for (x = blockIdx.x * BM; x < m; x += jump1) { + for (y = blockIdx.y * BN; y < n; y += jump2) { + matmul32_block2d(m, n, k, x, y, // + A, lda, As, // + B, ldb, Bs, // + C, ldc, Cs); + } } - *cptr = sum; } static bool check_args(cublasOperation_t transa, cublasOperation_t transb, @@ -71,22 +121,6 @@ static bool check_args(cublasOperation_t transa, cublasOperation_t transb, *(float *)pBeta == 0.0f))); } -static __global__ void tinyblasS_entry(int m, int n, int k, - const float *A, int lda, - const float *B, int ldb, - float *C, int ldc) { - int x = blockIdx.x * blockDim.x + threadIdx.x; - int jump = blockDim.x * gridDim.x; - int y = threadIdx.y; - int jump2 = blockDim.y; - - for (; x < m; x += jump) { - for (y=threadIdx.y; y < n; y += jump2) { - matmul_single32(m, n, k, x, y, A, lda, B, ldb, C, ldc); - } - } -} - cublasStatus_t tinyblasSgemm(cudaStream_t stream, cublasOperation_t transa, cublasOperation_t transb, @@ -101,23 +135,15 @@ cublasStatus_t tinyblasSgemm(cudaStream_t stream, return CUBLAS_STATUS_NOT_SUPPORTED; } - int numSMs, devId; - cudaGetDevice(&devId); - cudaDeviceGetAttribute(&numSMs, cudaDevAttrMultiProcessorCount, devId); - int maxblocks = 16 * numSMs; - dim3 maxthreads(16, 64, 1); + dim3 maxblocks(CEIL_DIV(m, BM), CEIL_DIV(n, BN), 1); + int maxthreads = BM; - tinyblasS_entry<<>>( + tinyblasS_entry<<>>( m, n, k, A, lda, B, ldb, C, ldc); return CUBLAS_STATUS_SUCCESS; } -// https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmex -#define BM 64 -#define BN 32 -#define BK BM -#define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) - static __device__ void matmul_block2d(int m, int n, int k, int x, int y, const half *A, int lda, float *As, const half *B, int ldb, float *Bs, @@ -168,13 +194,14 @@ static __device__ void matmul_block2d(int m, int n, int k, int x, int y, } } else { for (j = 0; j < BN && y + j < n; ++j) { - *((float *)C + (x + i) + (y + j) * ldc) = __float2half(Cs[j]); + *((float *)C + (x + i) + (y + j) * ldc) = Cs[j]; } } } __syncthreads(); } +// https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmex static __global__ void tinyblasGE_entry(int m, int n, int k, const half *A, int lda, const half *B, int ldb, void *C, cudaDataType_t Ctype, @@ -307,16 +334,6 @@ cublasStatus_t tinyblasGemmBatchedEx(cudaStream_t stream, // https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmstridedbatchedex -#define STRIDE0(A, i, stride) ((A) + (i) * (stride)) -#define STRIDE(A, type, i, stride) \ - ((type) == CUDA_R_16F \ - ? (void *)STRIDE0((half *)(A), (i), (stride)) \ - : (void *)STRIDE0((float *)(A), (i), (stride))) -#define STRIDE_CONST(A, type, i, stride) \ - ((type) == CUDA_R_16F \ - ? (const void *)STRIDE0((const half *)(A), (i), (stride)) \ - : (const void *)STRIDE0((const float *)(A), (i), (stride))) - static __global__ void tinyblasGSBE_entry(int m, int n, int k, const half *A, int lda,