Skip to content

Commit

Permalink
add bounds check
Browse files Browse the repository at this point in the history
  • Loading branch information
ahgamut committed Jan 2, 2024
1 parent 54bb243 commit b96885c
Showing 1 changed file with 43 additions and 31 deletions.
74 changes: 43 additions & 31 deletions llamafile/tinyblas.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))

static __device__ void matmul32_block2d(int m, int n, int k, int x, int y,
const float *A, int lda, float *As,
const float *B, int ldb, float *Bs,
void *C, int ldc, float *Cs) {
const float *A, int lda, float *As,
const float *B, int ldb, float *Bs,
void *C, int ldc, float *Cs) {
assert(blockDim.x == BM);
static_assert(BK <= BM);
const int i = threadIdx.x;
int j, l, blob;
// within each block
Expand All @@ -38,23 +40,29 @@ static __device__ void matmul32_block2d(int m, int n, int k, int x, int y,

for (blob = 0; blob < k; blob += BK) {
// we copy into As from A
if (i < BM && (x + i) < m) {
for (j = 0; j < BK && blob + j < k; ++j) {
As[(i * BK) + j] = READ(A, CUBLAS_OP_T, lda, x + i, blob + j);
if (i < BM) {
if ((x + i) < m) {
for (j = 0; j < BK && blob + j < k; ++j) {
As[(i * BK) + j] =
READ(A, CUBLAS_OP_T, lda, x + i, blob + j);
}
for (; j < BK; ++j) As[(i * BK) + j] = 0;
} else { // UNLIKELY
for (j = 0; j < BK; ++j) As[(i * BK) + j] = 0;
}
for (; j < BK; ++j) As[(i * BK) + j] = 0;
} else { // UNLIKELY
for (j = 0; j < BK; ++j) As[(i * BK) + j] = 0;
}

// we copy into Bs from B
if (i < BK && (blob + i) < k) {
for (j = 0; j < BN && y + j < n; ++j) {
Bs[(i * BN) + j] = READ(B, CUBLAS_OP_N, ldb, blob + i, y + j);
if (i < BK) {
if ((blob + i) < k) {
for (j = 0; j < BN && y + j < n; ++j) {
Bs[(i * BN) + j] =
READ(B, CUBLAS_OP_N, ldb, blob + i, y + j);
}
for (; j < BN; ++j) Bs[(i * BN) + j] = 0;
} else { // UNLIKELY
for (j = 0; j < BN; ++j) Bs[(i * BN) + j] = 0;
}
for (; j < BN; ++j) Bs[(i * BN) + j] = 0;
} else { // UNLIKELY
for (j = 0; j < BN; ++j) Bs[(i * BN) + j] = 0;
}
__syncthreads();

Expand Down Expand Up @@ -85,7 +93,6 @@ static __global__ void tinyblasS_entry(int m, int n, int k,
int y = blockIdx.y * BN;
const int jump2 = gridDim.y * BN;

assert(blockDim.x == BM);
extern __shared__ float svals[]; // shared across all threads in a block
float *As = svals;
float *Bs = svals + BM * BK;
Expand Down Expand Up @@ -149,6 +156,8 @@ static __device__ void matmul_block2d(int m, int n, int k, int x, int y,
const half *B, int ldb, float *Bs,
void *C, cudaDataType_t Ctype, int ldc,
float *Cs) {
assert(blockDim.x == BM);
static_assert(BK <= BM);
const int i = threadIdx.x;
int j, l, blob;
// within each block
Expand All @@ -157,23 +166,29 @@ static __device__ void matmul_block2d(int m, int n, int k, int x, int y,

for (blob = 0; blob < k; blob += BK) {
// we copy into As from A
if (i < BM && (x + i) < m) {
for (j = 0; j < BK && blob + j < k; ++j) {
As[(i * BK) + j] = READ16(A, CUBLAS_OP_T, lda, x + i, blob + j);
if (i < BM) {
if ((x + i) < m) {
for (j = 0; j < BK && blob + j < k; ++j) {
As[(i * BK) + j] =
READ16(A, CUBLAS_OP_T, lda, x + i, blob + j);
}
for (; j < BK; ++j) As[(i * BK) + j] = 0;
} else { // UNLIKELY
for (j = 0; j < BK; ++j) As[(i * BK) + j] = 0;
}
for (; j < BK; ++j) As[(i * BK) + j] = 0;
} else { // UNLIKELY
for (j = 0; j < BK; ++j) As[(i * BK) + j] = 0;
}

// we copy into Bs from B
if (i < BK && (blob + i) < k) {
for (j = 0; j < BN && y + j < n; ++j) {
Bs[(i * BN) + j] = READ16(B, CUBLAS_OP_N, ldb, blob + i, y + j);
if (i < BK) {
if ((blob + i) < k) {
for (j = 0; j < BN && y + j < n; ++j) {
Bs[(i * BN) + j] =
READ16(B, CUBLAS_OP_N, ldb, blob + i, y + j);
}
for (; j < BN; ++j) Bs[(i * BN) + j] = 0;
} else { // UNLIKELY
for (j = 0; j < BN; ++j) Bs[(i * BN) + j] = 0;
}
for (; j < BN; ++j) Bs[(i * BN) + j] = 0;
} else { // UNLIKELY
for (j = 0; j < BN; ++j) Bs[(i * BN) + j] = 0;
}
__syncthreads();

Expand Down Expand Up @@ -211,7 +226,6 @@ static __global__ void tinyblasGE_entry(int m, int n, int k, const half *A,
int y = blockIdx.y * BN;
const int jump2 = gridDim.y * BN;

assert(blockDim.x == BM);
extern __shared__ float svals[]; // shared across all threads in a block
float *As = svals;
float *Bs = svals + BM * BK;
Expand Down Expand Up @@ -277,7 +291,6 @@ static __global__ void tinyblasGBE_entry(int m, int n, int k,
int z = blockIdx.z;
const int jump3 = gridDim.z;

assert(blockDim.x == BM);
extern __shared__ float svals[]; // shared across all threads in a block
float *As = svals;
float *Bs = svals + BM * BK;
Expand Down Expand Up @@ -353,7 +366,6 @@ static __global__ void tinyblasGSBE_entry(int m, int n, int k,
int z = blockIdx.z;
const int jump3 = gridDim.z;

assert(blockDim.x == BM);
extern __shared__ float svals[]; // shared across all threads in a block
float *As = svals;
float *Bs = svals + BM * BK;
Expand Down

0 comments on commit b96885c

Please sign in to comment.