From 1f108e86c41b74c6ff4ba0142a4cc4f3dc626d42 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Fri, 27 Sep 2024 16:15:53 +0200 Subject: [PATCH] Add partition sorting utilities --- cpp/CMakeLists.txt | 1 + cpp/arcae/group_sort.cc | 236 +++++++++++++++++++++++++++++++++++++++ cpp/arcae/group_sort.h | 51 +++++++++ cpp/tests/CMakeLists.txt | 5 + 4 files changed, 293 insertions(+) create mode 100644 cpp/arcae/group_sort.cc create mode 100644 cpp/arcae/group_sort.h diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 49c19a29..c2532047 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -6,6 +6,7 @@ add_library(arcae SHARED arcae/configuration.cc arcae/descriptor.cc arcae/data_partition.cc + arcae/group_sort.cc arcae/isolated_table_proxy.cc arcae/new_table_proxy.cc arcae/read_impl.cc diff --git a/cpp/arcae/group_sort.cc b/cpp/arcae/group_sort.cc new file mode 100644 index 00000000..02ab4643 --- /dev/null +++ b/cpp/arcae/group_sort.cc @@ -0,0 +1,236 @@ +#include "group_sort.h" + +#include +#include +#include +#include +#include +#include + +#include "arcae/type_traits.h" + +#include "arrow/api.h" +#include "arrow/buffer.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/type_fwd.h" + +using ::arrow::AllocateBuffer; +using ::arrow::Array; +using ::arrow::Buffer; +using ::arrow::DoubleArray; +using ::arrow::Int32Array; +using ::arrow::Int64Array; +using ::arrow::Result; +using ::arrow::Status; + +using ::arcae::detail::AggregateAdapter; + +namespace arcae { + +namespace { + +static constexpr char kArrayIsNull[] = "GroupSortData array is null"; +static constexpr char kLengthMismatch[] = "GroupSortData length mismatch"; +static constexpr char kHasNulls[] = "GroupSortData has nulls"; + +} // namespace + +Result> GroupSortData::Make( + const std::vector>& groups, const std::shared_ptr& time, + const std::shared_ptr& ant1, const std::shared_ptr& ant2, + const std::shared_ptr& rows) { + if (time == nullptr || ant1 == nullptr || ant2 == nullptr || rows == nullptr) + return Status::Invalid(kArrayIsNull); + if (time->length() != ant1->length() || time->length() != ant2->length() || + time->length() != rows->length()) + return Status::Invalid(kLengthMismatch); + + if (time->type() != arrow::float64()) + return Status::Invalid("time column was not float64"); + if (ant1->type() != arrow::int32()) return Status::Invalid("ant1 column was not int32"); + if (ant2->type() != arrow::int32()) return Status::Invalid("ant2 column was not int32"); + if (rows->type() != arrow::int64()) return Status::Invalid("row column was not int64"); + + if (time->data()->MayHaveNulls()) return Status::Invalid(kHasNulls); + if (ant1->data()->MayHaveNulls()) return Status::Invalid(kHasNulls); + if (ant2->data()->MayHaveNulls()) return Status::Invalid(kHasNulls); + if (rows->data()->MayHaveNulls()) return Status::Invalid(kHasNulls); + + std::vector> groups_int32; + groups_int32.reserve(groups.size()); + + for (const auto& group : groups) { + if (group == nullptr) return Status::Invalid(kArrayIsNull); + if (time->length() != group->length()) return Status::Invalid(kLengthMismatch); + if (group->type() != arrow::int32()) + return Status::Invalid("Grouping column was not int32"); + if (group->data()->MayHaveNulls()) return Status::Invalid(kHasNulls); + groups_int32.push_back(std::dynamic_pointer_cast(group)); + } + + return std::make_shared>( + std::move(groups_int32), std::dynamic_pointer_cast(time), + std::dynamic_pointer_cast(ant1), + std::dynamic_pointer_cast(ant2), + std::dynamic_pointer_cast(rows)); +} + +Result> GroupSortData::Sort() const { + std::vector groups; + groups.reserve(groups_.size()); + for (const auto& g : groups_) groups.push_back(g->raw_values()); + auto time = time_->raw_values(); + auto ant1 = ant1_->raw_values(); + auto ant2 = ant2_->raw_values(); + auto rows = rows_->raw_values(); + auto nrow = time_->length(); + + // Generate sort indices + std::vector index(nrow); + std::iota(std::begin(index), std::end(index), 0); + std::sort(std::begin(index), std::end(index), [&](std::int64_t l, std::int64_t r) { + for (std::size_t i = 0; i < groups.size(); ++i) { + if (groups[i][l] != groups[i][r]) { + return groups[i][l] < groups[i][r]; + } + } + if (time[l] != time[r]) return time[l] < time[r]; + if (ant1[l] != ant1[r]) return ant1[l] < ant1[r]; + return ant2[l] < ant2[r]; + }); + + // Allocate output buffers + std::vector> group_buffers(groups.size()); + std::vector> group_arrays(groups.size()); + std::vector> group_spans(groups.size()); + for (std::size_t g = 0; g < groups.size(); ++g) { + ARROW_ASSIGN_OR_RAISE(group_buffers[g], AllocateBuffer(nrow * sizeof(std::int32_t))); + group_arrays[g] = std::make_shared(nrow, group_buffers[g]); + group_spans[g] = group_buffers[g]->mutable_span_as(); + } + + ARROW_ASSIGN_OR_RAISE(std::shared_ptr time_buffer, + AllocateBuffer(nrow * sizeof(double))); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr ant1_buffer, + AllocateBuffer(nrow * sizeof(std::int32_t))); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr ant2_buffer, + AllocateBuffer(nrow * sizeof(std::int32_t))); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr rows_buffer, + AllocateBuffer(nrow * sizeof(std::int64_t))); + + auto time_span = time_buffer->mutable_span_as(); + auto ant1_span = ant1_buffer->mutable_span_as(); + auto ant2_span = ant2_buffer->mutable_span_as(); + auto rows_span = rows_buffer->mutable_span_as(); + + auto DoCopy = [&index, &nrow](auto out, auto in) { + for (std::int64_t r = 0; r < nrow; ++r) out[r] = in[index[r]]; + }; + + for (std::size_t g = 0; g < groups.size(); ++g) DoCopy(group_spans[g], groups[g]); + DoCopy(time_span, time); + DoCopy(ant1_span, ant1); + DoCopy(ant2_span, ant2); + DoCopy(rows_span, rows); + + return std::make_shared>( + std::move(group_arrays), + std::make_shared(nrow, std::move(time_buffer)), + std::make_shared(nrow, std::move(ant1_buffer)), + std::make_shared(nrow, std::move(ant2_buffer)), + std::make_shared(nrow, std::move(rows_buffer))); +} + +Result> MergeGroups( + const std::vector>& group_data) { + if (group_data.empty()) + return std::make_shared>( + GroupSortData::GroupsType{}, nullptr, nullptr, nullptr, nullptr); + + struct MergeData { + std::size_t gd; + GroupSortData* group; + std::int64_t r; + + double time(std::int64_t r) const { return group->time()[r]; } + std::int32_t ant1(std::int64_t r) const { return group->ant1()[r]; } + std::int32_t ant2(std::int64_t r) const { return group->ant2()[r]; } + + bool operator<(const MergeData& rhs) const { + for (std::size_t g = 0; g < group->nGroups(); ++g) { + auto lhs_group = group->group(g)[r]; + auto rhs_group = rhs.group->group(g)[rhs.r]; + if (lhs_group != rhs_group) return lhs_group < rhs_group; + } + if (time(r) != rhs.time(rhs.r)) return time(r) < rhs.time(rhs.r); + if (ant1(r) != rhs.ant1(rhs.r)) return ant1(r) < rhs.ant1(rhs.r); + return ant2(r) < rhs.ant2(rhs.r); + } + }; + + std::int64_t nrows = 0; + // TOOD: Check for consistency across data here + auto ngroups = group_data[0]->nGroups(); + for (const auto& g : group_data) nrows += g->nRows(); + std::priority_queue queue; + + std::vector> group_buffers(ngroups); + std::vector> group_arrays(ngroups); + std::vector> group_spans(ngroups); + + for (std::size_t gd = 0; gd < group_data.size(); ++gd) { + queue.emplace(MergeData{gd, group_data[gd].get(), 0}); + } + + for (std::size_t g = 0; g < ngroups; ++g) { + ARROW_ASSIGN_OR_RAISE(group_buffers[g], AllocateBuffer(nrows * sizeof(std::int32_t))); + group_spans[g] = group_buffers[g]->mutable_span_as(); + group_arrays[g] = std::make_shared(nrows, group_buffers[g]); + } + + ARROW_ASSIGN_OR_RAISE(std::shared_ptr time_buffer, + AllocateBuffer(nrows * sizeof(double))); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr ant1_buffer, + AllocateBuffer(nrows * sizeof(std::int32_t))); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr ant2_buffer, + AllocateBuffer(nrows * sizeof(std::int32_t))); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr rows_buffer, + AllocateBuffer(nrows * sizeof(std::int64_t))); + + auto time_span = time_buffer->mutable_span_as(); + auto ant1_span = ant1_buffer->mutable_span_as(); + auto ant2_span = ant2_buffer->mutable_span_as(); + auto rows_span = rows_buffer->mutable_span_as(); + + std::int64_t row = 0; + + while (!queue.empty()) { + auto [gd, dummy, gr] = queue.top(); + const auto& top_group = group_data[gd]; + queue.pop(); + + for (std::size_t g = 0; g < ngroups; ++g) { + group_spans[g][row] = top_group->group(g)[gr]; + } + + time_span[row] = top_group->time()[gr]; + ant1_span[row] = top_group->ant1()[gr]; + ant2_span[row] = top_group->ant2()[gr]; + rows_span[row] = top_group->rows()[gr]; + ++row; + + if (gr + 1 < top_group->nRows()) { + queue.emplace(MergeData{gd, top_group.get(), gr + 1}); + } + } + + return std::make_shared>( + std::move(group_arrays), + std::make_shared(nrows, std::move(time_buffer)), + std::make_shared(nrows, std::move(ant1_buffer)), + std::make_shared(nrows, std::move(ant2_buffer)), + std::make_shared(nrows, std::move(rows_buffer))); +} + +} // namespace arcae diff --git a/cpp/arcae/group_sort.h b/cpp/arcae/group_sort.h new file mode 100644 index 00000000..18785611 --- /dev/null +++ b/cpp/arcae/group_sort.h @@ -0,0 +1,51 @@ +#ifndef ARCAE_GROUP_SORT_H +#define ARCAE_GROUP_SORT_H + +#include +#include +#include + +#include +#include +#include + +namespace arcae { + +struct GroupSortData { + using GroupsType = std::vector>; + GroupsType groups_; + std::shared_ptr time_; + std::shared_ptr ant1_; + std::shared_ptr ant2_; + std::shared_ptr rows_; + + const std::int32_t* group(int g) const { return groups_[g]->raw_values(); } + const double* time() const { return time_->raw_values(); } + const std::int32_t* ant1() const { return ant1_->raw_values(); } + const std::int32_t* ant2() const { return ant2_->raw_values(); } + const std::int64_t* rows() const { return rows_->raw_values(); } + + // Create the GroupSortData from grouping and sorting arrays + static arrow::Result> Make( + const std::vector>& groups, + const std::shared_ptr& time, + const std::shared_ptr& ant1, + const std::shared_ptr& ant2, + const std::shared_ptr& rows); + + // Number of group columns + std::size_t nGroups() const { return groups_.size(); } + + // Number of rows in the group + std::int64_t nRows() const { return rows_->length(); } + + // Sort the Group + arrow::Result> Sort() const; +}; + +arrow::Result> MergeGroups( + const std::vector>& group_data); + +} // namespace arcae + +#endif // ARCAE_GROUP_SORT_H diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 98ec69ca..4f53c8be 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -31,6 +31,10 @@ add_executable(dev_transpose_test dev_transpose_test.cc) target_link_libraries(dev_transpose_test PRIVATE GTest::gtest_main arcae absl::time absl::str_format test_utils) add_test(dev_transpose_test dev_transpose_test) +add_executable(group_sort_test group_sort_test.cc) +target_link_libraries(group_sort_test PRIVATE GTest::gtest_main arcae test_utils) +add_test(group_sort_test group_sort_test) + add_executable(new_table_proxy_test new_table_proxy_test.cc) target_link_libraries(new_table_proxy_test PRIVATE GTest::gtest_main arcae test_utils) add_test(new_table_proxy_test new_table_proxy_test) @@ -38,6 +42,7 @@ add_test(new_table_proxy_test new_table_proxy_test) set_tests_properties(result_shape_test data_partition_test dev_transpose_test + group_sort_test isolated_table_proxy_test new_table_proxy_test parallel_write_test