Skip to content

Commit

Permalink
added matmul32_block2d for sgemm
Browse files Browse the repository at this point in the history
removed the matmul_single methods
  • Loading branch information
Gautham authored and ahgamut committed Jan 2, 2024
1 parent c9ff289 commit 54bb243
Showing 1 changed file with 81 additions and 64 deletions.
145 changes: 81 additions & 64 deletions llamafile/tinyblas.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,36 +21,86 @@
(((trans) == CUBLAS_OP_N) ? (A)[(i) + (j) * (ld)] : (A)[(j) + (i) * (ld)])
#define READ16(A, trans, ld, i, j) __half2float(READ(A, trans, ld, i, j))

static __device__ __forceinline__ void matmul_single(int m, int n, int k,
int i, int j,
const half *A, int lda,
const half *B, int ldb,
void *C,
cudaDataType_t Ctype,
int ldc) {
float sum = 0.0f;
for (int l = 0; l < k; ++l) {
sum += READ16(A, CUBLAS_OP_T, lda, i, l) *
READ16(B, CUBLAS_OP_N, ldb, l, j);
#define BM 64
#define BN 32
#define BK BM
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))

static __device__ void matmul32_block2d(int m, int n, int k, int x, int y,
const float *A, int lda, float *As,
const float *B, int ldb, float *Bs,
void *C, int ldc, float *Cs) {
const int i = threadIdx.x;
int j, l, blob;
// within each block
// we first zero out Cs
for (j = 0; j < BN; ++j) Cs[j] = 0;

for (blob = 0; blob < k; blob += BK) {
// we copy into As from A
if (i < BM && (x + i) < m) {
for (j = 0; j < BK && blob + j < k; ++j) {
As[(i * BK) + j] = READ(A, CUBLAS_OP_T, lda, x + i, blob + j);
}
for (; j < BK; ++j) As[(i * BK) + j] = 0;
} else { // UNLIKELY
for (j = 0; j < BK; ++j) As[(i * BK) + j] = 0;
}

// we copy into Bs from B
if (i < BK && (blob + i) < k) {
for (j = 0; j < BN && y + j < n; ++j) {
Bs[(i * BN) + j] = READ(B, CUBLAS_OP_N, ldb, blob + i, y + j);
}
for (; j < BN; ++j) Bs[(i * BN) + j] = 0;
} else { // UNLIKELY
for (j = 0; j < BN; ++j) Bs[(i * BN) + j] = 0;
}
__syncthreads();

// We matmul the blobs, basically Cs += matmul(As, Bs)
for (j = 0; j < BN; ++j) {
for (l = 0; l < BK; ++l) {
Cs[j] += As[(i * BK) + l] * Bs[(l * BN) + j];
}
}
__syncthreads();
}
if (Ctype == CUDA_R_16F) {
*((half *)C + i + j * ldc) = __float2half(sum);
} else {
*((float *)C + i + j * ldc) = sum;

// We write Cs out into C
if (x + i < m) {
for (j = 0; j < BN && y + j < n; ++j) {
*((float *)C + (x + i) + (y + j) * ldc) = Cs[j];
}
}
__syncthreads();
}

static __device__ __forceinline__ void matmul_single32(int m, int n, int k, int i, int j,
const float *A, int lda,
const float *B, int ldb,
float *C, int ldc) {
float sum = 0.0f;
float *cptr = C + i + j * ldc;
for (int l = 0; l < k; ++l) {
sum += READ(A, CUBLAS_OP_T, lda, i, l) *
READ(B, CUBLAS_OP_N, ldb, l, j);
static __global__ void tinyblasS_entry(int m, int n, int k,
const float *A, int lda,
const float *B, int ldb,
float *C, int ldc) {
int x = blockIdx.x * BM;
const int jump1 = gridDim.x * BM;
int y = blockIdx.y * BN;
const int jump2 = gridDim.y * BN;

assert(blockDim.x == BM);
extern __shared__ float svals[]; // shared across all threads in a block
float *As = svals;
float *Bs = svals + BM * BK;
float Cs[BN]; // only within a particular thread

// each block handles a sub-matrix of C, of size BM * BN
// each thread handles a sub-row of size BN
for (x = blockIdx.x * BM; x < m; x += jump1) {
for (y = blockIdx.y * BN; y < n; y += jump2) {
matmul32_block2d(m, n, k, x, y, //
A, lda, As, //
B, ldb, Bs, //
C, ldc, Cs);
}
}
*cptr = sum;
}

static bool check_args(cublasOperation_t transa, cublasOperation_t transb,
Expand All @@ -71,22 +121,6 @@ static bool check_args(cublasOperation_t transa, cublasOperation_t transb,
*(float *)pBeta == 0.0f)));
}

static __global__ void tinyblasS_entry(int m, int n, int k,
const float *A, int lda,
const float *B, int ldb,
float *C, int ldc) {
int x = blockIdx.x * blockDim.x + threadIdx.x;
int jump = blockDim.x * gridDim.x;
int y = threadIdx.y;
int jump2 = blockDim.y;

for (; x < m; x += jump) {
for (y=threadIdx.y; y < n; y += jump2) {
matmul_single32(m, n, k, x, y, A, lda, B, ldb, C, ldc);
}
}
}

cublasStatus_t tinyblasSgemm(cudaStream_t stream,
cublasOperation_t transa,
cublasOperation_t transb,
Expand All @@ -101,23 +135,15 @@ cublasStatus_t tinyblasSgemm(cudaStream_t stream,
return CUBLAS_STATUS_NOT_SUPPORTED;
}

int numSMs, devId;
cudaGetDevice(&devId);
cudaDeviceGetAttribute(&numSMs, cudaDevAttrMultiProcessorCount, devId);
int maxblocks = 16 * numSMs;
dim3 maxthreads(16, 64, 1);
dim3 maxblocks(CEIL_DIV(m, BM), CEIL_DIV(n, BN), 1);
int maxthreads = BM;

tinyblasS_entry<<<maxblocks, maxthreads, 0, stream>>>(
tinyblasS_entry<<<maxblocks, maxthreads,
(sizeof(float) * (BM * BK + BK * BN)), stream>>>(
m, n, k, A, lda, B, ldb, C, ldc);
return CUBLAS_STATUS_SUCCESS;
}

// https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmex
#define BM 64
#define BN 32
#define BK BM
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))

