Skip to content

Commit

Permalink
[SYCL][E2E][Joint Matrix] OOB tests to support more shapes, layouts
Browse files Browse the repository at this point in the history
  • Loading branch information
YuriPlyakhin committed Jan 30, 2025
1 parent adeaea8 commit 817f6bd
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 54 deletions.
115 changes: 65 additions & 50 deletions sycl/test-e2e/Matrix/Inputs/joint_matrix_out_bounds_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,13 @@
#include <iostream>
#include <sycl/usm.hpp>

constexpr size_t TM = 8;
constexpr size_t TK = 16;

template <layout B_layout, unsigned int vnniFactor> class mult;

template <typename T1, typename T2, size_t NUM_ROWS_A, size_t NUM_COLS_A,
size_t NUM_ROWS_B, size_t NUM_COLS_B, size_t NUM_ROWS_C,
size_t NUM_COLS_C, layout B_layout, unsigned int vnniFactor>
template <typename T1, typename T2, size_t M, size_t N, size_t K, size_t TM,
size_t TN, size_t TK, layout A_layout, layout B_layout,
unsigned int vnniFactor>
void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q) {
size_t M = NUM_ROWS_C;
size_t N = NUM_COLS_C;
size_t K = NUM_COLS_A;

assert(NUM_ROWS_C == NUM_ROWS_A && NUM_COLS_A == NUM_ROWS_B * vnniFactor);
// Add one iteration for the out of bounds dpas instruction
size_t NDRangeM = M / TM + (((M % TM) != 0) ? 1 : 0);
size_t NDRangeN = N / TN;
Expand All @@ -45,6 +38,7 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q) {
auto pC =
address_space_cast<sycl::access::address_space::global_space,
sycl::access::decorated::no>(C);

// The submatrix API has to be accessed by all the workitems in a
// subgroup these functions will be called once by the subgroup no
// code divergence between the workitems
Expand All @@ -54,27 +48,40 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q) {
const auto sg_starty = global_idy - spmd_item.get_local_id(1);

sub_group sg = spmd_item.get_sub_group();
joint_matrix<sub_group, bfloat16, use::a, TM, TK, layout::row_major>
sub_a;

// For B, since current implementation does not support non-packed
// layout, users need to specify the packed_b layout.
joint_matrix<sub_group, bfloat16, use::b, TK, TN, B_layout> sub_b;
joint_matrix<sub_group, float, use::accumulator, TM, TN> sub_c;
// bounds-checked load where width and height are added
joint_matrix<sub_group, T2, use::a, TM, TK, A_layout> sub_a;
joint_matrix<sub_group, T2, use::b, TK, TN, B_layout> sub_b;
joint_matrix<sub_group, T1, use::accumulator, TM, TN> sub_c;

// bounds-checked fill where width and height are added
ext::intel::experimental::matrix::joint_matrix_fill_checked(
sg, sub_c, 1, M, N, sg_startx * TM, sg_starty / sg_size * TN);

for (int k = 0; k < K; k += TK) {
// bounds-checked load where width and height are added
ext::intel::experimental::matrix::joint_matrix_load_checked(
sg, sub_a, pA, K, M, K, sg_startx * TM, k);
// Assume we alreay in vnni format.
// params order: Stride, Height, Width, CoordX, CoordY
if constexpr (A_layout == layout::row_major) {
ext::intel::experimental::matrix::joint_matrix_load_checked(
sg, sub_a, pA, K, M, K, sg_startx * TM, k);
} else {
ext::intel::experimental::matrix::joint_matrix_load_checked(
sg, sub_a, pA, M, K, M, k, sg_startx * TM);
}

// bounds-checked load where width and height are added
ext::intel::experimental::matrix::joint_matrix_load_checked(
sg, sub_b, pB, N * vnniFactor, K / vnniFactor, N * vnniFactor,
k / vnniFactor, sg_starty / sg_size * TN * vnniFactor);
// params order: Stride, Height, Width, CoordX, CoordY
if constexpr (B_layout != layout::col_major) {
ext::intel::experimental::matrix::joint_matrix_load_checked(
sg, sub_b, pB, N * vnniFactor, K / vnniFactor,
N * vnniFactor, k / vnniFactor,
sg_starty / sg_size * TN * vnniFactor);
} else {
ext::intel::experimental::matrix::joint_matrix_load_checked(
sg, sub_b, pB, K, N, K, sg_starty / sg_size * TN, k);
}

joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c);
}

// bounds-checked store where width and height are added
ext::intel::experimental::matrix::joint_matrix_store_checked(
sg, sub_c, pC, N, layout::row_major, M, N, sg_startx * TM,
Expand All @@ -83,42 +90,50 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q) {
}).wait();
}

