Skip to content

Commit

Permalink
[SYCL][Joint Matrix Tests] Add fill/store/apply tests for 16x16x16, 3…
Browse files Browse the repository at this point in the history
…2x64x16 (#12629)
  • Loading branch information
YuriPlyakhin authored Mar 22, 2024
1 parent 75f6cd2 commit 84426d1
Show file tree
Hide file tree
Showing 8 changed files with 177 additions and 179 deletions.
13 changes: 5 additions & 8 deletions sycl/test-e2e/Matrix/SG32/element_wise_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,16 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// REQUIRES: matrix
// REQUIRES: aspect-ext_intel_matrix
// REQUIRES-INTEL-DRIVER: lin: 27501, win: 101.4943
// SG size = 32 is not currently supported for SYCL Joint Matrix by IGC on DG2
// UNSUPPORTED: gpu-intel-dg2

// RUN: %{build} -o %t.out
// RUN: %{run} %t.out

#include <iostream>
#include <sycl/sycl.hpp>
#include "../common.hpp"

using namespace sycl;
using namespace sycl::ext::oneapi::experimental::matrix;

constexpr size_t SG_SZ = 32;
constexpr size_t TN = 16;
#define SG_SZ 32

#include "../element_wise_ops_impl.hpp"
22 changes: 0 additions & 22 deletions sycl/test-e2e/Matrix/XMX8/element_wise_ops.cpp

This file was deleted.

25 changes: 14 additions & 11 deletions sycl/test-e2e/Matrix/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,21 @@ float make_fp32(bfloat16 x) {
return *res;
}

template <typename Ta, typename Tb, typename Tc, unsigned int VF = 1>
template <typename Ta, typename Tb, typename Tc, unsigned int VF = 1,
typename F = std::nullptr_t>
void matrix_multiply_ref(Ta *A, Tb *B, Tc *C, int M, int N, int K,
bool transpose_c = false, bool colmajor_a = false,
bool colmajor_b = false) {
bool colmajor_b = false, F &&lambda = {}) {
for (unsigned int m = 0; m < M; m++) {
for (unsigned int n = 0; n < N; n++) {
for (unsigned int k = 0; k < K; k++) {
int c_ind = transpose_c ? (n * M + m) : m * N + n;
Tc acc = *(C + c_ind);

for (unsigned int k = 0; k < K; k++) {
int a_ind = colmajor_a ? (k * M + m) : m * K + k;
int b_ind = colmajor_b ? (n * K + k) : k * N + n;
int c_ind = transpose_c ? (n * M + m) : m * N + n;

Ta *va = (Ta *)(A + a_ind * VF);
Tb *vb = (Tb *)(B + b_ind * VF);
Tc acc = *(C + c_ind);

for (unsigned int i = 0; i < VF; i++) {
if constexpr (std::is_same_v<Ta, bfloat16> &&
Expand All @@ -74,9 +74,12 @@ void matrix_multiply_ref(Ta *A, Tb *B, Tc *C, int M, int N, int K,
else
assert(false && "Unsupported type in matrix_multiply_ref.");
}
}

*(C + c_ind) = acc;
if constexpr (!std::is_same_v<F, std::nullptr_t>) {
lambda(acc);
}
*(C + c_ind) = acc;
}
}
}
Expand Down Expand Up @@ -132,8 +135,7 @@ void matrix_rand(unsigned int rows, unsigned int cols, T *src, T val) {
if constexpr (std::is_same_v<T, bfloat16> || std::is_same_v<T, float> ||
std::is_same_v<T, double>) {
src[i * cols + j] = T(fdistr(dev));
} else if constexpr (std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>) {
} else if constexpr (std::is_integral_v<T>) {
src[i * cols + j] = T(idistr(dev));
} else {
assert(false && "Unsupported type in matrix_rand.");
Expand Down Expand Up @@ -170,8 +172,9 @@ bool matrix_compare(unsigned int rows, unsigned int cols, T1 *src, T2 *ref) {
}
} else if constexpr (exact || std::is_same_v<T1, int32_t>) {
if (src[i * cols + j] != ref[i * cols + j]) {
std::cout << "Incorrect result in matrix." << "i: " << i
<< ", j: " << j << ", Ref: " << ref[i * cols + j]
std::cout << "Incorrect result in matrix."
<< "i: " << i << ", j: " << j
<< ", Ref: " << ref[i * cols + j]
<< ", Val: " << src[i * cols + j] << "\n";
return false;
}
Expand Down
1 change: 0 additions & 1 deletion sycl/test-e2e/Matrix/element_wise_all_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,4 @@
// RUN: %{run} %t.out

#include "common.hpp"

#include "element_wise_all_ops_impl.hpp"
83 changes: 55 additions & 28 deletions sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

template <typename T, size_t NUM_ROWS, size_t NUM_COLS>
void assert_ops_ref(host_accessor<T, 2, access::mode::read> mat,
const float ref) {
Expand Down Expand Up @@ -105,8 +106,11 @@ void verify_op_c(const T l, const T r, const float ref, OP op) {

// Avoid same kernel name for different types
template <typename T, class name> class ewops_a {};
template <typename T, size_t NROWS, size_t NCOLS, size_t SROWS, size_t SCOLS>
void test_ewops_a() {
template <typename T, size_t SROWS, size_t SCOLS> void test_ewops_a() {
std::cout << "Test A " << SROWS << "x" << SCOLS << "\n";

static constexpr size_t NROWS = SROWS * 2;
static constexpr size_t NCOLS = SCOLS * 2;

verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_add>>(
T(5.0), T(2.0), 7.0, [](auto l, auto r) { return l + r; });
Expand Down Expand Up @@ -135,64 +139,87 @@ void test_ewops_a() {
T(5.0), T(2.0), 2.0,
[](auto l, auto r) { return l <= r ? T(3.0) : T(2.0); });
}

// Avoid same kernel name for different types and numbers of columns
template <typename T, size_t COLS, class name> class ewops_c {};
template <typename T, size_t NROWS, size_t NCOLS, size_t SROWS, size_t SCOLS>
void test_ewops_c() {
template <typename T, size_t ROWS, size_t COLS, class name> class ewops_c {};
template <typename T, size_t SROWS, size_t SCOLS> void test_ewops_c() {
std::cout << "Test C " << SROWS << "x" << SCOLS << "\n";

verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_add>>(
static constexpr size_t NROWS = SROWS * 2;
static constexpr size_t NCOLS = SCOLS * 2;

verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS,
ewops_c<T, SROWS, SCOLS, class c_add>>(
T(5.0), T(2.0), 7.0, [](auto l, auto r) { return l + r; });
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_sub>>(
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS,
ewops_c<T, SROWS, SCOLS, class c_sub>>(
T(5.0), T(2.0), 3.0, [](auto l, auto r) { return l - r; });
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_mul>>(
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS,
ewops_c<T, SROWS, SCOLS, class c_mul>>(
T(5.0), T(2.0), 10.0, [](auto l, auto r) { return l * r; });
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_div>>(
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS,
ewops_c<T, SROWS, SCOLS, class c_div>>(
T(5.0), T(2.0), 2.5, [](auto l, auto r) { return l / r; });
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS,
ewops_c<T, SCOLS, class c_logical>>(
ewops_c<T, SROWS, SCOLS, class c_logical>>(
T(5.0), T(5.0), 5.0, [](auto l, auto r) { return l == r ? l : T(1.0); });
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_eq>>(
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS,
ewops_c<T, SROWS, SCOLS, class c_eq>>(
T(5.0), T(4.0), 4.0, [](auto l, auto r) { return l == r ? l : r; });
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_ne>>(
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS,
ewops_c<T, SROWS, SCOLS, class c_ne>>(
T(5.0), T(5.0), 1.0, [](auto l, auto r) { return l != r ? l : T(1.0); });
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_gt>>(
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS,
ewops_c<T, SROWS, SCOLS, class c_gt>>(
T(5.0), T(2.0), 3.0,
[](auto l, auto r) { return l > r ? T(3.0) : T(2.0); });
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_lt>>(
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS,
ewops_c<T, SROWS, SCOLS, class c_lt>>(
T(5.0), T(2.0), 2.0,
[](auto l, auto r) { return l < r ? T(3.0) : T(2.0); });
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_ge>>(
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS,
ewops_c<T, SROWS, SCOLS, class c_ge>>(
T(5.0), T(2.0), 3.0,
[](auto l, auto r) { return l >= r ? T(3.0) : T(2.0); });
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS, ewops_c<T, SCOLS, class c_le>>(
verify_op_c<T, NROWS, NCOLS, SROWS, SCOLS,
ewops_c<T, SROWS, SCOLS, class c_le>>(
T(5.0), T(2.0), 2.0,
[](auto l, auto r) { return l <= r ? T(3.0) : T(2.0); });
}

int main() {
static constexpr size_t TM = 8;

static constexpr size_t MATRIX_M = TM * 2;
static constexpr size_t MATRIX_N = 32;
static constexpr size_t MATRIX_K = 32;
queue q;
std::vector<combination> combinations =
q.get_device()
.get_info<sycl::ext::oneapi::experimental::info::device::
matrix_combinations>();

for (unsigned int i = 0; i < combinations.size(); i++) {
if (combinations[i].atype == matrix_type::bf16) {
if (combinations[i].nsize == 0 || combinations[i].nsize == 16) {
test_ewops_a<bfloat16, MATRIX_M, MATRIX_K, TM, 16>();
test_ewops_c<float, MATRIX_M, MATRIX_N, TM, 16>();
break;

if (combinations[i].nsize == 0 ||
(combinations[i].msize == 0 && combinations[i].nsize == 16)) {
test_ewops_a<bfloat16, 8, 16>();
test_ewops_c<float, 8, 16>();
}

if (combinations[i].msize == 16 && combinations[i].nsize == 16) {
test_ewops_c<float, 16, 16>();
}

// This combination is not currently supported for sub group size = 32 in IGC
#if (!defined(SG_SZ) || SG_SZ != 32)
if (combinations[i].msize == 32 && combinations[i].nsize == 64) {
test_ewops_c<float, 32, 64>();
}
#endif

if (combinations[i].nsize == 8) {
test_ewops_a<bfloat16, MATRIX_M, MATRIX_K, TM, 16>();
test_ewops_c<float, MATRIX_M, MATRIX_N, TM, 8>();
break;
test_ewops_a<bfloat16, 8, 16>();
test_ewops_c<float, 8, 8>();
}
}
}

return 0;
}
12 changes: 2 additions & 10 deletions sycl/test-e2e/Matrix/element_wise_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,10 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// REQUIRES: matrix
// REQUIRES: aspect-ext_intel_matrix

// RUN: %{build} -o %t.out
// RUN: %{run} %t.out

#include <iostream>
#include <sycl/sycl.hpp>

using namespace sycl;
using namespace sycl::ext::oneapi::experimental::matrix;

#define SG_SZ 16
constexpr size_t TN = 16;

#include "common.hpp"
#include "element_wise_ops_impl.hpp"
Loading

0 comments on commit 84426d1

Please sign in to comment.