Skip to content

Commit

Permalink
scan local table one by one
Browse files Browse the repository at this point in the history
  • Loading branch information
yiyun-sj committed Aug 6, 2024
1 parent 7cf1dd5 commit d0cac92
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 67 deletions.
5 changes: 1 addition & 4 deletions src/include/storage/store/rel_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ struct LocalRelTableScanState;
struct RelTableScanState : TableScanState {
common::RelDataDirection direction;
common::ValueVector* boundNodeIDVector;
common::offset_t boundNodeOffset;
Column* csrOffsetColumn;
Column* csrLengthColumn;

Expand Down Expand Up @@ -44,8 +43,7 @@ struct RelTableScanState : TableScanState {
const std::vector<Column*>& columns, Column* csrOffsetCol, Column* csrLengthCol,
common::RelDataDirection direction, std::vector<ColumnPredicateSet> columnPredicateSets)
: TableScanState{columnIDs, columns, std::move(columnPredicateSets)}, direction{direction},
boundNodeIDVector{nullptr}, boundNodeOffset{common::INVALID_OFFSET},
csrOffsetColumn{csrOffsetCol}, csrLengthColumn{csrLengthCol},
boundNodeIDVector{nullptr}, csrOffsetColumn{csrOffsetCol}, csrLengthColumn{csrLengthCol},
localTableScanState{nullptr} {
nodeGroupScanState = std::make_unique<CSRNodeGroupScanState>(this->columnIDs.size());
if (!this->columnPredicateSets.empty()) {
Expand All @@ -56,7 +54,6 @@ struct RelTableScanState : TableScanState {
}

void resetState() override {
boundNodeOffset = common::INVALID_OFFSET;
nodeGroupScanState->resetState();
}
};
Expand Down
49 changes: 16 additions & 33 deletions src/storage/local_storage/local_rel_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,37 +144,17 @@ void LocalRelTable::checkIfNodeHasRels(ValueVector* srcNodeIDVector) const {
void LocalRelTable::initializeScan(TableScanState& state) {
auto& relScanState = state.cast<LocalRelTableScanState>();
KU_ASSERT(relScanState.source == TableScanSource::UNCOMMITTED);
auto& nodeSelVector = relScanState.boundNodeIDVector->state->getSelVector();
relScanState.nodeGroup = localNodeGroup.get();
relScanState.rowIndices.clear();
relScanState.batchSize = 0;
auto& nodeSelVector = relScanState.boundNodeIDVector->state->getSelVector();
auto& index = relScanState.direction == RelDataDirection::FWD ? fwdIndex : bwdIndex;
offset_t nodeOffset = relScanState.boundNodeOffset;
// collect all node ids that can be read from the same node group
while (relScanState.endNodeIdx < relScanState.totalNodeIdx &&
relScanState.batchSize < DEFAULT_VECTOR_CAPACITY) {
nodeOffset =
relScanState.boundNodeIDVector->readNodeOffset(nodeSelVector[relScanState.endNodeIdx]);
if (index.contains(nodeOffset)) {
auto numToScan = std::min(index[nodeOffset].size() - relScanState.nextRowToScan,
DEFAULT_VECTOR_CAPACITY - relScanState.batchSize);
relScanState.rowIndices.insert(relScanState.rowIndices.end(),
index[nodeOffset].begin() + relScanState.nextRowToScan,
index[nodeOffset].begin() + relScanState.nextRowToScan + numToScan);
KU_ASSERT(
std::is_sorted(relScanState.rowIndices.begin(), relScanState.rowIndices.end()));
relScanState.batchSize += numToScan;
if (numToScan < index[nodeOffset].size() - relScanState.nextRowToScan) {
relScanState.nextRowToScan += numToScan;
KU_ASSERT(relScanState.batchSize == DEFAULT_VECTOR_CAPACITY);
break;
} else {
relScanState.nextRowToScan = 0;
}
}
relScanState.endNodeIdx++;
offset_t nodeOffset =
relScanState.boundNodeIDVector->readNodeOffset(nodeSelVector[relScanState.endNodeIdx++]);
if (index.contains(nodeOffset)) {
relScanState.rowIndices = index[nodeOffset];
KU_ASSERT(std::is_sorted(relScanState.rowIndices.begin(), relScanState.rowIndices.end()));
} else {
relScanState.rowIndices.clear();
}
KU_ASSERT(relScanState.batchSize == relScanState.rowIndices.size());
}

std::vector<column_id_t> LocalRelTable::rewriteLocalColumnIDs(RelDataDirection direction,
Expand All @@ -196,19 +176,22 @@ column_id_t LocalRelTable::rewriteLocalColumnID(RelDataDirection direction, colu
}

bool LocalRelTable::scan(Transaction* transaction, TableScanState& state) const {
const auto& relScanState = state.cast<RelTableScanState>();
auto& relScanState = state.cast<RelTableScanState>();
KU_ASSERT(relScanState.localTableScanState);
auto& localScanState = *relScanState.localTableScanState;
if (localScanState.batchSize == 0) {
KU_ASSERT(localScanState.rowIndices.size() >= localScanState.nextRowToScan);
relScanState.batchSize = std::min(localScanState.rowIndices.size() - localScanState.nextRowToScan,
DEFAULT_VECTOR_CAPACITY);
if (relScanState.batchSize == 0) {
return false;
}
for (auto i = 0u; i < localScanState.batchSize; i++) {
for (auto i = 0u; i < relScanState.batchSize; i++) {
localScanState.rowIdxVector->setValue<row_idx_t>(i,
localScanState.rowIndices[localScanState.nextRowToScan + i]);
}
localScanState.rowIdxVector->state->getSelVectorUnsafe().setSelSize(localScanState.batchSize);
localScanState.rowIdxVector->state->getSelVectorUnsafe().setSelSize(relScanState.batchSize);
localNodeGroup->lookup(transaction, localScanState);
localScanState.nextRowToScan += localScanState.batchSize;
localScanState.nextRowToScan += relScanState.batchSize;
return true;
}

Expand Down
27 changes: 15 additions & 12 deletions src/storage/store/csr_node_group.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,45 +12,48 @@ namespace storage {

void CSRNodeGroup::initializeScanState(Transaction* transaction, TableScanState& state) {
auto& relScanState = state.cast<RelTableScanState>();
KU_ASSERT(nodeGroupIdx == StorageUtils::getNodeGroupIdx(relScanState.boundNodeOffset));
auto& nodeSelVector = relScanState.boundNodeIDVector->state->getSelVector();
KU_ASSERT(nodeGroupIdx == StorageUtils::getNodeGroupIdx(
relScanState.boundNodeIDVector->readNodeOffset(nodeSelVector[relScanState.currNodeIdx])
));
KU_ASSERT(relScanState.nodeGroupScanState);
auto& nodeGroupScanState = relScanState.nodeGroupScanState->cast<CSRNodeGroupScanState>();
relScanState.nodeGroupScanState->resetState();
relScanState.nodeGroupIdx = nodeGroupIdx;
nodeGroupScanState.source = CSRNodeGroupScanSource::NONE;
// Scan the csr header chunks from disk.
if (persistentChunkGroup) {
initializePersistentCSRHeader(transaction, relScanState, nodeGroupScanState);
}
// Queue all nodes to be scanned in the node group.
auto& nodeSelVector = relScanState.boundNodeIDVector->state->getSelVector();
nodeGroupScanState.nextRowToScan = 0;
const auto startNodeOffset = StorageUtils::getStartOffsetOfNodeGroup(nodeGroupIdx);
for (auto startNodeIdx = relScanState.currNodeIdx; startNodeIdx < relScanState.endNodeIdx;
startNodeIdx++) {
const auto offsetInGroup =
relScanState.boundNodeIDVector->readNodeOffset(nodeSelVector[startNodeIdx]) -
startNodeOffset;
for (auto i = relScanState.currNodeIdx; i < relScanState.endNodeIdx; i++) {
const auto offsetInGroup =
relScanState.boundNodeIDVector->readNodeOffset(nodeSelVector[i]) - startNodeOffset;
if (persistentChunkGroup) {
auto offset = nodeGroupScanState.csrHeader->getStartCSROffset(offsetInGroup);
auto length = nodeGroupScanState.csrHeader->getCSRLength(offsetInGroup);
if (length > 0) {
nodeGroupScanState.persistentCSRLists.emplace_back(offset, length);
nodeGroupScanState.source = CSRNodeGroupScanSource::COMMITTED_PERSISTENT;
}
}
if (csrIndex) {
auto& index = csrIndex->indices[offsetInGroup];
if (!index.isSequential) {
KU_ASSERT(std::is_sorted(index.rowIndices.begin(), index.rowIndices.end()));
}
if (index.rowIndices.size() > 0 &&
nodeGroupScanState.source == CSRNodeGroupScanSource::NONE) {
if (index.rowIndices.size() > 0) {
nodeGroupScanState.inMemCSRLists.push_back(index);
nodeGroupScanState.source = CSRNodeGroupScanSource::COMMITTED_IN_MEMORY;
}
}
}
if (!nodeGroupScanState.persistentCSRLists.empty()) {
nodeGroupScanState.source = CSRNodeGroupScanSource::COMMITTED_PERSISTENT;
} else if (!nodeGroupScanState.inMemCSRLists.empty()) {
nodeGroupScanState.source = CSRNodeGroupScanSource::COMMITTED_IN_MEMORY;
} else {
nodeGroupScanState.source = CSRNodeGroupScanSource::NONE;
}
}

void CSRNodeGroup::initializePersistentCSRHeader(Transaction* transaction,
Expand Down
26 changes: 10 additions & 16 deletions src/storage/store/rel_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,11 @@ void RelTable::initializeScanState(Transaction* transaction, TableScanState& sca
auto& nodeSelVector = relScanState.boundNodeIDVector->state->getSelVector();
relScanState.totalNodeIdx = nodeSelVector.getSelSize();
KU_ASSERT(relScanState.totalNodeIdx > 0);
relScanState.endNodeIdx = relScanState.currNodeIdx;
relScanState.boundNodeOffset =
KU_ASSERT(relScanState.endNodeIdx == relScanState.currNodeIdx);
KU_ASSERT(relScanState.endNodeIdx < relScanState.totalNodeIdx);
offset_t nodeOffset =
relScanState.boundNodeIDVector->readNodeOffset(nodeSelVector[relScanState.currNodeIdx]);
if (relScanState.boundNodeOffset >= StorageConstants::MAX_NUM_ROWS_IN_TABLE) {
if (nodeOffset >= StorageConstants::MAX_NUM_ROWS_IN_TABLE) {
// No more to read from committed
relScanState.nodeGroup = nullptr;
if (relScanState.localTableScanState) {
Expand All @@ -75,8 +76,7 @@ void RelTable::initializeScanState(Transaction* transaction, TableScanState& sca

relScanState.source = TableScanSource::COMMITTED;
relScanState.currentCSROffset = 0;
auto nodeGroupIdx = StorageUtils::getNodeGroupIdx(relScanState.boundNodeOffset);
offset_t nodeOffset = relScanState.boundNodeOffset;
auto nodeGroupIdx = StorageUtils::getNodeGroupIdx(nodeOffset);
// collect all node ids that can be read from the same node group
while (relScanState.endNodeIdx < relScanState.totalNodeIdx) {
nodeOffset =
Expand All @@ -101,15 +101,10 @@ void RelTable::initializeLocalRelScanState(RelTableScanState& relScanState) {
KU_ASSERT(relScanState.localTableScanState);
auto& localScanState = *relScanState.localTableScanState;
KU_ASSERT(localScanState.localRelTable);
localScanState.boundNodeOffset = relScanState.boundNodeOffset;
localScanState.currNodeIdx = relScanState.currNodeIdx;
localScanState.endNodeIdx = relScanState.endNodeIdx;
localScanState.totalNodeIdx = relScanState.totalNodeIdx;
localScanState.rowIdxVector->setState(relScanState.rowIdxVector->state);
localScanState.localRelTable->initializeScan(*relScanState.localTableScanState);
relScanState.currNodeIdx = localScanState.currNodeIdx;
relScanState.endNodeIdx = localScanState.endNodeIdx;
relScanState.batchSize = localScanState.batchSize;
}

bool RelTable::scanInternal(Transaction* transaction, TableScanState& scanState) {
Expand All @@ -124,7 +119,7 @@ bool RelTable::scanInternal(Transaction* transaction, TableScanState& scanState)
if (relScanState.currNodeIdx == relScanState.endNodeIdx) {
initializeScanState(transaction, relScanState);
}
relScanState.boundNodeOffset =
offset_t curNodeOffset =
relScanState.boundNodeIDVector->readNodeOffset(nodeIDSelVector[relScanState.currNodeIdx]);
row_idx_t posInLastCSR = 0;
row_idx_t currCSRSize = INVALID_OFFSET;
Expand All @@ -134,7 +129,7 @@ bool RelTable::scanInternal(Transaction* transaction, TableScanState& scanState)
auto& csrNodeGroupScanState =
relScanState.nodeGroupScanState->cast<CSRNodeGroupScanState>();
currCSRSize = relScanState.nodeGroup->cast<CSRNodeGroup>().getCSRLength(
csrNodeGroupScanState, relScanState.boundNodeOffset - startNodeOffset);
csrNodeGroupScanState, curNodeOffset - startNodeOffset);
posInLastCSR = csrNodeGroupScanState.nextRowToScan;
} break;
case TableScanSource::UNCOMMITTED: {
Expand All @@ -143,8 +138,8 @@ bool RelTable::scanInternal(Transaction* transaction, TableScanState& scanState)
auto localTable = relScanState.localTableScanState->localRelTable;
auto& index = relScanState.direction == RelDataDirection::FWD ? localTable->getFWDIndex() :
localTable->getBWDIndex();
if (index.contains(relScanState.boundNodeOffset)) {
currCSRSize = index[relScanState.boundNodeOffset].size();
if (index.contains(curNodeOffset)) {
currCSRSize = index[curNodeOffset].size();
}
} break;
case TableScanSource::NONE: {
Expand All @@ -161,8 +156,7 @@ bool RelTable::scanInternal(Transaction* transaction, TableScanState& scanState)
return false;
}
}
// This assumes nodeIDVector is initially unfiltered, which is not safe
// we should do this using similar logic to Flatten
// This assumes nodeIDVector is initially unfiltered
nodeIDSelVector.getMultableBuffer()[0] = nodeIDSelVector[relScanState.currNodeIdx];
nodeIDSelVector.setToFiltered(1);
if (relScanState.currentCSROffset == 0) {
Expand Down
6 changes: 4 additions & 2 deletions src/storage/store/rel_table_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ std::pair<CSRNodeGroupScanSource, row_idx_t> RelTableData::findMatchingRow(Trans
scanState->IDVector = scanState->outputVectors[0];
scanState->rowIdxVector->state = scanState->IDVector->state;
scanState->source = TableScanSource::COMMITTED;
scanState->boundNodeOffset = boundNodeOffset;
scanState->currNodeIdx = 0;
scanState->endNodeIdx = 1;
scanState->nodeGroup = getNodeGroup(nodeGroupIdx);
scanState->nodeGroup->initializeScanState(transaction, *scanState);
row_idx_t matchingRowIdx = INVALID_ROW_IDX;
Expand Down Expand Up @@ -182,7 +183,8 @@ void RelTableData::checkIfNodeHasRels(Transaction* transaction,
scanState->outputVectors.push_back(scanChunk.getValueVector(0).get());
scanState->IDVector = scanState->outputVectors[0];
scanState->source = TableScanSource::COMMITTED;
scanState->boundNodeOffset = nodeOffset;
scanState->currNodeIdx = 0;
scanState->endNodeIdx = 1;
scanState->nodeGroup = getNodeGroup(nodeGroupIdx);
scanState->nodeGroup->initializeScanState(transaction, *scanState);
while (true) {
Expand Down

0 comments on commit d0cac92

Please sign in to comment.