diff --git a/llamafile/tinyblas.cu b/llamafile/tinyblas.cu index 47308cbc6f..ac3d6f487c 100644 --- a/llamafile/tinyblas.cu +++ b/llamafile/tinyblas.cu @@ -329,21 +329,31 @@ static __global__ void tinyblasGSBE_entry(int m, int n, int k, int ldc, long long int strideC, int batchCount) { - int x = blockIdx.x * blockDim.x + threadIdx.x; - int y = blockIdx.y * blockDim.y + threadIdx.y; - int z = threadIdx.z; - int jump = blockDim.x * gridDim.x; - int jump2 = blockDim.y * gridDim.y; - int jump3 = blockDim.z; + int x = blockIdx.x * BM; + const int jump1 = gridDim.x * BM; + int y = blockIdx.y * BN; + const int jump2 = gridDim.y * BN; + int z = blockIdx.z; + const int jump3 = gridDim.z; + + assert(blockDim.x == BM); + extern __shared__ float svals[]; // shared across all threads in a block + float *As = svals; + float *Bs = svals + BM * BK; + float Cs[BN]; // only within a particular thread - for (; x < batchCount; x += jump) { - for (y=blockIdx.y * blockDim.y + threadIdx.y; y < m; y += jump2) { - for (z=threadIdx.z; z < n; z += jump3) { - matmul_single(m, n, k, y, z, A + x * strideA, lda, B + x * strideB, ldb, + // each block handles a sub-matrix of C, of size BM * BN + // each thread handles a sub-row of size BN + for (z = blockIdx.z; z < batchCount; z += jump3) { + for (x = blockIdx.x * BM; x < m; x += jump1) { + for (y = blockIdx.y * BN; y < n; y += jump2) { + matmul_block2d(m, n, k, x, y, // + A + z * strideA, lda, As, // + B + z * strideB, ldb, Bs, // (Ctype == CUDA_R_16F - ? (void *)((half *)C + x * strideC) - : (void *)((float *)C + x * strideC)), - Ctype, ldc); + ? (void *)((half *)C + z * strideC) + : (void *)((float *)C + z * strideC)), + Ctype, ldc, Cs); } } } @@ -375,15 +385,12 @@ cublasStatus_t tinyblasGemmStridedBatchedEx(cudaStream_t stream, return CUBLAS_STATUS_NOT_SUPPORTED; } - // https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/ - int numSMs, devId; - cudaGetDevice(&devId); - cudaDeviceGetAttribute(&numSMs, cudaDevAttrMultiProcessorCount, devId); - dim3 maxblocks(numSMs, 16, 1); - dim3 maxthreads(4, 4, 64); - // call the entry function - tinyblasGSBE_entry<<>>( + dim3 maxblocks(CEIL_DIV(m, BM), CEIL_DIV(n, BN), 32); + int maxthreads = BM; + + tinyblasGSBE_entry<<>>( m, n, k, (const half*)A, lda, strideA, (const half*)B, ldb, strideB, C, Ctype, ldc, strideC, batchCount);