From d6e684747f537f083113999e6d3f80f88304d85d Mon Sep 17 00:00:00 2001 From: "Plyakhin, Yury" Date: Fri, 31 Jan 2025 12:49:00 -0800 Subject: [PATCH] updat --- sycl/test-e2e/Matrix/Inputs/common.hpp | 7 +++ .../Inputs/joint_matrix_out_bounds_impl.hpp | 45 +++++++++++++++---- .../Matrix/SG32/joint_matrix_out_bounds.cpp | 15 ++----- .../Matrix/joint_matrix_out_bounds.cpp | 18 ++------ .../joint_matrix_out_bounds_colmajor.cpp | 22 +-------- 5 files changed, 53 insertions(+), 54 deletions(-) diff --git a/sycl/test-e2e/Matrix/Inputs/common.hpp b/sycl/test-e2e/Matrix/Inputs/common.hpp index dca215ae574d2..371f705bae07f 100644 --- a/sycl/test-e2e/Matrix/Inputs/common.hpp +++ b/sycl/test-e2e/Matrix/Inputs/common.hpp @@ -234,3 +234,10 @@ void matrix_print(unsigned int rows, unsigned int cols, T *mat) { std::cout << "\n"; } } + +template constexpr int vnni_factor() { + if constexpr (Layout != layout::ext_intel_packed) + return 1; + static_assert(sizeof(T) <= 4 && "Unsupported type in vnni_factor()."); + return 4 / sizeof(T); +} diff --git a/sycl/test-e2e/Matrix/Inputs/joint_matrix_out_bounds_impl.hpp b/sycl/test-e2e/Matrix/Inputs/joint_matrix_out_bounds_impl.hpp index ac9282e910dd7..70c6f934a3c68 100644 --- a/sycl/test-e2e/Matrix/Inputs/joint_matrix_out_bounds_impl.hpp +++ b/sycl/test-e2e/Matrix/Inputs/joint_matrix_out_bounds_impl.hpp @@ -9,22 +9,20 @@ #include #include -template -class mult; +template class mult; template + size_t TN, size_t TK, layout A_layout, layout B_layout> void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q) { // Add one iteration for the out of bounds dpas instruction size_t NDRangeM = M / TM + (((M % TM) != 0) ? 1 : 0); size_t NDRangeN = N / TN; - size_t sg_size = get_sg_size>(q); + size_t sg_size = get_sg_size>(q); std::cout << "SG size: " << sg_size << " "; q.submit([&](handler &cgh) { - cgh.parallel_for>( + cgh.parallel_for>( nd_range<2>({NDRangeM, NDRangeN * sg_size}, {1, 1 * sg_size}), [=](nd_item<2> spmd_item) #ifdef SG_SZ @@ -72,6 +70,7 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q) { // bounds-checked load where width and height are added // params order: Stride, Height, Width, CoordX, CoordY if constexpr (B_layout != layout::col_major) { + constexpr unsigned int vnniFactor = vnni_factor(); ext::intel::experimental::matrix::joint_matrix_load_checked( sg, sub_b, pB, N * vnniFactor, K / vnniFactor, N * vnniFactor, k / vnniFactor, @@ -94,7 +93,7 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q) { template + layout B_layout> void test() { std::cout << MATRIX_M << "x" << MATRIX_N << "x" << MATRIX_K << ", " << TM << "x" << TN << "x" << TK << ": "; @@ -129,13 +128,14 @@ void test() { if constexpr (B_layout == layout::ext_intel_packed) { Tab *vnniB = malloc_shared(MATRIX_K * MATRIX_N, q); - matrix_vnni(MATRIX_K, MATRIX_N, B, vnniB, vnniFactor); + matrix_vnni(MATRIX_K, MATRIX_N, B, vnniB, vnni_factor()); Tab *tmp = B; B = vnniB; free(tmp, q); } - matrix_multiply(C, A, B, q); + matrix_multiply(C, A, B, q); assert(matrix_compare(MATRIX_M, MATRIX_N, C, D)); std::cout << "passed" << std::endl; @@ -144,3 +144,30 @@ void test() { free(C, q); free(D, q); } + +template void test_all() { + std::cout << "bf16: "; + test(); + std::cout << "half: "; + test(); + std::cout << "int8: "; + test(); + + // unaligned k: + std::cout << "bf16: "; + test(); + std::cout << "half: "; + test(); + + // row major A fails, so disabled. CMPLRLLVM-65239 + if constexpr (A_layout != layout::row_major) { + std::cout << "int8: "; + test(); + } +} diff --git a/sycl/test-e2e/Matrix/SG32/joint_matrix_out_bounds.cpp b/sycl/test-e2e/Matrix/SG32/joint_matrix_out_bounds.cpp index 132631cc640d1..cf982bed31562 100644 --- a/sycl/test-e2e/Matrix/SG32/joint_matrix_out_bounds.cpp +++ b/sycl/test-e2e/Matrix/SG32/joint_matrix_out_bounds.cpp @@ -21,15 +21,8 @@ #include "joint_matrix_out_bounds_impl.hpp" int main() { - test(); - test(); - - // unaligned k: - test(); - test(); - + std::cout << "A row major, B row major:\n"; + test_all(); + std::cout << "A row major, B packed:\n"; + test_all(); } diff --git a/sycl/test-e2e/Matrix/joint_matrix_out_bounds.cpp b/sycl/test-e2e/Matrix/joint_matrix_out_bounds.cpp index 72f172f16be92..4aa30435a2ffc 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_out_bounds.cpp +++ b/sycl/test-e2e/Matrix/joint_matrix_out_bounds.cpp @@ -16,18 +16,8 @@ #include "joint_matrix_out_bounds_impl.hpp" int main() { - std::cout << "bf16 A row major, B row major: "; - test(); - std::cout << "bf16 A row major, B packed: "; - test(); - - // unaligned k: - std::cout << "bf16 A row major, B row major: "; - test(); - std::cout << "bf16 A row major, B packed: "; - test(); + std::cout << "A row major, B row major:\n"; + test_all(); + std::cout << "A row major, B packed:\n"; + test_all(); } diff --git a/sycl/test-e2e/Matrix/joint_matrix_out_bounds_colmajor.cpp b/sycl/test-e2e/Matrix/joint_matrix_out_bounds_colmajor.cpp index a93a1701c0bfa..30cdb69730e39 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_out_bounds_colmajor.cpp +++ b/sycl/test-e2e/Matrix/joint_matrix_out_bounds_colmajor.cpp @@ -21,24 +21,6 @@ #include "joint_matrix_out_bounds_impl.hpp" int main() { - std::cout << "bf16 A col major, B col major: "; - test(); - std::cout << "half A col major, B col major: "; - test(); - std::cout << "int8 A col major, B col major: "; - test(); - - // unaligned k: - std::cout << "bf16 A col major, B col major: "; - test(); - std::cout << "half A col major, B col major: "; - test(); - std::cout << "int8 A col major, B col major: "; - test(); + std::cout << "A col major, B col major:\n"; + test_all(); }