Skip to content

Commit

Permalink
Refactor RJOutputWriter (#4692)
Browse files Browse the repository at this point in the history
* Refactor RJOutputWriter

* Run clang-format

---------

Co-authored-by: CI Bot <[email protected]>
  • Loading branch information
andyfengHKU and andyfengHKU authored Jan 9, 2025
1 parent 2e07232 commit 8d69e48
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 70 deletions.
41 changes: 16 additions & 25 deletions src/function/gds/all_shortest_paths.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class PathMultiplicities {
using multiplicity_entry_t = std::atomic<uint64_t>;

public:
PathMultiplicities(std::unordered_map<common::table_id_t, uint64_t> numNodesMap,
PathMultiplicities(std::unordered_map<table_id_t, uint64_t> numNodesMap,
storage::MemoryManager* mm) {
for (auto& [tableID, numNodes] : numNodesMap) {
multiplicityArray.allocate(tableID, numNodes, mm);
Expand All @@ -37,33 +37,23 @@ class PathMultiplicities {
}
}

// Warning: This function should be called in a single threaded phase. That is because
// it fixes the target node table. This should not be done at a part of the computation when
// worker threads might be incrementing multiplicity using the incrementTargetMultiplicity
// function that assumes curTargetMultiplicities is already fixed to something.
void incrementMultiplicity(common::nodeID_t nodeID, uint64_t multiplicity) {
fixTargetNodeTable(nodeID.tableID);
auto curPtr = getCurTargetMultiplicities();
curPtr[nodeID.offset].fetch_add(multiplicity);
}

void incrementTargetMultiplicity(common::offset_t offset, uint64_t multiplicity) {
void incrementTargetMultiplicity(offset_t offset, uint64_t multiplicity) {
getCurTargetMultiplicities()[offset].fetch_add(multiplicity);
}

uint64_t getBoundMultiplicity(common::offset_t nodeOffset) {
uint64_t getBoundMultiplicity(offset_t nodeOffset) {
return getCurBoundMultiplicities()[nodeOffset].load(std::memory_order_relaxed);
}

uint64_t getTargetMultiplicity(common::offset_t nodeOffset) {
uint64_t getTargetMultiplicity(offset_t nodeOffset) {
return getCurTargetMultiplicities()[nodeOffset].load(std::memory_order_relaxed);
}

void fixBoundNodeTable(common::table_id_t tableID) {
void pinBoundTable(table_id_t tableID) {
curBoundMultiplicities.store(multiplicityArray.getData(tableID), std::memory_order_relaxed);
}

void fixTargetNodeTable(common::table_id_t tableID) {
void pinTargetTable(table_id_t tableID) {
curTargetMultiplicities.store(multiplicityArray.getData(tableID),
std::memory_order_relaxed);
}
Expand All @@ -90,28 +80,29 @@ class PathMultiplicities {
std::atomic<multiplicity_entry_t*> curBoundMultiplicities;
};

struct AllSPDestinationsOutputs : public SPOutputs {
struct AllSPDestinationsOutputs : public SPDestinationOutputs {
public:
AllSPDestinationsOutputs(nodeID_t sourceNodeID, std::shared_ptr<PathLengths> pathLengths,
std::shared_ptr<PathMultiplicities> multiplicities)
: SPOutputs{sourceNodeID, std::move(pathLengths)},
: SPDestinationOutputs{sourceNodeID, std::move(pathLengths)},
multiplicities{std::move(multiplicities)} {}

void initRJFromSource(nodeID_t source) override {
multiplicities->incrementMultiplicity(source, 1);
multiplicities->pinTargetTable(source.tableID);
multiplicities->incrementTargetMultiplicity(source.offset, 1);
}

void beginFrontierComputeBetweenTables(table_id_t curFrontierTableID,
table_id_t nextFrontierTableID) override {
// Note: We do not fix the node table for pathLengths, which is inherited from AllSPOutputs.
// See the comment in SingleSPOutputs::beginFrontierComputeBetweenTables() for details.
multiplicities->fixBoundNodeTable(curFrontierTableID);
multiplicities->fixTargetNodeTable(nextFrontierTableID);
multiplicities->pinBoundTable(curFrontierTableID);
multiplicities->pinTargetTable(nextFrontierTableID);
};

void beginWritingOutputsForDstNodesInTable(table_id_t tableID) override {
pathLengths->pinCurFrontierTableID(tableID);
multiplicities->fixTargetNodeTable(tableID);
multiplicities->pinTargetTable(tableID);
}

public:
Expand Down Expand Up @@ -253,7 +244,7 @@ class AllSPDestinationsAlgorithm final : public SPAlgorithm {
std::make_unique<AllSPDestinationsOutputs>(sourceNodeID, frontier, multiplicities);
auto outputWriter = std::make_unique<AllSPDestinationsOutputWriter>(clientContext,
output.get(), sharedState->getOutputNodeMaskMap());
auto frontierPair = std::make_unique<SinglePathLengthsFrontierPair>(output->pathLengths);
auto frontierPair = std::make_unique<SinglePathLengthsFrontierPair>(frontier);
auto edgeCompute = std::make_unique<AllSPDestinationsEdgeCompute>(frontierPair.get(),
output->multiplicities.get());
return RJCompState(std::move(frontierPair), std::move(edgeCompute),
Expand Down Expand Up @@ -283,13 +274,13 @@ class AllSPPathsAlgorithm final : public SPAlgorithm {
auto clientContext = context->clientContext;
auto frontier = getPathLengthsFrontier(context, PathLengths::UNVISITED);
auto bfsGraph = getBFSGraph(context);
auto output = std::make_unique<PathsOutputs>(sourceNodeID, frontier, std::move(bfsGraph));
auto output = std::make_unique<PathsOutputs>(sourceNodeID, std::move(bfsGraph));
auto rjBindData = bindData->ptrCast<RJBindData>();
auto writerInfo = rjBindData->getPathWriterInfo();
writerInfo.pathNodeMask = sharedState->getPathNodeMaskMap();
auto outputWriter = std::make_unique<SPPathsOutputWriter>(clientContext, output.get(),
sharedState->getOutputNodeMaskMap(), writerInfo);
auto frontierPair = std::make_unique<SinglePathLengthsFrontierPair>(output->pathLengths);
auto frontierPair = std::make_unique<SinglePathLengthsFrontierPair>(frontier);
auto edgeCompute =
std::make_unique<AllSPPathsEdgeCompute>(frontierPair.get(), output->bfsGraph.get());
return RJCompState(std::move(frontierPair), std::move(edgeCompute),
Expand Down
10 changes: 3 additions & 7 deletions src/function/gds/output_writer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,6 @@ using namespace kuzu::processor;
namespace kuzu {
namespace function {

void PathsOutputs::beginWritingOutputsForDstNodesInTable(common::table_id_t tableID) {
pathLengths->pinCurFrontierTableID(tableID);
bfsGraph->pinTableID(tableID);
}

std::unique_ptr<common::ValueVector> GDSOutputWriter::createVector(const LogicalType& type,
storage::MemoryManager* mm) {
auto vector = std::make_unique<ValueVector>(type.copy(), mm);
Expand Down Expand Up @@ -392,7 +387,8 @@ DestinationsOutputWriter::DestinationsOutputWriter(main::ClientContext* context,
void DestinationsOutputWriter::write(processor::FactorizedTable& fTable, nodeID_t dstNodeID,
GDSOutputCounter* counter) {
auto length =
rjOutputs->ptrCast<SPOutputs>()->pathLengths->getMaskValueFromCurFrontier(dstNodeID.offset);
rjOutputs->ptrCast<SPDestinationOutputs>()->pathLengths->getMaskValueFromCurFrontier(
dstNodeID.offset);
dstNodeIDVector->setValue<nodeID_t>(0, dstNodeID);
setLength(lengthVector.get(), length);
fTable.append(vectors);
Expand All @@ -403,7 +399,7 @@ void DestinationsOutputWriter::write(processor::FactorizedTable& fTable, nodeID_
}

bool DestinationsOutputWriter::skipInternal(common::nodeID_t dstNodeID) const {
auto outputs = rjOutputs->ptrCast<SPOutputs>();
auto outputs = rjOutputs->ptrCast<SPDestinationOutputs>();
return dstNodeID == outputs->sourceNodeID || outputs->pathLengths->getMaskValueFromCurFrontier(
dstNodeID.offset) == PathLengths::UNVISITED;
}
Expand Down
19 changes: 3 additions & 16 deletions src/function/gds/single_shortest_paths.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,6 @@ using namespace kuzu::graph;
namespace kuzu {
namespace function {

struct SingleSPDestinationsOutputs : public SPOutputs {
SingleSPDestinationsOutputs(nodeID_t sourceNodeID, std::shared_ptr<PathLengths> pathLengths)
: SPOutputs(sourceNodeID, pathLengths) {}
// Note: We do not fix the node table for pathLengths, because PathLengths is a
// FrontierPair implementation and RJCompState will call beginFrontierComputeBetweenTables
// on FrontierPair (and RJOutputs). That's why this function is empty.
void beginFrontierComputeBetweenTables(table_id_t, table_id_t) override{};

void beginWritingOutputsForDstNodesInTable(table_id_t tableID) override {
pathLengths->pinCurFrontierTableID(tableID);
}
};

class SingleSPDestinationsEdgeCompute : public SPEdgeCompute {
public:
explicit SingleSPDestinationsEdgeCompute(SinglePathLengthsFrontierPair* frontierPair)
Expand Down Expand Up @@ -106,7 +93,7 @@ class SingleSPDestinationsAlgorithm : public SPAlgorithm {
RJCompState getRJCompState(ExecutionContext* context, nodeID_t sourceNodeID) override {
auto clientContext = context->clientContext;
auto frontier = getPathLengthsFrontier(context, PathLengths::UNVISITED);
auto output = std::make_unique<SingleSPDestinationsOutputs>(sourceNodeID, frontier);
auto output = std::make_unique<SPDestinationOutputs>(sourceNodeID, frontier);
auto outputWriter = std::make_unique<DestinationsOutputWriter>(clientContext, output.get(),
sharedState->getOutputNodeMaskMap());
auto frontierPair = std::make_unique<SinglePathLengthsFrontierPair>(output->pathLengths);
Expand Down Expand Up @@ -138,13 +125,13 @@ class SingleSPPathsAlgorithm : public SPAlgorithm {
auto clientContext = context->clientContext;
auto frontier = getPathLengthsFrontier(context, PathLengths::UNVISITED);
auto bfsGraph = getBFSGraph(context);
auto output = std::make_unique<PathsOutputs>(sourceNodeID, frontier, std::move(bfsGraph));
auto output = std::make_unique<PathsOutputs>(sourceNodeID, std::move(bfsGraph));
auto rjBindData = bindData->ptrCast<RJBindData>();
auto writerInfo = rjBindData->getPathWriterInfo();
writerInfo.pathNodeMask = sharedState->getPathNodeMaskMap();
auto outputWriter = std::make_unique<SPPathsOutputWriter>(clientContext, output.get(),
sharedState->getOutputNodeMaskMap(), writerInfo);
auto frontierPair = std::make_unique<SinglePathLengthsFrontierPair>(output->pathLengths);
auto frontierPair = std::make_unique<SinglePathLengthsFrontierPair>(frontier);
auto edgeCompute =
std::make_unique<SingleSPPathsEdgeCompute>(frontierPair.get(), output->bfsGraph.get());
return RJCompState(std::move(frontierPair), std::move(edgeCompute),
Expand Down
2 changes: 1 addition & 1 deletion src/function/gds/variable_length_path.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class VarLenJoinsAlgorithm final : public RJAlgorithm {
auto clientContext = context->clientContext;
auto frontier = getPathLengthsFrontier(context, PathLengths::UNVISITED);
auto bfsGraph = getBFSGraph(context);
auto output = std::make_unique<PathsOutputs>(sourceNodeID, frontier, std::move(bfsGraph));
auto output = std::make_unique<PathsOutputs>(sourceNodeID, std::move(bfsGraph));
auto rjBindData = bindData->ptrCast<RJBindData>();
auto writerInfo = rjBindData->getPathWriterInfo();
writerInfo.pathNodeMask = sharedState->getPathNodeMaskMap();
Expand Down
44 changes: 23 additions & 21 deletions src/include/function/gds/output_writer.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
#include "bfs_graph.h"
#include "common/enums/path_semantic.h"
#include "common/types/types.h"
#include "gds_frontier.h"
#include "processor/operator/gds_call_shared_state.h"
#include "processor/result/factorized_table.h"

namespace kuzu {
namespace function {

class PathLengths;

struct RJOutputs {
public:
common::nodeID_t sourceNodeID;

explicit RJOutputs(common::nodeID_t sourceNodeID) : sourceNodeID{sourceNodeID} {}
virtual ~RJOutputs() = default;

Expand All @@ -26,36 +26,38 @@ struct RJOutputs {
TARGET* ptrCast() {
return common::ku_dynamic_cast<TARGET*>(this);
}

public:
common::nodeID_t sourceNodeID;
};

struct SPOutputs : public RJOutputs {
public:
SPOutputs(common::nodeID_t sourceNodeID, std::shared_ptr<PathLengths> pathLengths)
struct SPDestinationOutputs : public RJOutputs {
std::shared_ptr<PathLengths> pathLengths;

SPDestinationOutputs(common::nodeID_t sourceNodeID, std::shared_ptr<PathLengths> pathLengths)
: RJOutputs{sourceNodeID}, pathLengths{std::move(pathLengths)} {}

public:
std::shared_ptr<PathLengths> pathLengths;
// Note: We do not fix the node table for pathLengths, because PathLengths is a
// FrontierPair implementation and RJCompState will call beginFrontierComputeBetweenTables
// on FrontierPair (and RJOutputs).
void beginFrontierComputeBetweenTables(common::table_id_t, common::table_id_t) override{};

void beginWritingOutputsForDstNodesInTable(common::table_id_t tableID) override {
pathLengths->pinCurFrontierTableID(tableID);
}
};

struct PathsOutputs : public SPOutputs {
PathsOutputs(common::nodeID_t sourceNodeID, std::shared_ptr<PathLengths> pathLengths,
std::unique_ptr<BFSGraph> bfsGraph)
: SPOutputs(sourceNodeID, std::move(pathLengths)), bfsGraph{std::move(bfsGraph)} {}
struct PathsOutputs : public RJOutputs {
std::unique_ptr<BFSGraph> bfsGraph;

PathsOutputs(common::nodeID_t sourceNodeID, std::unique_ptr<BFSGraph> bfsGraph)
: RJOutputs(sourceNodeID), bfsGraph{std::move(bfsGraph)} {}

void beginFrontierComputeBetweenTables(common::table_id_t,
common::table_id_t nextFrontierTableID) override {
// Note: We do not fix the node table for pathLengths, which is inherited from AllSPOutputs.
// See the comment in SingleSPOutputs::beginFrontierComputeBetweenTables() for details.
bfsGraph->pinTableID(nextFrontierTableID);
};

void beginWritingOutputsForDstNodesInTable(common::table_id_t tableID) override;

public:
std::unique_ptr<BFSGraph> bfsGraph;
void beginWritingOutputsForDstNodesInTable(common::table_id_t tableID) override {
bfsGraph->pinTableID(tableID);
}
};

class GDSOutputWriter {
Expand Down

0 comments on commit 8d69e48

Please sign in to comment.