Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

more perfo with llamafile tinyblas #655

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
299 changes: 147 additions & 152 deletions llamafile/tinyblas_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
#pragma GCC diagnostic ignored "-Wpedantic"
#pragma GCC diagnostic ignored "-Wignored-attributes"

#define CHUNK 8
#define CHUNK 16
#define ROW_ALIGN 64
#define MATRIX_ALIGN 4096
#define MAX_ALIGN 4096
Expand Down Expand Up @@ -416,6 +416,12 @@ inline void store(ggml_bf16_t *p, float f) {
////////////////////////////////////////////////////////////////////////////////////////////////////
// FLOATING POINT MATRIX MULTIPLICATION

template <int M>
static long BLOCK_SIZE(long m) {
const long NB_BLOC_M = (m + M - 1) / M;
return (m % NB_BLOC_M == 0) ? m / NB_BLOC_M : (m / NB_BLOC_M) + 1;
}

template <int CONFIG, int KN, typename D, typename V, typename TA, typename TB, typename TC>
class tinyBLAS {
public:
Expand All @@ -424,180 +430,169 @@ class tinyBLAS {
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
}

void matmul(long m, long n) {
mnpack(0, m, 0, n);
}

private:
NOINLINE void mnpack(long m0, long m, long n0, long n) {
long mc, nc, mp, np;

bool matmul(long m, long n) {
#if VECTOR_REGISTERS == 32
switch ((MIN(m - m0, 5) << 4) | MIN(n - n0, 5)) {
case 0x55:
mc = 5;
nc = 5;
gemm<5, 5>(m0, m, n0, n);
break;
case 0x54:
case 0x53:
case 0x52:
case 0x45:
case 0x44:
case 0x43:
case 0x42:
case 0x35:
case 0x34:
case 0x33:
case 0x32:
case 0x25:
case 0x24:
case 0x23:
case 0x22:
mc = 2;
nc = 2;
gemm<2, 2>(m0, m, n0, n);
break;
case 0x51:
case 0x41:
case 0x31:
case 0x21:
mc = 2;
nc = 1;
gemm<2, 1>(m0, m, n0, n);
break;
case 0x15:
case 0x14:
case 0x13:
case 0x12:
mc = 1;
nc = 2;
gemm<1, 2>(m0, m, n0, n);
break;
case 0x11:
mc = 1;
nc = 1;
gemm<1, 1>(m0, m, n0, n);
break;
default:
return;
if (m % 8 == 0 && n < 4) {
mnpack<8, 3, 1>(m, n, n);
return true;
}
#endif

#if VECTOR_REGISTERS == 16
switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 3)) {
case 0x43:
mc = 4;
nc = 3;
gemm<4, 3>(m0, m, n0, n);
break;
case 0x42:
case 0x33:
case 0x32:
case 0x23:
case 0x22:
mc = 2;
nc = 2;
gemm<2, 2>(m0, m, n0, n);
break;
case 0x41:
case 0x31:
case 0x21:
mc = 2;
nc = 1;
gemm<2, 1>(m0, m, n0, n);
break;
case 0x13:
case 0x12:
mc = 1;
nc = 2;
gemm<1, 2>(m0, m, n0, n);
break;
case 0x11:
mc = 1;
nc = 1;
gemm<1, 1>(m0, m, n0, n);
break;
default:
return;
if (m % 16 == 0) {
const long SIZE_N = BLOCK_SIZE<6>(n);
mnpack<4, 6, 4>(m, n, SIZE_N);
return true;
}
if (m % 8 == 0) {
const long SIZE_N = BLOCK_SIZE<6>(n);
mnpack<4, 6, 2>(m, n, SIZE_N);
return true;
}
if (m % 4 == 0) {
const long SIZE_N = BLOCK_SIZE<6>(n);
mnpack<4, 6, 1>(m, n, SIZE_N);
return true;
}
#else // VECTOR_REGISTERS == 16
if (m % 4 == 0 && n < 3) {
mnpack<4, 2, 1>(m, n, n);
return true;
}
if (m % 16 == 0) {
const long SIZE_N = BLOCK_SIZE<3>(n);
mnpack<4, 3, 4>(m, n, SIZE_N);
return true;
}
if (m % 8 == 0) {
const long SIZE_N = BLOCK_SIZE<3>(n);
mnpack<4, 3, 2>(m, n, SIZE_N);
return true;
}
if (m % 4 == 0) {
const long SIZE_N = BLOCK_SIZE<3>(n);
mnpack<4, 3, 1>(m, n, SIZE_N);
return true;
}
#endif
return false;
}

mp = m0 + (m - m0) / mc * mc;
np = n0 + (n - n0) / nc * nc;
mnpack(mp, m, n0, np);
mnpack(m0, m, np, n);
private:
template <int RM, int RN, int BM>
inline void mnpack(long m, long n, long SIZE_N) {
if (SIZE_N == RN) {
return gemm<RM, RN, BM>(m, n);
}
if constexpr (RN > 1) {
return mnpack<RM, RN-1, BM>(m, n, SIZE_N);
//} else {
// GGML_LOG_ERROR("mnpack<%d, %d> bloc size not supported\n", RM, (int)SIZE_N);
// GGML_ASSERT(false); // we have miss something.
}
}

template <int RM, int RN>
NOINLINE void gemm(long m0, long m, long n0, long n) {
D stack[bsr(k / CHUNK + 1) + 1][RN][RM];
long ytiles = RM > 1 ? (m - m0) / RM : 1;
long xtiles = RN > 1 ? (n - n0) / RN : 1;
long tiles = xtiles * ytiles;
long duty = (tiles + nth - 1) / nth;
long start = duty * ith;
long end = start + duty;
if (end > tiles)
end = tiles;
for (long job = start; job < end; ++job) {
long ii = m0 + job / xtiles * RM;
long jj = n0 + job % xtiles * RN;

size_t chunk, sp = 0;
int i, j, rule, step = 2;
for (chunk = 0; chunk + KN * CHUNK * 4 <= k; chunk += KN * CHUNK * 4, step += 2, ++sp) {

D Cv[RN][RM] = {};
for (long l = 0; l < KN * CHUNK * 4; l += KN)
inline void gemm_bloc(long ii, long jj, long l, D Cv[RN][RM]) {
// help compiler for op order.
if constexpr (RM <= RN) {
V Av[RM];
#pragma GCC unroll 100
for (j = 0; j < RN; ++j)
for (int64_t i = 0; i < RM; ++i) {
Av[i] = load<V>(A + lda * (ii + i) + l);
}
#pragma GCC unroll 100
for (i = 0; i < RM; ++i)
Cv[j][i] = madd(load<V>(INDEX(A, lda, ii + i, chunk + l)), //
load<V>(INDEX(B, ldb, jj + j, chunk + l)), //
Cv[j][i]);

for (rule = bsr(step & -step); --rule;)
for (--sp, j = 0; j < RN; ++j)
for (i = 0; i < RM; ++i)
Cv[j][i] += stack[sp][j][i];

for (j = 0; j < RN; ++j)
for (i = 0; i < RM; ++i)
stack[sp][j][i] = Cv[j][i];
for (int64_t j = 0; j < RN; ++j) {
V Bv = load<V>(B + ldb * (jj + j) + l);
#pragma GCC unroll 100
for (int64_t i = 0; i < RM; ++i) {
Cv[j][i] = madd(Av[i], Bv, Cv[j][i]);
}
}

D Cv[RN][RM] = {};
for (; chunk + KN <= k; chunk += KN)
} else {
V Bv[RN];
#pragma GCC unroll 100
for (j = 0; j < RN; ++j)
for (int64_t j = 0; j < RN; ++j) {
Bv[j] = load<V>(B + ldb * (jj + j) + l);
}
#pragma GCC unroll 100
for (i = 0; i < RM; ++i)
Cv[j][i] = madd(load<V>(INDEX(A, lda, ii + i, chunk)), //
load<V>(INDEX(B, ldb, jj + j, chunk)), //
Cv[j][i]);
for (int64_t i = 0; i < RM; ++i) {
V Av = load<V>(A + lda * (ii + i) + l);
#pragma GCC unroll 100
for (int64_t j = 0; j < RN; ++j) {
Cv[j][i] = madd(Av, Bv[j], Cv[j][i]);
}
}
}
}

template <int RM, int RN>
inline void gemm_bloc(long ii, long jj) {
D stack[bsr(k / CHUNK + 1) + 1][RN][RM];
long chunk, sp = 0;
int i, j, rule, step = 2;
for (chunk = 0; chunk + KN * CHUNK * 4 <= k; chunk += KN * CHUNK * 4, step += 2, ++sp) {

while (sp--)
for (j = 0; j < RN; ++j)
D Cv[RN][RM] = {};
for (long l = 0; l < KN * CHUNK * 4; l += KN)
gemm_bloc<RM, RN>(ii, jj, chunk + l, Cv);

for (rule = bsr(step & -step); --rule;)
for (--sp, j = 0; j < RN; ++j)
for (i = 0; i < RM; ++i)
Cv[j][i] += stack[sp][j][i];

float Cf[RN][RM];
for (j = 0; j < RN; ++j)
for (i = 0; i < RM; ++i)
Cf[j][i] = hsum(Cv[j][i]);
stack[sp][j][i] = Cv[j][i];
}

for (; chunk < k; ++chunk)
for (j = 0; j < RN; ++j)
for (i = 0; i < RM; ++i)
Cf[j][i] = fmaf(load<float>(INDEX(A, lda, ii + i, chunk)), //
load<float>(INDEX(B, ldb, jj + j, chunk)), //
Cf[j][i]);
D Cv[RN][RM] = {};
for (; chunk + KN <= k; chunk += KN)
gemm_bloc<RM, RN>(ii, jj, chunk, Cv);

while (sp--)
for (j = 0; j < RN; ++j)
for (i = 0; i < RM; ++i)
Cv[j][i] += stack[sp][j][i];

float Cf[RN][RM];
for (j = 0; j < RN; ++j)
for (i = 0; i < RM; ++i)
Cf[j][i] = hsum(Cv[j][i]);

for (; chunk < k; ++chunk)
for (j = 0; j < RN; ++j)
for (i = 0; i < RM; ++i)
store(INDEX(C, ldc, jj + j, ii + i), Cf[j][i]);
Cf[j][i] = fmaf(load<float>(INDEX(A, lda, ii + i, chunk)), //
load<float>(INDEX(B, ldb, jj + j, chunk)), //
Cf[j][i]);

for (j = 0; j < RN; ++j)
for (i = 0; i < RM; ++i)
store(INDEX(C, ldc, jj + j, ii + i), Cf[j][i]);
}

template <int RM, int RN, int BM>
NOINLINE void gemm(long m, long n) {
// GGML_ASSERT(m % (RM * BM) == 0);
const long ytiles = m / (RM * BM);
const long xtiles = (n + RN -1) / RN;
const long jj_RN = (xtiles - (xtiles * RN - n));

long tiles = xtiles * ytiles;
long duty = (tiles + nth - 1) / nth;
long start = duty * ith;
long end = start + duty;
if (end > tiles)
end = tiles;
for (int64_t job = start; job < end; ++job) {
const int64_t ii = job / xtiles;
const int64_t jj = job % xtiles;
for (int64_t bi = 0; bi < BM; ++bi) {
if (jj < jj_RN) {
gemm_bloc<RM, RN>((ii * BM + bi) * RM, jj * RN);
} else if constexpr (RN > 1) {
gemm_bloc<RM, RN - 1>((ii * BM + bi) * RM, jj_RN * RN + (jj - jj_RN) * (RN - 1));
}
}
}
}

Expand Down
Loading
Loading