Skip to content

Commit

Permalink
Review fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
benjaminwinger committed Jan 21, 2025
1 parent d6da832 commit d8e8f69
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 22 deletions.
16 changes: 4 additions & 12 deletions src/include/processor/operator/aggregate/hash_aggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,9 @@ struct HashAggregateInfo {
class HashAggregateSharedState final : public BaseAggregateSharedState {

public:
explicit HashAggregateSharedState(main::ClientContext* context, HashAggregateInfo hashInfo,
explicit HashAggregateSharedState(main::ClientContext* context, HashAggregateInfo aggInfo,
const std::vector<function::AggregateFunction>& aggregateFunctions);

void initPartitions(main::ClientContext* context,
const std::vector<common::LogicalType>& keyDataTypes,
const std::vector<common::LogicalType>& payloadDataTypes,
const std::vector<common::LogicalType>& types);

~HashAggregateSharedState();

void appendTuple(std::span<uint8_t> tuple, common::hash_t hash);
Expand All @@ -64,9 +59,6 @@ class HashAggregateSharedState final : public BaseAggregateSharedState {

uint64_t getCurrentOffset() const { return currentOffset; }

// return whether limitNumber is exceeded
bool increaseAndCheckLimitCount(uint64_t num);

void setLimitNumber(uint64_t num) { limitNumber = num; }
uint64_t getLimitNumber() const { return limitNumber; }

Expand All @@ -81,14 +73,14 @@ class HashAggregateSharedState final : public BaseAggregateSharedState {

void assertFinalized() const;

const HashAggregateInfo& getInfo() const { return hashInfo; }
const HashAggregateInfo& getAggregateInfo() const { return aggInfo; }

protected:
std::tuple<const FactorizedTable*, common::offset_t> getPartitionForOffset(
common::offset_t offset) const;

public:
HashAggregateInfo hashInfo;
HashAggregateInfo aggInfo;
common::MPSCQueue<std::unique_ptr<common::InMemOverflowBuffer>> overflow;
struct Partition {
std::unique_ptr<AggregateHashTable> hashTable;
Expand All @@ -103,7 +95,7 @@ class HashAggregateSharedState final : public BaseAggregateSharedState {
table.resize(table.getNumTuplesPerBlock());
}
// numTuplesReserved may be greater than the capacity of the factorizedTable
// if threads try to write to it it while a new block is being allocated
// if threads try to write to it while a new block is being allocated
// So it should not be relied on for anything other than reserving tuples
std::atomic<uint64_t> numTuplesReserved;
// Set after the tuple has been written to the block.
Expand Down
14 changes: 7 additions & 7 deletions src/processor/operator/aggregate/hash_aggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,20 @@ std::string HashAggregatePrintInfo::toString() const {
return result;
}
HashAggregateSharedState::HashAggregateSharedState(main::ClientContext* context,
HashAggregateInfo hashInfo, const std::vector<function::AggregateFunction>& aggregateFunctions)
: BaseAggregateSharedState{aggregateFunctions}, hashInfo{std::move(hashInfo)},
HashAggregateInfo aggInfo, const std::vector<function::AggregateFunction>& aggregateFunctions)
: BaseAggregateSharedState{aggregateFunctions}, aggInfo{std::move(aggInfo)},
globalPartitions{static_cast<size_t>(context->getMaxNumThreadForExec())},
limitNumber{common::INVALID_LIMIT}, numThreads{0},
memoryManager{context->getMemoryManager()} {

// When copying directly into factorizedTables the table's schema's internal mayContainNulls
// won't be updated and it's probably less work to just always check nulls
for (size_t i = 0; i < this->hashInfo.tableSchema.getNumColumns(); i++) {
this->hashInfo.tableSchema.setMayContainsNullsToTrue(i);
for (size_t i = 0; i < this->aggInfo.tableSchema.getNumColumns(); i++) {
this->aggInfo.tableSchema.setMayContainsNullsToTrue(i);
}
for (auto& partition : globalPartitions) {
partition.headBlock = new Partition::TupleBlock(context->getMemoryManager(),
this->hashInfo.tableSchema.copy());
this->aggInfo.tableSchema.copy());
}
}

Expand Down Expand Up @@ -88,7 +88,7 @@ HashAggregateInfo::HashAggregateInfo(const HashAggregateInfo& other)
void HashAggregateLocalState::init(HashAggregateSharedState* sharedState, ResultSet& resultSet,
main::ClientContext* context, std::vector<function::AggregateFunction>& aggregateFunctions,
std::vector<common::LogicalType> types) {
auto& info = sharedState->getInfo();
auto& info = sharedState->getAggregateInfo();
std::vector<LogicalType> keyDataTypes;
for (auto& pos : info.flatKeysPos) {
auto vector = resultSet.getValueVector(pos).get();
Expand Down Expand Up @@ -173,7 +173,7 @@ void HashAggregateSharedState::appendTuple(std::span<uint8_t> tuple, common::has
return;
} else {
// No more space in the block, allocate and replace it
auto* newBlock = new Partition::TupleBlock(memoryManager, hashInfo.tableSchema.copy());
auto* newBlock = new Partition::TupleBlock(memoryManager, aggInfo.tableSchema.copy());
if (partition.headBlock.compare_exchange_strong(block, newBlock)) {
// TODO(bmwinger): if the queuedTuples has at least a certain size (benchmark to see
// if there's a benefit to waiting for multiple blocks) then cycle through the queue
Expand Down
7 changes: 4 additions & 3 deletions test/test_files/agg/hash_large.test
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
---- 1
3373
-STATEMENT MATCH (a:account)-[]->(b:account) WHERE a.ID > 1000000 RETURN a.ID, COUNT(b) as n ORDER BY n DESC LIMIT 5;
# TODO(bmwinger): Ordering is wrong here
-CHECK_ORDER
---- 5
18776017|2272
3359851|3373
5442012|2204
59804598|2467
7860742|2458
18776017|2272
5442012|2204
-STATEMENT MATCH (a:account)-[]->(b:account) WHERE a.ID > 1000000 RETURN a.ID, COUNT(b.ID) as n ORDER BY n, a.ID;
-CHECK_ORDER
---- 68949
<FILE>:hash_large.csv

0 comments on commit d8e8f69

Please sign in to comment.