diff --git a/src/Native/include/nncase/ntt/arch/riscv64/ukernels.h b/src/Native/include/nncase/ntt/arch/riscv64/ukernels.h index 61aa9a56c..a48671ce2 100644 --- a/src/Native/include/nncase/ntt/arch/riscv64/ukernels.h +++ b/src/Native/include/nncase/ntt/arch/riscv64/ukernels.h @@ -23,54 +23,68 @@ namespace nncase::ntt::ukernels { // unary -template -struct u_unary_policy>, - vector, true> { - static constexpr size_t unroll = 8; -}; - -template -struct u_unary_policy>, - vector, true> { - static constexpr size_t unroll = 8; -}; - -template -struct u_unary_policy>, - vector, true> { - static constexpr size_t unroll = 8; -}; - -template -struct u_unary_policy>, - vector, true> { - static constexpr size_t unroll = 8; -}; - -template -struct u_unary_policy>, - vector, true> { - static constexpr size_t unroll = 8; -}; - -template -struct u_unary_policy>, - vector, true> { - static constexpr size_t unroll = 8; -}; - -template -struct u_unary_policy>, - vector, true> { - static constexpr size_t unroll = 8; -}; - +#define SPECIALIZE_U_UNARY(op, unroll_num) \ + template \ + struct u_unary_policy>, \ + vector, true> { \ + static constexpr size_t unroll = unroll_num; \ + }; + +SPECIALIZE_U_UNARY(abs, 8) +SPECIALIZE_U_UNARY(ceil, 8) +SPECIALIZE_U_UNARY(floor, 8) +SPECIALIZE_U_UNARY(neg, 8) +SPECIALIZE_U_UNARY(round, 8) +SPECIALIZE_U_UNARY(sign, 8) +SPECIALIZE_U_UNARY(square, 8) + +#undef SPECIALIZE_U_UNARY + +// binary +#define SPECIALIZE_U_BINARY(op, unroll_num) \ + template \ + struct u_binary_policy< \ + ntt::ops::op, \ + vector>, \ + vector, \ + vector, true> { \ + static constexpr size_t unroll = unroll_num; \ + }; \ + \ + template \ + struct u_binary_policy< \ + ntt::ops::op>, T1, \ + vector, true> { \ + static constexpr size_t unroll = unroll_num; \ + }; \ + \ + template \ + struct u_binary_policy< \ + ntt::ops::op, T2>, \ + vector, T2, true> { \ + static constexpr size_t unroll = unroll_num; \ + }; + +SPECIALIZE_U_BINARY(add, 8) +SPECIALIZE_U_BINARY(sub, 8) +SPECIALIZE_U_BINARY(mul, 8) +SPECIALIZE_U_BINARY(div, 8) +SPECIALIZE_U_BINARY(max, 8) +SPECIALIZE_U_BINARY(min, 8) +SPECIALIZE_U_BINARY(mod, 8) +SPECIALIZE_U_BINARY(floor_mod, 8) + +#undef SPECIALIZE_U_BINARY + +// reduce template struct u_reduce_policy { static constexpr size_t unroll = 8; }; +// cast template <> struct u_cast_policy { static constexpr size_t unroll = 8; }; +// matmul template <> struct u_matmul_policy { static constexpr size_t m0_tile = 1; diff --git a/src/Native/include/nncase/ntt/arch/x86_64/ukernels.h b/src/Native/include/nncase/ntt/arch/x86_64/ukernels.h index 61431fde0..3b3a7a3b6 100644 --- a/src/Native/include/nncase/ntt/arch/x86_64/ukernels.h +++ b/src/Native/include/nncase/ntt/arch/x86_64/ukernels.h @@ -19,41 +19,54 @@ namespace nncase::ntt::ukernels { // unary -template -struct u_unary_policy>, vector, true> { - static constexpr size_t unroll = 2; -}; - -template -struct u_unary_policy>, vector, true> { - static constexpr size_t unroll = 2; -}; - -template -struct u_unary_policy>, vector, true> { - static constexpr size_t unroll = 2; -}; - -template -struct u_unary_policy>, vector, true> { - static constexpr size_t unroll = 2; -}; - -template -struct u_unary_policy>, vector, true> { - static constexpr size_t unroll = 2; -}; - -template -struct u_unary_policy>, vector, true> { - static constexpr size_t unroll = 2; -}; - -template -struct u_unary_policy>, vector, true> { - static constexpr size_t unroll = 2; -}; - +#define SPECIALIZE_U_UNARY(op, unroll_num) \ + template \ + struct u_unary_policy>, vector, true> { \ + static constexpr size_t unroll = unroll_num; \ + }; + +SPECIALIZE_U_UNARY(abs, 2) +SPECIALIZE_U_UNARY(ceil, 2) +SPECIALIZE_U_UNARY(floor, 2) +SPECIALIZE_U_UNARY(neg, 2) +SPECIALIZE_U_UNARY(round, 2) +SPECIALIZE_U_UNARY(sign, 2) +SPECIALIZE_U_UNARY(square, 2) + +#undef SPECIALIZE_U_UNARY + +// binary +#define SPECIALIZE_U_BINARY(op, unroll_num) \ + template \ + struct u_binary_policy, vector>, \ + vector, vector, true> { \ + static constexpr size_t unroll = unroll_num; \ + }; \ + \ + template \ + struct u_binary_policy>, T1, vector, \ + true> { \ + static constexpr size_t unroll = unroll_num; \ + }; \ + \ + template \ + struct u_binary_policy, T2>, vector, T2, \ + true> { \ + static constexpr size_t unroll = unroll_num; \ + }; + +SPECIALIZE_U_BINARY(add, 2) +SPECIALIZE_U_BINARY(sub, 2) +SPECIALIZE_U_BINARY(mul, 2) +SPECIALIZE_U_BINARY(div, 2) +SPECIALIZE_U_BINARY(max, 2) +SPECIALIZE_U_BINARY(min, 2) +SPECIALIZE_U_BINARY(mod, 2) +SPECIALIZE_U_BINARY(floor_mod, 2) + +#undef SPECIALIZE_U_BINARY + +// pack template class u_pack> { public: @@ -75,10 +88,12 @@ class u_pack> { } }; +// reduce template struct u_reduce_policy { static constexpr size_t unroll = 8; }; +// matmul template <> struct u_matmul_policy { static constexpr size_t m0_tile = 1; diff --git a/src/Native/include/nncase/ntt/kernels/binary.h b/src/Native/include/nncase/ntt/kernels/binary.h index 409002753..2614c048c 100644 --- a/src/Native/include/nncase/ntt/kernels/binary.h +++ b/src/Native/include/nncase/ntt/kernels/binary.h @@ -76,14 +76,14 @@ class binary_impl { // 1.1 Non broadcast if constexpr (is_same_seq(lhs_rest_dims, rhs_rest_dims)) { - return binary_non_broadcast(op, lhs_p, rhs_p, out_p, - lhs_rest_dims.length()); + return binary_non_broadcast(lhs_p, rhs_p, out_p, + lhs_rest_dims.length()); } else if constexpr (lhs_rest_dims.length() == 1) { - return binary_left_broadcast(op, *lhs_p, rhs_p, out_p, - rhs_rest_dims.length()); + return binary_left_broadcast(*lhs_p, rhs_p, out_p, + rhs_rest_dims.length()); } else if constexpr (rhs_rest_dims.length() == 1) { - return binary_right_broadcast(op, lhs_p, *rhs_p, out_p, - lhs_rest_dims.length()); + return binary_right_broadcast(lhs_p, *rhs_p, out_p, + lhs_rest_dims.length()); } } @@ -102,153 +102,26 @@ class binary_impl { } template - void binary_non_broadcast(Op &op, const TLhsElem *lhs, const TRhsElem *rhs, + void binary_non_broadcast(const TLhsElem *lhs, const TRhsElem *rhs, TOutElem *output, size_t extent) { - for (size_t i = 0; i < extent; i++) { - *output++ = op(*lhs++, *rhs++); - } + ntt::u_binary(lhs, 1, rhs, 1, output, + 1, extent); } template - void binary_left_broadcast(Op &op, const TLhsElem &lhs, const TRhsElem *rhs, + void binary_left_broadcast(const TLhsElem &lhs, const TRhsElem *rhs, TOutElem *output, size_t extent) { - for (size_t i = 0; i < extent; i++) { - *output++ = op(lhs, *rhs++); - } + ntt::u_binary(lhs, 0, rhs, 1, output, + 1, extent); } template - void binary_right_broadcast(Op &op, const TLhsElem *lhs, - const TRhsElem &rhs, TOutElem *output, - size_t extent) { - for (size_t i = 0; i < extent; i++) { - *output++ = op(*lhs++, rhs); - } + void binary_right_broadcast(const TLhsElem *lhs, const TRhsElem &rhs, + TOutElem *output, size_t extent) { + ntt::u_binary(lhs, 1, rhs, 0, output, + 1, extent); } }; - -#define BINARY_IMPL(OP) \ - template \ - class OP##_impl; \ - template \ - class OP##_impl, fixed_strides, \ - fixed_strides, \ - fixed_strides> { \ - public: \ - template \ - constexpr void operator()(const TIn1 &input1, const TIn2 &input2, \ - TOut &output) { \ - constexpr size_t rank = sizeof...(Dims); \ - ranked_shape index{}; \ - constexpr auto conti_dims = \ - std::min(contiguous_dims(fixed_shape{}, \ - fixed_strides{}), \ - contiguous_dims(fixed_shape{}, \ - fixed_strides{})); \ - apply( \ - index, input1, input2, output); \ - } \ - \ - private: \ - template \ - constexpr void apply(ranked_shape &index, const TIn1 &input1, \ - const TIn2 &input2, TOut &output) { \ - if constexpr (ContiguousDims == sizeof...(RestDims)) { \ - constexpr auto inner_size = \ - fixed_shape::length(); \ - auto input1_p = input1.elements().data() + \ - linear_offset(index, input1.strides()); \ - auto input2_p = input2.elements().data() + \ - linear_offset(index, input2.strides()); \ - auto output_p = output.elements().data() + \ - linear_offset(index, output.strides()); \ - OP##_contiguous(input1_p, input2_p, output_p); \ - } else { \ - apply_next(index, input1, input2, output); \ - } \ - } \ - \ - template \ - constexpr void apply_next(ranked_shape &index, \ - const TIn1 &input1, const TIn2 &input2, \ - TOut &output) { \ - for (index[Axis] = 0; index[Axis] < Dim; index[Axis]++) { \ - apply(index, input1, input2, output); \ - } \ - } \ - template \ - constexpr void OP##_contiguous(const T1 *input1, const T2 *input2, \ - TOut *output) { \ - ntt::u_##OP(input1, input2, 1, 1, output, 1, Extent); \ - } \ - }; \ - \ - template \ - class OP##_impl, In1Strides, In2Strides, OutStrides> { \ - public: \ - template \ - constexpr void operator()(const TIn1 &input1, const TIn2 &input2, \ - TOut &output) { \ - ranked_shape index{}; \ - auto conti_dims = \ - std::min(contiguous_dims(input1.shape(), input1.strides()), \ - contiguous_dims(input1.shape(), output.strides())); \ - apply(index, conti_dims, input1, input2, \ - output); \ - } \ - \ - private: \ - template \ - constexpr void apply(ranked_shape &index, size_t conti_dims, \ - const TIn1 &input1, const TIn2 &input2, \ - TOut &output) { \ - const auto outer_dims = Rank - conti_dims; \ - if (Axis >= outer_dims) { \ - size_t inner_size = 1; \ - for (size_t i = outer_dims; i < input1.shape().rank(); i++) \ - inner_size *= input1.shape()[i]; \ - auto input1_p = input1.buffer().data() + \ - linear_offset(index, input1.strides()); \ - auto input2_p = input2.buffer().data() + \ - linear_offset(index, input2.strides()); \ - auto output_p = output.buffer().data() + \ - linear_offset(index, output.strides()); \ - OP##_contiguous(input1_p, input2_p, output_p, inner_size); \ - } else if constexpr (Axis < Rank - 1) { \ - const auto dim = input1.shape()[Axis]; \ - for (index[Axis] = 0; index[Axis] < dim; index[Axis]++) { \ - apply(index, conti_dims, \ - input1, input2, output); \ - } \ - } \ - } \ - \ - template \ - constexpr void OP##_contiguous(const T1 *input1_p, const T2 *input2_p, \ - TOut *output_p, size_t extent) { \ - for (size_t i = 0; i < extent; i++) { \ - output_p[i] = \ - ntt::ops::OP()(input1_p[i], input2_p[i]); \ - } \ - } \ - }; - -BINARY_IMPL(add) -BINARY_IMPL(div) -BINARY_IMPL(max) -BINARY_IMPL(min) -BINARY_IMPL(mod) -BINARY_IMPL(mul) -BINARY_IMPL(sub) - } // namespace detail template