Skip to content

Commit

Permalink
add GSBE
Browse files Browse the repository at this point in the history
  • Loading branch information
Gautham authored and ahgamut committed Jan 2, 2024
1 parent 74605e2 commit c9ff289
Showing 1 changed file with 28 additions and 21 deletions.
49 changes: 28 additions & 21 deletions llamafile/tinyblas.cu
Original file line number Diff line number Diff line change
Expand Up @@ -329,21 +329,31 @@ static __global__ void tinyblasGSBE_entry(int m, int n, int k,
int ldc,
long long int strideC,
int batchCount) {
int x = blockIdx.x * blockDim.x + threadIdx.x;
int y = blockIdx.y * blockDim.y + threadIdx.y;
int z = threadIdx.z;
int jump = blockDim.x * gridDim.x;
int jump2 = blockDim.y * gridDim.y;
int jump3 = blockDim.z;
int x = blockIdx.x * BM;
const int jump1 = gridDim.x * BM;
int y = blockIdx.y * BN;
const int jump2 = gridDim.y * BN;
int z = blockIdx.z;
const int jump3 = gridDim.z;

assert(blockDim.x == BM);
extern __shared__ float svals[]; // shared across all threads in a block
float *As = svals;
float *Bs = svals + BM * BK;
float Cs[BN]; // only within a particular thread

for (; x < batchCount; x += jump) {
for (y=blockIdx.y * blockDim.y + threadIdx.y; y < m; y += jump2) {
for (z=threadIdx.z; z < n; z += jump3) {
matmul_single(m, n, k, y, z, A + x * strideA, lda, B + x * strideB, ldb,
// each block handles a sub-matrix of C, of size BM * BN
// each thread handles a sub-row of size BN
for (z = blockIdx.z; z < batchCount; z += jump3) {
for (x = blockIdx.x * BM; x < m; x += jump1) {
for (y = blockIdx.y * BN; y < n; y += jump2) {
matmul_block2d(m, n, k, x, y, //
A + z * strideA, lda, As, //
B + z * strideB, ldb, Bs, //
(Ctype == CUDA_R_16F
? (void *)((half *)C + x * strideC)
: (void *)((float *)C + x * strideC)),
Ctype, ldc);
? (void *)((half *)C + z * strideC)
: (void *)((float *)C + z * strideC)),
Ctype, ldc, Cs);
}
}
}
Expand Down Expand Up @@ -375,15 +385,12 @@ cublasStatus_t tinyblasGemmStridedBatchedEx(cudaStream_t stream,
return CUBLAS_STATUS_NOT_SUPPORTED;
}

// https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/
int numSMs, devId;
cudaGetDevice(&devId);
cudaDeviceGetAttribute(&numSMs, cudaDevAttrMultiProcessorCount, devId);
dim3 maxblocks(numSMs, 16, 1);
dim3 maxthreads(4, 4, 64);

// call the entry function
tinyblasGSBE_entry<<<maxblocks, maxthreads, 0, stream>>>(
dim3 maxblocks(CEIL_DIV(m, BM), CEIL_DIV(n, BN), 32);
int maxthreads = BM;

tinyblasGSBE_entry<<<maxblocks, maxthreads,
(sizeof(float) * (BM * BK + BK * BN)), stream>>>(
m, n, k, (const half*)A, lda, strideA, (const half*)B, ldb, strideB,
C, Ctype, ldc, strideC, batchCount);

Expand Down

0 comments on commit c9ff289

Please sign in to comment.