Skip to content

Commit

Permalink
Add partition sorting utilities
Browse files Browse the repository at this point in the history
  • Loading branch information
sjperkins committed Sep 27, 2024
1 parent 3c57c60 commit 1f108e8
Show file tree
Hide file tree
Showing 4 changed files with 293 additions and 0 deletions.
1 change: 1 addition & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ add_library(arcae SHARED
arcae/configuration.cc
arcae/descriptor.cc
arcae/data_partition.cc
arcae/group_sort.cc
arcae/isolated_table_proxy.cc
arcae/new_table_proxy.cc
arcae/read_impl.cc
Expand Down
236 changes: 236 additions & 0 deletions cpp/arcae/group_sort.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
#include "group_sort.h"

#include <algorithm>
#include <cstdint>
#include <memory>
#include <numeric>
#include <queue>
#include <vector>

#include "arcae/type_traits.h"

#include "arrow/api.h"
#include "arrow/buffer.h"
#include "arrow/result.h"
#include "arrow/status.h"
#include "arrow/type_fwd.h"

using ::arrow::AllocateBuffer;
using ::arrow::Array;
using ::arrow::Buffer;
using ::arrow::DoubleArray;
using ::arrow::Int32Array;
using ::arrow::Int64Array;
using ::arrow::Result;
using ::arrow::Status;

using ::arcae::detail::AggregateAdapter;

namespace arcae {

namespace {

static constexpr char kArrayIsNull[] = "GroupSortData array is null";
static constexpr char kLengthMismatch[] = "GroupSortData length mismatch";
static constexpr char kHasNulls[] = "GroupSortData has nulls";

} // namespace

Result<std::shared_ptr<GroupSortData>> GroupSortData::Make(
const std::vector<std::shared_ptr<Array>>& groups, const std::shared_ptr<Array>& time,
const std::shared_ptr<Array>& ant1, const std::shared_ptr<Array>& ant2,
const std::shared_ptr<Array>& rows) {
if (time == nullptr || ant1 == nullptr || ant2 == nullptr || rows == nullptr)
return Status::Invalid(kArrayIsNull);
if (time->length() != ant1->length() || time->length() != ant2->length() ||
time->length() != rows->length())
return Status::Invalid(kLengthMismatch);

if (time->type() != arrow::float64())
return Status::Invalid("time column was not float64");
if (ant1->type() != arrow::int32()) return Status::Invalid("ant1 column was not int32");
if (ant2->type() != arrow::int32()) return Status::Invalid("ant2 column was not int32");
if (rows->type() != arrow::int64()) return Status::Invalid("row column was not int64");

if (time->data()->MayHaveNulls()) return Status::Invalid(kHasNulls);
if (ant1->data()->MayHaveNulls()) return Status::Invalid(kHasNulls);
if (ant2->data()->MayHaveNulls()) return Status::Invalid(kHasNulls);
if (rows->data()->MayHaveNulls()) return Status::Invalid(kHasNulls);

std::vector<std::shared_ptr<Int32Array>> groups_int32;
groups_int32.reserve(groups.size());

for (const auto& group : groups) {
if (group == nullptr) return Status::Invalid(kArrayIsNull);
if (time->length() != group->length()) return Status::Invalid(kLengthMismatch);
if (group->type() != arrow::int32())
return Status::Invalid("Grouping column was not int32");
if (group->data()->MayHaveNulls()) return Status::Invalid(kHasNulls);
groups_int32.push_back(std::dynamic_pointer_cast<arrow::Int32Array>(group));
}

return std::make_shared<AggregateAdapter<GroupSortData>>(
std::move(groups_int32), std::dynamic_pointer_cast<DoubleArray>(time),
std::dynamic_pointer_cast<Int32Array>(ant1),
std::dynamic_pointer_cast<Int32Array>(ant2),
std::dynamic_pointer_cast<Int64Array>(rows));
}

Result<std::shared_ptr<GroupSortData>> GroupSortData::Sort() const {
std::vector<const int*> groups;
groups.reserve(groups_.size());
for (const auto& g : groups_) groups.push_back(g->raw_values());
auto time = time_->raw_values();
auto ant1 = ant1_->raw_values();
auto ant2 = ant2_->raw_values();
auto rows = rows_->raw_values();
auto nrow = time_->length();

// Generate sort indices
std::vector<int64_t> index(nrow);
std::iota(std::begin(index), std::end(index), 0);
std::sort(std::begin(index), std::end(index), [&](std::int64_t l, std::int64_t r) {
for (std::size_t i = 0; i < groups.size(); ++i) {
if (groups[i][l] != groups[i][r]) {
return groups[i][l] < groups[i][r];
}
}
if (time[l] != time[r]) return time[l] < time[r];
if (ant1[l] != ant1[r]) return ant1[l] < ant1[r];
return ant2[l] < ant2[r];
});

// Allocate output buffers
std::vector<std::shared_ptr<Buffer>> group_buffers(groups.size());
std::vector<std::shared_ptr<Int32Array>> group_arrays(groups.size());
std::vector<arrow::util::span<std::int32_t>> group_spans(groups.size());
for (std::size_t g = 0; g < groups.size(); ++g) {
ARROW_ASSIGN_OR_RAISE(group_buffers[g], AllocateBuffer(nrow * sizeof(std::int32_t)));
group_arrays[g] = std::make_shared<Int32Array>(nrow, group_buffers[g]);
group_spans[g] = group_buffers[g]->mutable_span_as<std::int32_t>();
}

ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Buffer> time_buffer,
AllocateBuffer(nrow * sizeof(double)));
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Buffer> ant1_buffer,
AllocateBuffer(nrow * sizeof(std::int32_t)));
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Buffer> ant2_buffer,
AllocateBuffer(nrow * sizeof(std::int32_t)));
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Buffer> rows_buffer,
AllocateBuffer(nrow * sizeof(std::int64_t)));