static __device__ void matmul_block2d(int m, int n, int k, int x, int y,
const half *A, int lda, float *As,
const half *B, int ldb, float *Bs,
Expand Down Expand Up @@ -168,13 +194,14 @@ static __device__ void matmul_block2d(int m, int n, int k, int x, int y,
}
} else {
for (j = 0; j < BN && y + j < n; ++j) {
*((float *)C + (x + i) + (y + j) * ldc) = __float2half(Cs[j]);
*((float *)C + (x + i) + (y + j) * ldc) = Cs[j];
}
}
}
__syncthreads();
}

// https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmex
static __global__ void tinyblasGE_entry(int m, int n, int k, const half *A,
int lda, const half *B, int ldb,
void *C, cudaDataType_t Ctype,
Expand Down Expand Up @@ -307,16 +334,6 @@ cublasStatus_t tinyblasGemmBatchedEx(cudaStream_t stream,

// https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmstridedbatchedex

#define STRIDE0(A, i, stride) ((A) + (i) * (stride))
#define STRIDE(A, type, i, stride) \
((type) == CUDA_R_16F \
? (void *)STRIDE0((half *)(A), (i), (stride)) \
: (void *)STRIDE0((float *)(A), (i), (stride)))
#define STRIDE_CONST(A, type, i, stride) \
((type) == CUDA_R_16F \
? (const void *)STRIDE0((const half *)(A), (i), (stride)) \
: (const void *)STRIDE0((const float *)(A), (i), (stride)))

static __global__ void tinyblasGSBE_entry(int m, int n, int k,
const half *A,
int lda,
Expand Down

0 comments on commit 54bb243

Please sign in to comment.