Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
YuriPlyakhin committed Jan 31, 2025
1 parent b2e5d59 commit 6dd1426
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 14 deletions.
37 changes: 24 additions & 13 deletions sycl/test-e2e/Matrix/Inputs/joint_matrix_out_bounds_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q) {
size_t NDRangeM = M / TM + (((M % TM) != 0) ? 1 : 0);
size_t NDRangeN = N / TN;
size_t sg_size = get_sg_size<mult<K, B_layout, vnniFactor>>(q);
std::cout << "SG size: " << sg_size << " ";

q.submit([&](handler &cgh) {
cgh.parallel_for<mult<K, B_layout, vnniFactor>>(
Expand Down Expand Up @@ -109,21 +110,31 @@ void test() {
matrix_multiply_ref(A, B, D, MATRIX_M, MATRIX_N, MATRIX_K);

// test data
if constexpr (A_layout == layout::row_major) {
if constexpr (B_layout == layout::row_major) {
matrix_multiply<Tc, Tab, MATRIX_M, MATRIX_N, MATRIX_K, TM, TN, TK,
A_layout, B_layout, vnniFactor>(C, A, B, q);
} else if constexpr (B_layout == layout::col_major) {
} else {
Tab *vnniB = malloc_shared<Tab>(MATRIX_K * MATRIX_N, q);
matrix_vnni<Tab>(MATRIX_K, MATRIX_N, B, vnniB, vnniFactor);
matrix_multiply<Tc, Tab, MATRIX_M, MATRIX_N, MATRIX_K, TM, TN, TK,
A_layout, B_layout, vnniFactor>(C, A, vnniB, q);
free(vnniB, q);
}
} else {
if constexpr (A_layout == layout::col_major) {
Tab *colA = malloc_shared<Tab>(MATRIX_K * MATRIX_M, q);
matrix_transpose(MATRIX_M, MATRIX_K, colA, A);
Tab *tmp = A;
A = colA;
free(tmp, q);
}

if constexpr (B_layout == layout::col_major) {
Tab *colB = malloc_shared<Tab>(MATRIX_N * MATRIX_K, q);
matrix_transpose(MATRIX_K, MATRIX_N, colB, B);
Tab *tmp = B;
B = colB;
free(tmp, q);
}

if constexpr (B_layout == layout::ext_intel_packed) {
Tab *vnniB = malloc_shared<Tab>(MATRIX_K * MATRIX_N, q);
matrix_vnni(MATRIX_K, MATRIX_N, B, vnniB, vnniFactor);
Tab *tmp = B;
B = vnniB;
free(tmp, q);
}

matrix_multiply<Tc, Tab, MATRIX_M, MATRIX_N, MATRIX_K, TM, TN, TK, A_layout, B_layout, vnniFactor>(C, A, B, q);
assert(matrix_compare(MATRIX_M, MATRIX_N, C, D));
std::cout << "passed" << std::endl;

Expand Down
5 changes: 4 additions & 1 deletion sycl/test-e2e/Matrix/joint_matrix_out_bounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,18 @@
#include "joint_matrix_out_bounds_impl.hpp"

int main() {
std::cout << "bf16:\n";
std::cout << "bf16 A row major, B row major: ";
test<bfloat16, float, 1024 + 14, 1024, 1024 + 24, 8, 16, 16,
layout::row_major, layout::row_major, 1>();
std::cout << "bf16 A row major, B packed: ";
test<bfloat16, float, 1024 + 14, 1024, 1024 + 24, 8, 16, 16,
layout::row_major, layout::ext_intel_packed, 2>();

// unaligned k:
std::cout << "bf16 A row major, B row major: ";
test<bfloat16, float, 1024 + 14, 1024, 1024 + 14, 8, 16, 16,
layout::row_major, layout::row_major, 1>();
std::cout << "bf16 A row major, B packed: ";
test<bfloat16, float, 1024 + 14, 1024, 1024 + 14, 8, 16, 16,
layout::row_major, layout::ext_intel_packed, 2>();
}
44 changes: 44 additions & 0 deletions sycl/test-e2e/Matrix/joint_matrix_out_bounds_colmajor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
//==-------- joint_matrix_out_bounds.cpp - DPC++ joint_matrix--------------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// REQUIRES: aspect-ext_intel_matrix
// UNSUPPORTED: gpu-intel-dg2, cpu

// RUN: %{build} -o %t.out
// RUN: %{run} %t.out

// RUN: %{build} -o %t32.out -DSG_SZ=32
// RUN: %{run} %t32.out

// XFAIL:gpu
// XFAIL-TRACKER: GSD-5768

#include "common.hpp"
#include "joint_matrix_out_bounds_impl.hpp"

int main() {
std::cout << "bf16 A col major, B col major: ";
test<bfloat16, float, 1024 + 14, 1024, 1024 + 24, 8, 16, 16,
layout::col_major, layout::col_major, 1>();
std::cout << "half A col major, B col major: ";
test<half, float, 1024 + 14, 1024, 1024 + 24, 8, 16, 16,
layout::col_major, layout::col_major, 1>();
std::cout << "int8 A col major, B col major: ";
test<int8_t, int32_t, 1024 + 14, 1024, 1024 + 24, 8, 16, 32,
layout::col_major, layout::col_major, 2>();

// unaligned k:
std::cout << "bf16 A col major, B col major: ";
test<bfloat16, float, 1024 + 14, 1024, 1024 + 14, 8, 16, 16,
layout::col_major, layout::col_major, 1>();
std::cout << "half A col major, B col major: ";
test<half, float, 1024 + 14, 1024, 1024 + 14, 8, 16, 16,
layout::col_major, layout::col_major, 1>();
std::cout << "int8 A col major, B col major: ";
test<int8_t, int32_t, 1024 + 14, 1024, 1024 + 14, 8, 16, 32,
layout::col_major, layout::col_major, 2>();
}

0 comments on commit 6dd1426

Please sign in to comment.