auto time_span = time_buffer->mutable_span_as<double>();
auto ant1_span = ant1_buffer->mutable_span_as<std::int32_t>();
auto ant2_span = ant2_buffer->mutable_span_as<std::int32_t>();
auto rows_span = rows_buffer->mutable_span_as<std::int64_t>();

auto DoCopy = [&index, &nrow](auto out, auto in) {
for (std::int64_t r = 0; r < nrow; ++r) out[r] = in[index[r]];
};

for (std::size_t g = 0; g < groups.size(); ++g) DoCopy(group_spans[g], groups[g]);
DoCopy(time_span, time);
DoCopy(ant1_span, ant1);
DoCopy(ant2_span, ant2);
DoCopy(rows_span, rows);

return std::make_shared<AggregateAdapter<GroupSortData>>(
std::move(group_arrays),
std::make_shared<DoubleArray>(nrow, std::move(time_buffer)),
std::make_shared<Int32Array>(nrow, std::move(ant1_buffer)),
std::make_shared<Int32Array>(nrow, std::move(ant2_buffer)),
std::make_shared<Int64Array>(nrow, std::move(rows_buffer)));
}

Result<std::shared_ptr<GroupSortData>> MergeGroups(
const std::vector<std::shared_ptr<GroupSortData>>& group_data) {
if (group_data.empty())
return std::make_shared<AggregateAdapter<GroupSortData>>(
GroupSortData::GroupsType{}, nullptr, nullptr, nullptr, nullptr);

struct MergeData {
std::size_t gd;
GroupSortData* group;
std::int64_t r;

double time(std::int64_t r) const { return group->time()[r]; }
std::int32_t ant1(std::int64_t r) const { return group->ant1()[r]; }
std::int32_t ant2(std::int64_t r) const { return group->ant2()[r]; }

bool operator<(const MergeData& rhs) const {
for (std::size_t g = 0; g < group->nGroups(); ++g) {
auto lhs_group = group->group(g)[r];
auto rhs_group = rhs.group->group(g)[rhs.r];
if (lhs_group != rhs_group) return lhs_group < rhs_group;
}
if (time(r) != rhs.time(rhs.r)) return time(r) < rhs.time(rhs.r);
if (ant1(r) != rhs.ant1(rhs.r)) return ant1(r) < rhs.ant1(rhs.r);
return ant2(r) < rhs.ant2(rhs.r);
}
};

std::int64_t nrows = 0;
// TOOD: Check for consistency across data here
auto ngroups = group_data[0]->nGroups();
for (const auto& g : group_data) nrows += g->nRows();
std::priority_queue<MergeData> queue;

std::vector<std::shared_ptr<Buffer>> group_buffers(ngroups);
std::vector<std::shared_ptr<Int32Array>> group_arrays(ngroups);
std::vector<arrow::util::span<std::int32_t>> group_spans(ngroups);

for (std::size_t gd = 0; gd < group_data.size(); ++gd) {
queue.emplace(MergeData{gd, group_data[gd].get(), 0});
}

for (std::size_t g = 0; g < ngroups; ++g) {
ARROW_ASSIGN_OR_RAISE(group_buffers[g], AllocateBuffer(nrows * sizeof(std::int32_t)));
group_spans[g] = group_buffers[g]->mutable_span_as<std::int32_t>();
group_arrays[g] = std::make_shared<Int32Array>(nrows, group_buffers[g]);
}

ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Buffer> time_buffer,
AllocateBuffer(nrows * sizeof(double)));
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Buffer> ant1_buffer,
AllocateBuffer(nrows * sizeof(std::int32_t)));
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Buffer> ant2_buffer,
AllocateBuffer(nrows * sizeof(std::int32_t)));
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Buffer> rows_buffer,
AllocateBuffer(nrows * sizeof(std::int64_t)));

