Skip to content

Commit

Permalink
feat: add ruy sgemm implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
ebraraktas committed Mar 14, 2024
1 parent bfa0cb3 commit 85166a5
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/cpu/backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ namespace ctranslate2 {
#endif

#ifdef CT2_WITH_RUY
if (is_int8) {
if (is_int8 || compute_type == ComputeType::FLOAT32) {
return GemmBackend::RUY;
}
#endif
Expand Down
54 changes: 54 additions & 0 deletions src/cpu/primitives.cc
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,60 @@ namespace ctranslate2 {
}
#endif

#ifdef CT2_WITH_RUY
case cpu::GemmBackend::RUY: {
if (lda != (transpose_a ? m : k)
|| ldb != (transpose_b ? k : n)
|| ldc != n)
throw std::invalid_argument("Ruy GEMM does not support custom leading dimensions");

ruy::Context *context = cpu::get_ruy_context();

const ruy::Order order_a = transpose_a ? ruy::Order::kColMajor : ruy::Order::kRowMajor;
const ruy::Order order_b = transpose_b ? ruy::Order::kColMajor : ruy::Order::kRowMajor;

ruy::Matrix<float> lhs;
ruy::MakeSimpleLayout(m, k, order_a, lhs.mutable_layout());
lhs.set_data(a);

ruy::Matrix<float> rhs;
ruy::MakeSimpleLayout(k, n, order_b, rhs.mutable_layout());
rhs.set_data(b);

ruy::Matrix<float> dst;
ruy::MakeSimpleLayout(m, n, ruy::Order::kRowMajor, dst.mutable_layout());
dst.set_data(c);

float *tmp_c = nullptr;

ruy::MulParams<float, float> mul_params;

if (beta != 0.0f) {
// this block sets `(beta / alpha) * c` as bias
// and multiplication by `alpha` below generates correct value:
// C <- alpha * (AB + (beta/alpha) * C)
// <- alpha * AB + beta * C
// there is no guard for alpha = 0.0 case, as it is unlikely to
// call this function with that value.
auto beta_prime = beta / alpha;
tmp_c = static_cast<float*>(allocator.allocate(m * n * sizeof (float)));
mul(beta_prime, c, tmp_c, m * n);
mul_params.set_bias(tmp_c);
}

ruy::Mul(lhs, rhs, mul_params, context, &dst);

if (alpha != 1.0f) {
mul(alpha, c, m * n);
}

if (tmp_c) {
allocator.free(tmp_c);
}
break;
}
#endif

#ifdef CT2_WITH_OPENBLAS
case cpu::GemmBackend::OPENBLAS: {
cblas_sgemm(CblasRowMajor,
Expand Down

0 comments on commit 85166a5

Please sign in to comment.