int main() {
template <typename Tab, typename Tc, size_t TM, size_t TN, size_t TK,
layout A_layout, layout B_layout, unsigned int vnniFactor>
void test() {
static constexpr size_t MATRIX_M = 1024 + 14;
static constexpr size_t MATRIX_N = 1024;
static constexpr unsigned int vnniFactor = 2;

queue q;
bfloat16 *A = malloc_shared<bfloat16>(MATRIX_M * MATRIX_K, q);
bfloat16 *B = malloc_shared<bfloat16>(MATRIX_K * MATRIX_N, q);
bfloat16 *vnniB = malloc_shared<bfloat16>(MATRIX_K * MATRIX_N, q);
float *C = malloc_shared<float>(MATRIX_M * MATRIX_N, q);
float *D = malloc_shared<float>(MATRIX_M * MATRIX_N, q);

matrix_rand(MATRIX_M, MATRIX_K, A, (bfloat16)5);
matrix_rand(MATRIX_K, MATRIX_N, B, (bfloat16)5);
matrix_fill(MATRIX_M, MATRIX_N, C, (float)1);
matrix_fill(MATRIX_M, MATRIX_N, D, (float)1);

matrix_vnni<bfloat16>(MATRIX_K, MATRIX_N, B, vnniB, vnniFactor);

// reference data
Tab *A = malloc_shared<Tab>(MATRIX_M * MATRIX_K, q);
Tab *B = malloc_shared<Tab>(MATRIX_K * MATRIX_N, q);
Tc *C = malloc_shared<Tc>(MATRIX_M * MATRIX_N, q);
Tc *D = malloc_shared<Tc>(MATRIX_M * MATRIX_N, q);
matrix_rand(MATRIX_M, MATRIX_K, A, (Tab)5);
matrix_rand(MATRIX_K, MATRIX_N, B, (Tab)5);
matrix_fill(MATRIX_M, MATRIX_N, D, (Tc)1);
matrix_multiply_ref(A, B, D, MATRIX_M, MATRIX_N, MATRIX_K);
matrix_multiply<float, bfloat16, MATRIX_M, MATRIX_K, MATRIX_K / vnniFactor,
MATRIX_N * vnniFactor, MATRIX_M, MATRIX_N,
layout::ext_intel_packed, vnniFactor>(C, A, vnniB, q);
bool res = matrix_compare(MATRIX_M, MATRIX_N, C, D);

matrix_multiply<float, bfloat16, MATRIX_M, MATRIX_K, MATRIX_K, MATRIX_N,
MATRIX_M, MATRIX_N, layout::row_major, 1>(C, A, B, q);
res = res && matrix_compare(MATRIX_M, MATRIX_N, C, D);

std::cout << (res ? "passed" : "failed") << std::endl;
// test data
if constexpr (A_layout == layout::row_major) {
if constexpr (B_layout == layout::row_major) {
matrix_multiply<Tc, Tab, MATRIX_M, MATRIX_N, MATRIX_K, TM, TN, TK,
A_layout, B_layout, vnniFactor>(C, A, B, q);
} else if constexpr (B_layout == layout::col_major) {
} else {
Tab *vnniB = malloc_shared<Tab>(MATRIX_K * MATRIX_N, q);
matrix_vnni<Tab>(MATRIX_K, MATRIX_N, B, vnniB, vnniFactor);
matrix_multiply<Tc, Tab, MATRIX_M, MATRIX_N, MATRIX_K, TM, TN, TK,
A_layout, B_layout, vnniFactor>(C, A, vnniB, q);
free(vnniB, q);
}
} else {
}

assert(matrix_compare(MATRIX_M, MATRIX_N, C, D));
std::cout << "passed" << std::endl;

free(A, q);
free(B, q);
free(vnniB, q);
free(C, q);
free(D, q);
}

return !res;
int main() {
test<bfloat16, float, 8, 16, 16, layout::row_major, layout::row_major, 1>();
test<bfloat16, float, 8, 16, 16, layout::row_major, layout::ext_intel_packed,
2>();
}
1 change: 0 additions & 1 deletion sycl/test-e2e/Matrix/SG32/joint_matrix_out_bounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
#include "common.hpp"

#define SG_SZ 32
constexpr size_t TN = 16;
constexpr size_t MATRIX_K = 1024 + 24;

#include "joint_matrix_out_bounds_impl.hpp"
1 change: 0 additions & 1 deletion sycl/test-e2e/Matrix/SG32/joint_matrix_unaligned_k.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
#include "common.hpp"

#define SG_SZ 32
constexpr size_t TN = 16;
static constexpr size_t MATRIX_K = 1024 + 14;

#include "joint_matrix_out_bounds_impl.hpp"
1 change: 0 additions & 1 deletion sycl/test-e2e/Matrix/joint_matrix_out_bounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

#include "common.hpp"

constexpr size_t TN = 16;
constexpr size_t MATRIX_K = 1024 + 24;

#include "joint_matrix_out_bounds_impl.hpp"
1 change: 0 additions & 1 deletion sycl/test-e2e/Matrix/joint_matrix_unaligned_k.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

#include "common.hpp"

constexpr size_t TN = 16;
static constexpr size_t MATRIX_K = 1024 + 14;

#include "joint_matrix_out_bounds_impl.hpp"

0 comments on commit 817f6bd

Please sign in to comment.