auto time_span = time_buffer->mutable_span_as<double>();
auto ant1_span = ant1_buffer->mutable_span_as<std::int32_t>();
auto ant2_span = ant2_buffer->mutable_span_as<std::int32_t>();
auto rows_span = rows_buffer->mutable_span_as<std::int64_t>();

std::int64_t row = 0;

while (!queue.empty()) {
auto [gd, dummy, gr] = queue.top();
const auto& top_group = group_data[gd];
queue.pop();

for (std::size_t g = 0; g < ngroups; ++g) {
group_spans[g][row] = top_group->group(g)[gr];
}

time_span[row] = top_group->time()[gr];
ant1_span[row] = top_group->ant1()[gr];
ant2_span[row] = top_group->ant2()[gr];
rows_span[row] = top_group->rows()[gr];
++row;

if (gr + 1 < top_group->nRows()) {
queue.emplace(MergeData{gd, top_group.get(), gr + 1});
}
}

return std::make_shared<AggregateAdapter<GroupSortData>>(
std::move(group_arrays),
std::make_shared<DoubleArray>(nrows, std::move(time_buffer)),
std::make_shared<Int32Array>(nrows, std::move(ant1_buffer)),
std::make_shared<Int32Array>(nrows, std::move(ant2_buffer)),
std::make_shared<Int64Array>(nrows, std::move(rows_buffer)));
}

} // namespace arcae
51 changes: 51 additions & 0 deletions cpp/arcae/group_sort.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#ifndef ARCAE_GROUP_SORT_H
#define ARCAE_GROUP_SORT_H

#include <cstdint>
#include <memory>
#include <vector>

#include <arrow/array.h>
#include <arrow/status.h>
#include <arrow/type_fwd.h>

namespace arcae {

struct GroupSortData {
using GroupsType = std::vector<std::shared_ptr<arrow::Int32Array>>;
GroupsType groups_;
std::shared_ptr<arrow::DoubleArray> time_;
std::shared_ptr<arrow::Int32Array> ant1_;
std::shared_ptr<arrow::Int32Array> ant2_;
std::shared_ptr<arrow::Int64Array> rows_;

const std::int32_t* group(int g) const { return groups_[g]->raw_values(); }
const double* time() const { return time_->raw_values(); }
const std::int32_t* ant1() const { return ant1_->raw_values(); }
const std::int32_t* ant2() const { return ant2_->raw_values(); }
const std::int64_t* rows() const { return rows_->raw_values(); }

// Create the GroupSortData from grouping and sorting arrays
static arrow::Result<std::shared_ptr<GroupSortData>> Make(
const std::vector<std::shared_ptr<arrow::Array>>& groups,
const std::shared_ptr<arrow::Array>& time,
const std::shared_ptr<arrow::Array>& ant1,
const std::shared_ptr<arrow::Array>& ant2,
const std::shared_ptr<arrow::Array>& rows);

// Number of group columns
std::size_t nGroups() const { return groups_.size(); }

// Number of rows in the group
std::int64_t nRows() const { return rows_->length(); }

// Sort the Group
arrow::Result<std::shared_ptr<GroupSortData>> Sort() const;
};

arrow::Result<std::shared_ptr<GroupSortData>> MergeGroups(
const std::vector<std::shared_ptr<GroupSortData>>& group_data);

} // namespace arcae

#endif // ARCAE_GROUP_SORT_H
5 changes: 5 additions & 0 deletions cpp/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,18 @@ add_executable(dev_transpose_test dev_transpose_test.cc)
target_link_libraries(dev_transpose_test PRIVATE GTest::gtest_main arcae absl::time absl::str_format test_utils)
add_test(dev_transpose_test dev_transpose_test)

add_executable(group_sort_test group_sort_test.cc)
target_link_libraries(group_sort_test PRIVATE GTest::gtest_main arcae test_utils)
add_test(group_sort_test group_sort_test)

add_executable(new_table_proxy_test new_table_proxy_test.cc)
target_link_libraries(new_table_proxy_test PRIVATE GTest::gtest_main arcae test_utils)
add_test(new_table_proxy_test new_table_proxy_test)

set_tests_properties(result_shape_test
data_partition_test
dev_transpose_test
group_sort_test
isolated_table_proxy_test
new_table_proxy_test
parallel_write_test
Expand Down

0 comments on commit 1f108e8

Please sign in to comment.