Skip to content

Commit

Permalink
Refactor binary unroll and update rvv roofline of binary.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangyang2057 committed Oct 25, 2024
1 parent 2bd2c6d commit 1a847f4
Show file tree
Hide file tree
Showing 7 changed files with 204 additions and 371 deletions.
98 changes: 56 additions & 42 deletions src/Native/include/nncase/ntt/arch/riscv64/ukernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,54 +23,68 @@
namespace nncase::ntt::ukernels {

// unary
template <typename T>
struct u_unary_policy<ntt::ops::abs<vector<T, NTT_VLEN / sizeof(T) / 8>>,
vector<T, NTT_VLEN / sizeof(T) / 8>, true> {
static constexpr size_t unroll = 8;
};

template <typename T>
struct u_unary_policy<ntt::ops::ceil<vector<T, NTT_VLEN / sizeof(T) / 8>>,
vector<T, NTT_VLEN / sizeof(T) / 8>, true> {
static constexpr size_t unroll = 8;
};

template <typename T>
struct u_unary_policy<ntt::ops::floor<vector<T, NTT_VLEN / sizeof(T) / 8>>,
vector<T, NTT_VLEN / sizeof(T) / 8>, true> {
static constexpr size_t unroll = 8;
};

template <typename T>
struct u_unary_policy<ntt::ops::neg<vector<T, NTT_VLEN / sizeof(T) / 8>>,
vector<T, NTT_VLEN / sizeof(T) / 8>, true> {
static constexpr size_t unroll = 8;
};

template <typename T>
struct u_unary_policy<ntt::ops::round<vector<T, NTT_VLEN / sizeof(T) / 8>>,
vector<T, NTT_VLEN / sizeof(T) / 8>, true> {
static constexpr size_t unroll = 8;
};

template <typename T>
struct u_unary_policy<ntt::ops::sign<vector<T, NTT_VLEN / sizeof(T) / 8>>,
vector<T, NTT_VLEN / sizeof(T) / 8>, true> {
static constexpr size_t unroll = 8;
};

template <typename T>
struct u_unary_policy<ntt::ops::square<vector<T, NTT_VLEN / sizeof(T) / 8>>,
vector<T, NTT_VLEN / sizeof(T) / 8>, true> {
static constexpr size_t unroll = 8;
};

#define SPECIALIZE_U_UNARY(op, unroll_num) \
template <typename T> \
struct u_unary_policy<ntt::ops::op<vector<T, NTT_VLEN / sizeof(T) / 8>>, \
vector<T, NTT_VLEN / sizeof(T) / 8>, true> { \
static constexpr size_t unroll = unroll_num; \
};

SPECIALIZE_U_UNARY(abs, 8)
SPECIALIZE_U_UNARY(ceil, 8)
SPECIALIZE_U_UNARY(floor, 8)
SPECIALIZE_U_UNARY(neg, 8)
SPECIALIZE_U_UNARY(round, 8)
SPECIALIZE_U_UNARY(sign, 8)
SPECIALIZE_U_UNARY(square, 8)

#undef SPECIALIZE_U_UNARY

// binary
#define SPECIALIZE_U_BINARY(op, unroll_num) \
template <typename T1, typename T2> \
struct u_binary_policy< \
ntt::ops::op<vector<T1, NTT_VLEN / sizeof(T1) / 8>, \
vector<T2, NTT_VLEN / sizeof(T2) / 8>>, \
vector<T1, NTT_VLEN / sizeof(T1) / 8>, \
vector<T2, NTT_VLEN / sizeof(T2) / 8>, true> { \
static constexpr size_t unroll = unroll_num; \
}; \
\
template <typename T1, typename T2> \
struct u_binary_policy< \
ntt::ops::op<T1, vector<T2, NTT_VLEN / sizeof(T2) / 8>>, T1, \
vector<T2, NTT_VLEN / sizeof(T2) / 8>, true> { \
static constexpr size_t unroll = unroll_num; \
}; \
\
template <typename T1, typename T2> \
struct u_binary_policy< \
ntt::ops::op<vector<T1, NTT_VLEN / sizeof(T1) / 8>, T2>, \
vector<T1, NTT_VLEN / sizeof(T1) / 8>, T2, true> { \
static constexpr size_t unroll = unroll_num; \
};

SPECIALIZE_U_BINARY(add, 8)
SPECIALIZE_U_BINARY(sub, 8)
SPECIALIZE_U_BINARY(mul, 8)
SPECIALIZE_U_BINARY(div, 8)
SPECIALIZE_U_BINARY(max, 8)
SPECIALIZE_U_BINARY(min, 8)
SPECIALIZE_U_BINARY(mod, 8)
SPECIALIZE_U_BINARY(floor_mod, 8)

#undef SPECIALIZE_U_BINARY

// reduce
template <reduce_op Op, class T> struct u_reduce_policy<Op, T, true> {
static constexpr size_t unroll = 8;
};

// cast
template <> struct u_cast_policy<true> { static constexpr size_t unroll = 8; };

// matmul
template <>
struct u_matmul_policy<mamtul_pack_kind::no_pack, float, float, float, true> {
static constexpr size_t m0_tile = 1;
Expand Down
85 changes: 50 additions & 35 deletions src/Native/include/nncase/ntt/arch/x86_64/ukernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,41 +19,54 @@
namespace nncase::ntt::ukernels {

// unary
template <typename T>
struct u_unary_policy<ntt::ops::abs<vector<T, 8>>, vector<T, 8>, true> {
static constexpr size_t unroll = 2;
};

template <typename T>
struct u_unary_policy<ntt::ops::ceil<vector<T, 8>>, vector<T, 8>, true> {
static constexpr size_t unroll = 2;
};

template <typename T>
struct u_unary_policy<ntt::ops::floor<vector<T, 8>>, vector<T, 8>, true> {
static constexpr size_t unroll = 2;
};

template <typename T>
struct u_unary_policy<ntt::ops::neg<vector<T, 8>>, vector<T, 8>, true> {
static constexpr size_t unroll = 2;
};

template <typename T>
struct u_unary_policy<ntt::ops::round<vector<T, 8>>, vector<T, 8>, true> {
static constexpr size_t unroll = 2;
};

template <typename T>
struct u_unary_policy<ntt::ops::sign<vector<T, 8>>, vector<T, 8>, true> {
static constexpr size_t unroll = 2;
};

template <typename T>
struct u_unary_policy<ntt::ops::square<vector<T, 8>>, vector<T, 8>, true> {
static constexpr size_t unroll = 2;
};

#define SPECIALIZE_U_UNARY(op, unroll_num) \
template <typename T> \
struct u_unary_policy<ntt::ops::op<vector<T, 8>>, vector<T, 8>, true> { \
static constexpr size_t unroll = unroll_num; \
};

SPECIALIZE_U_UNARY(abs, 2)
SPECIALIZE_U_UNARY(ceil, 2)
SPECIALIZE_U_UNARY(floor, 2)
SPECIALIZE_U_UNARY(neg, 2)
SPECIALIZE_U_UNARY(round, 2)
SPECIALIZE_U_UNARY(sign, 2)
SPECIALIZE_U_UNARY(square, 2)

#undef SPECIALIZE_U_UNARY

// binary
#define SPECIALIZE_U_BINARY(op, unroll_num) \
template <typename T1, typename T2> \
struct u_binary_policy<ntt::ops::op<vector<T1, 8>, vector<T2, 8>>, \
vector<T1, 8>, vector<T2, 8>, true> { \
static constexpr size_t unroll = unroll_num; \
}; \
\
template <typename T1, typename T2> \
struct u_binary_policy<ntt::ops::op<T1, vector<T2, 8>>, T1, vector<T2, 8>, \
true> { \
static constexpr size_t unroll = unroll_num; \
}; \
\
template <typename T1, typename T2> \
struct u_binary_policy<ntt::ops::op<vector<T1, 8>, T2>, vector<T1, 8>, T2, \
true> { \
static constexpr size_t unroll = unroll_num; \
};

SPECIALIZE_U_BINARY(add, 2)
SPECIALIZE_U_BINARY(sub, 2)
SPECIALIZE_U_BINARY(mul, 2)
SPECIALIZE_U_BINARY(div, 2)
SPECIALIZE_U_BINARY(max, 2)
SPECIALIZE_U_BINARY(min, 2)
SPECIALIZE_U_BINARY(mod, 2)
SPECIALIZE_U_BINARY(floor_mod, 2)

#undef SPECIALIZE_U_BINARY

// pack
template <size_t M, size_t N, size_t MStrides>
class u_pack<M, N, MStrides, true, float, vector<float, 8>> {
public:
Expand All @@ -75,10 +88,12 @@ class u_pack<M, N, MStrides, true, float, vector<float, 8>> {
}
};

// reduce
template <reduce_op Op, class T> struct u_reduce_policy<Op, T, true> {
static constexpr size_t unroll = 8;
};

// matmul
template <>
struct u_matmul_policy<mamtul_pack_kind::no_pack, float, float, float, true> {
static constexpr size_t m0_tile = 1;
Expand Down
Loading

0 comments on commit 1a847f4

Please sign in to comment.