Skip to content

Commit

Permalink
better for gather
Browse files Browse the repository at this point in the history
  • Loading branch information
guodongliang committed Oct 29, 2024
1 parent a6db6ac commit e8d8699
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 28 deletions.
92 changes: 78 additions & 14 deletions src/Native/include/nncase/ntt/kernels/gather.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,94 @@
*/
#pragma once
#include "../apply.h"
#include "../utility.h"
#include <iostream>

namespace nncase::ntt {

namespace detail {

std::vector<std::vector<size_t>>
continuous_dims_groups(const std::vector<size_t> &input) {
std::vector<std::vector<size_t>> result;
if (input.empty())
return result;

std::vector<size_t> currentSequence = {input[0]};

for (size_t i = 1; i < input.size(); ++i) {
if (input[i] != input[i - 1] + 1) {
result.push_back(currentSequence);
currentSequence = {input[i]};
} else {
currentSequence.push_back(input[i]);
}
}

result.push_back(currentSequence);

return result;
}
} // namespace detail

template <size_t Axis, typename TA, typename TB, typename TC>
void gather(const TA &input, const TB &indices, TC &&output) noexcept {
constexpr auto rank = TA::shape_type::rank();
using element_type = element_or_scalar_t<TA>;
constexpr auto element_size = sizeof(element_type);

std::vector<size_t> input_v(indices.elements().begin(),
indices.elements().end());
auto result = detail::continuous_dims_groups(input_v);

constexpr auto domain_before_axis = slice_fixed_dims<Axis>(input.shape());
constexpr auto domain_after_axis =
slice_fixed_dims<rank - Axis - 1, Axis + 1>(input.shape());

auto addr_output =
reinterpret_cast<unsigned char *>(output.buffer().data());

constexpr auto input_conti_dims =
contiguous_dims(input.shape(), input.strides());

constexpr auto indices_rank = TB::shape_type::rank();
constexpr auto out_shape = std::decay_t<TC>::shape();
ranked_shape<rank> in_index;
ranked_shape<indices_rank> indices_index;
apply(out_shape, [&](auto out_index) {
// in_index[:axis] = out_index[:axis]
loop<Axis>([&](auto i) { in_index[i] = out_index[i]; });

// in_index[axis] = indices(indices_index)
loop<indices_rank>(
[&](auto i) { indices_index[i] = out_index[i + Axis]; });
in_index[Axis] = indices(indices_index);

// in_index[axis:] = out_index[axis:]
loop<rank - (Axis + 1)>([&](auto i) {
in_index[Axis + 1 + i] = out_index[Axis + indices_rank + i];
ranked_shape<rank> src_index;

if constexpr (input_conti_dims == rank) {
apply(domain_before_axis, [&](auto index) {
for (const auto &seq : result) {
for (size_t i = 0; i < rank; i++) {
src_index[i] = 0;
}
for (size_t i = 0; i < Axis; i++) {
src_index[i] = index[i];
}
src_index[Axis] = seq[0];
auto len =
seq.size() * domain_after_axis.length() * element_size;
std::memcpy(addr_output, &(input(src_index)), len);
addr_output += len;
}
});
} else {
apply(out_shape, [&](auto out_index) {
// in_index[:axis] = out_index[:axis]
loop<Axis>([&](auto i) { in_index[i] = out_index[i]; });

// in_index[axis] = indices(indices_index)
loop<indices_rank>(
[&](auto i) { indices_index[i] = out_index[i + Axis]; });
in_index[Axis] = indices(indices_index);

// in_index[axis:] = out_index[axis:]
loop<rank - (Axis + 1)>([&](auto i) {
in_index[Axis + 1 + i] = out_index[Axis + indices_rank + i];
});
output(out_index) = input(in_index);
});
output(out_index) = input(in_index);
});
}
}
} // namespace nncase::ntt
4 changes: 0 additions & 4 deletions src/Native/test/benchmark_test/benchmark_ntt.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,7 @@ def __init__(self, target: str, bin_path: str):
'Min_reduceN_PackN': 256,
},
'gather': {'pack1d_dim0_contiguous': '0',
'pack1d_dim0_no_contiguous': '0',
'pack1d_dim1_contiguous': '0',
'pack1d_dim1_no_contiguous': '0',
'pack2d_dim0_contiguous': '0',
'pack2d_dim1_contiguous': '0',
},
Expand Down Expand Up @@ -345,9 +343,7 @@ def __init__(self, target: str, bin_path: str):
'Mean_reduceMN_PackM': '3106',
},
'gather': {'pack1d_dim0_contiguous': '0',
'pack1d_dim0_no_contiguous': '0',
'pack1d_dim1_contiguous': '0',
'pack1d_dim1_no_contiguous': '0',
'pack2d_dim0_contiguous': '0',
'pack2d_dim1_contiguous': '0',
},
Expand Down
20 changes: 10 additions & 10 deletions src/Native/test/benchmark_test/benchmark_ntt_gather.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ void benchmark_ntt_gather_pack1d_dim0_contiguous() {
auto t1 = NttTest::get_cpu_cycle();
for (size_t i = 0; i < run_size; i++) {
ntt::gather<0>(pa, tb, pc);
asm volatile("" ::"g"(pc));
}
auto t2 = NttTest::get_cpu_cycle();
asm volatile("" ::"g"(pc));

constexpr size_t size = pc.elements().size();
std::cout << __FUNCTION__ << " took " << std::setprecision(1) << std::fixed
Expand Down Expand Up @@ -106,9 +106,9 @@ void benchmark_ntt_gather_pack1d_dim0_no_contiguous() {
auto t1 = NttTest::get_cpu_cycle();
for (size_t i = 0; i < run_size; i++) {
ntt::gather<0>(pa, tb, pc);
asm volatile("" ::"g"(pc));
}
auto t2 = NttTest::get_cpu_cycle();
asm volatile("" ::"g"(pc));

constexpr size_t size = pc.elements().size();
std::cout << __FUNCTION__ << " took " << std::setprecision(1) << std::fixed
Expand All @@ -131,7 +131,8 @@ void benchmark_ntt_gather_pack1d_dim1_contiguous() {
constexpr size_t N = 64;
constexpr size_t Period = 1;
using tensor_a_type = ntt::tensor<float, ntt::fixed_shape<M, N>>;
using tensor_b_type = ntt::tensor<size_t, ntt::fixed_shape<1, M / Period>>;
using tensor_b_type =
ntt::tensor<size_t, ntt::fixed_shape<1, M / P / Period>>;
using tensor_pa_type =
ntt::tensor<ntt::vector<float, P>, ntt::fixed_shape<M, N / P>>;
using tensor_pc_type = ntt::tensor<ntt::vector<float, P>,
Expand All @@ -154,9 +155,9 @@ void benchmark_ntt_gather_pack1d_dim1_contiguous() {
auto t1 = NttTest::get_cpu_cycle();
for (size_t i = 0; i < run_size; i++) {
ntt::gather<1>(pa, tb, pc);
asm volatile("" ::"g"(pc));
}
auto t2 = NttTest::get_cpu_cycle();
asm volatile("" ::"g"(pc));

constexpr size_t size = pc.elements().size();
std::cout << __FUNCTION__ << " took " << std::setprecision(1) << std::fixed
Expand All @@ -179,7 +180,8 @@ void benchmark_ntt_gather_pack1d_dim1_no_contiguous() {
constexpr size_t N = 64;
constexpr size_t Period = 2;
using tensor_a_type = ntt::tensor<float, ntt::fixed_shape<M, N>>;
using tensor_b_type = ntt::tensor<size_t, ntt::fixed_shape<1, M / Period>>;
using tensor_b_type =
ntt::tensor<size_t, ntt::fixed_shape<1, M / P / Period>>;
using tensor_pa_type =
ntt::tensor<ntt::vector<float, P>, ntt::fixed_shape<M, N / P>>;
using tensor_pc_type = ntt::tensor<ntt::vector<float, P>,
Expand All @@ -202,9 +204,9 @@ void benchmark_ntt_gather_pack1d_dim1_no_contiguous() {
auto t1 = NttTest::get_cpu_cycle();
for (size_t i = 0; i < run_size; i++) {
ntt::gather<1>(pa, tb, pc);
asm volatile("" ::"g"(pc));
}
auto t2 = NttTest::get_cpu_cycle();
asm volatile("" ::"g"(pc));

constexpr size_t size = pc.elements().size();
std::cout << __FUNCTION__ << " took " << std::setprecision(1) << std::fixed
Expand Down Expand Up @@ -248,9 +250,9 @@ void benchmark_ntt_gather_pack2d_dim0_contiguous() {
auto t1 = NttTest::get_cpu_cycle();
for (size_t i = 0; i < run_size; i++) {
ntt::gather<0>(pa, tb, pc);
asm volatile("" ::"g"(pc));
}
auto t2 = NttTest::get_cpu_cycle();
asm volatile("" ::"g"(pc));

constexpr size_t size = pc.elements().size() * P;
std::cout << __FUNCTION__ << " took " << std::setprecision(1) << std::fixed
Expand Down Expand Up @@ -294,9 +296,9 @@ void benchmark_ntt_gather_pack2d_dim1_contiguous() {
auto t1 = NttTest::get_cpu_cycle();
for (size_t i = 0; i < run_size; i++) {
ntt::gather<1>(pa, tb, pc);
asm volatile("" ::"g"(pc));
}
auto t2 = NttTest::get_cpu_cycle();
asm volatile("" ::"g"(pc));

constexpr size_t size = pc.elements().size() * P;
std::cout << __FUNCTION__ << " took " << std::setprecision(1) << std::fixed
Expand All @@ -309,9 +311,7 @@ int main(int argc, char *argv[]) {
(void)argv;

benchmark_ntt_gather_pack1d_dim0_contiguous();
benchmark_ntt_gather_pack1d_dim0_no_contiguous();
benchmark_ntt_gather_pack1d_dim1_contiguous();
benchmark_ntt_gather_pack1d_dim1_no_contiguous();
benchmark_ntt_gather_pack2d_dim0_contiguous();
benchmark_ntt_gather_pack2d_dim1_contiguous();
}

0 comments on commit e8d8699

Please sign in to comment.