Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
YuriPlyakhin committed Jan 31, 2025
1 parent 817f6bd commit b2e5d59
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 61 deletions.
21 changes: 8 additions & 13 deletions sycl/test-e2e/Matrix/Inputs/joint_matrix_out_bounds_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include <iostream>
#include <sycl/usm.hpp>

template <layout B_layout, unsigned int vnniFactor> class mult;
template <size_t K, layout B_layout, unsigned int vnniFactor> class mult;

template <typename T1, typename T2, size_t M, size_t N, size_t K, size_t TM,
size_t TN, size_t TK, layout A_layout, layout B_layout,
Expand All @@ -19,10 +19,10 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q) {
// Add one iteration for the out of bounds dpas instruction
size_t NDRangeM = M / TM + (((M % TM) != 0) ? 1 : 0);
size_t NDRangeN = N / TN;
size_t sg_size = get_sg_size<mult<B_layout, vnniFactor>>(q);
size_t sg_size = get_sg_size<mult<K, B_layout, vnniFactor>>(q);

q.submit([&](handler &cgh) {
cgh.parallel_for<mult<B_layout, vnniFactor>>(
cgh.parallel_for<mult<K, B_layout, vnniFactor>>(
nd_range<2>({NDRangeM, NDRangeN * sg_size}, {1, 1 * sg_size}),
[=](nd_item<2> spmd_item)
#ifdef SG_SZ
Expand Down Expand Up @@ -90,11 +90,12 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q) {
}).wait();
}

template <typename Tab, typename Tc, size_t TM, size_t TN, size_t TK,
layout A_layout, layout B_layout, unsigned int vnniFactor>
template <typename Tab, typename Tc, size_t MATRIX_M, size_t MATRIX_N,
size_t MATRIX_K, size_t TM, size_t TN, size_t TK, layout A_layout,
layout B_layout, unsigned int vnniFactor>
void test() {
static constexpr size_t MATRIX_M = 1024 + 14;
static constexpr size_t MATRIX_N = 1024;
std::cout << MATRIX_M << "x" << MATRIX_N << "x" << MATRIX_K << ", " << TM
<< "x" << TN << "x" << TK << ": ";
queue q;

// reference data
Expand Down Expand Up @@ -131,9 +132,3 @@ void test() {
free(C, q);
free(D, q);
}

int main() {
test<bfloat16, float, 8, 16, 16, layout::row_major, layout::row_major, 1>();
test<bfloat16, float, 8, 16, 16, layout::row_major, layout::ext_intel_packed,
2>();
}
17 changes: 14 additions & 3 deletions sycl/test-e2e/Matrix/SG32/joint_matrix_out_bounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,19 @@
// XFAIL-TRACKER: GSD-4181

#include "common.hpp"

#define SG_SZ 32
constexpr size_t MATRIX_K = 1024 + 24;

#include "joint_matrix_out_bounds_impl.hpp"

int main() {
test<bfloat16, float, 1024 + 14, 1024, 1024 + 24, 8, 16, 16,
layout::row_major, layout::row_major, 1>();
test<bfloat16, float, 1024 + 14, 1024, 1024 + 24, 8, 16, 16,
layout::row_major, layout::ext_intel_packed, 2>();

// unaligned k:
test<bfloat16, float, 1024 + 14, 1024, 1024 + 14, 8, 16, 16,
layout::row_major, layout::row_major, 1>();
test<bfloat16, float, 1024 + 14, 1024, 1024 + 14, 8, 16, 16,
layout::row_major, layout::ext_intel_packed, 2>();

}
24 changes: 0 additions & 24 deletions sycl/test-e2e/Matrix/SG32/joint_matrix_unaligned_k.cpp

This file was deleted.

15 changes: 13 additions & 2 deletions sycl/test-e2e/Matrix/joint_matrix_out_bounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,18 @@
// RUN: %{run} %t.out

#include "common.hpp"
#include "joint_matrix_out_bounds_impl.hpp"

constexpr size_t MATRIX_K = 1024 + 24;
int main() {
std::cout << "bf16:\n";
test<bfloat16, float, 1024 + 14, 1024, 1024 + 24, 8, 16, 16,
layout::row_major, layout::row_major, 1>();
test<bfloat16, float, 1024 + 14, 1024, 1024 + 24, 8, 16, 16,
layout::row_major, layout::ext_intel_packed, 2>();

#include "joint_matrix_out_bounds_impl.hpp"
// unaligned k:
test<bfloat16, float, 1024 + 14, 1024, 1024 + 14, 8, 16, 16,
layout::row_major, layout::row_major, 1>();
test<bfloat16, float, 1024 + 14, 1024, 1024 + 14, 8, 16, 16,
layout::row_major, layout::ext_intel_packed, 2>();
}
19 changes: 0 additions & 19 deletions sycl/test-e2e/Matrix/joint_matrix_unaligned_k.cpp

This file was deleted.

0 comments on commit b2e5d59

Please sign in to comment.