Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor list auxiliary buffer #4052

Merged
merged 2 commits into from
Aug 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions src/common/vector/value_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -533,9 +533,29 @@ void StringVector::copyToRowData(const ValueVector* vector, uint32_t pos, uint8_
}
}

void ListVector::copyListEntryAndBufferMetaData(ValueVector& vector, const ValueVector& other) {
auto& selVector = vector.state->getSelVector();
auto& otherSelVector = other.state->getSelVector();
KU_ASSERT(selVector.getSelSize() == otherSelVector.getSelSize());
// Copy list entries
for (auto i = 0u; i < otherSelVector.getSelSize(); ++i) {
auto pos = selVector[i];
auto otherPos = otherSelVector[i];
if (other.isNull(otherPos)) {
vector.setNull(pos, true);
} else {
vector.setValue(pos, other.getValue<list_entry_t>(otherPos));
}
}
// Copy buffer metadata
auto& buffer = getAuxBufferUnsafe(vector);
auto& otherBuffer = getAuxBuffer(other);
buffer.size = otherBuffer.size;
buffer.capacity = otherBuffer.capacity;
}

void ListVector::copyFromRowData(ValueVector* vector, uint32_t pos, const uint8_t* rowData) {
KU_ASSERT(vector->dataType.getPhysicalType() == PhysicalTypeID::LIST ||
vector->dataType.getPhysicalType() == PhysicalTypeID::ARRAY);
KU_ASSERT(validateType(*vector));
auto& srcKuList = *(ku_list_t*)rowData;
auto srcNullBytes = reinterpret_cast<uint8_t*>(srcKuList.overflowPtr);
auto srcListValues = srcNullBytes + NullBuffer::getNumBytesForNullValues(srcKuList.size);
Expand Down Expand Up @@ -593,8 +613,8 @@ void ListVector::copyFromVectorData(ValueVector* dstVector, uint8_t* dstData,
}
}

void ListVector::appendDataVector(kuzu::common::ValueVector* dstVector,
kuzu::common::ValueVector* srcDataVector, uint64_t numValuesToAppend) {
void ListVector::appendDataVector(ValueVector* dstVector, ValueVector* srcDataVector,
uint64_t numValuesToAppend) {
auto offset = getDataVectorSize(dstVector);
resizeDataVector(dstVector, offset + numValuesToAppend);
auto dstDataVector = getDataVector(dstVector);
Expand Down
1 change: 1 addition & 0 deletions src/expression_evaluator/lambda_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ void ListLambdaEvaluator::evaluate() {
void ListLambdaEvaluator::resolveResultVector(const ResultSet&, MemoryManager* memoryManager) {
resultVector = std::make_shared<ValueVector>(expression->getDataType().copy(), memoryManager);
resultVector->state = children[0]->resultVector->state;
isResultFlat_ = children[0]->isResultFlat();
}

} // namespace evaluator
Expand Down
10 changes: 1 addition & 9 deletions src/function/list/list_transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,9 @@ static void execFunc(const std::vector<std::shared_ptr<common::ValueVector>>& in
auto lambdaParamVector = listLambdaBindData->lambdaParamEvaluators[0]->resultVector.get();
lambdaParamVector->state->getSelVectorUnsafe().setSelSize(listSize);
listLambdaBindData->rootEvaluator->evaluate();
auto& listInputSelVector = inputVector->state->getSelVector();
// NOTE: the following can be done with a memcpy. But I think soon we will need to change
// to handle cases like
// MATCH (a:person) RETURN LIST_TRANSFORM([1,2,3], x->x + a.ID)
// So I'm leaving it in the naive form.
KU_ASSERT(input.size() == 2);
ListVector::setDataVector(&result, input[1]);
for (auto i = 0u; i < listInputSelVector.getSelSize(); ++i) {
auto pos = listInputSelVector[i];
result.setValue(pos, inputVector->getValue<list_entry_t>(pos));
}
ListVector::copyListEntryAndBufferMetaData(result, *inputVector);
}

function_set ListTransformFunction::getFunctionSet() {
Expand Down
26 changes: 1 addition & 25 deletions src/function/path/properties_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,31 +51,7 @@ static void compileFunc(FunctionBindData* bindData,

static void execFunc(const std::vector<std::shared_ptr<ValueVector>>& parameters,
ValueVector& result, void* /*dataPtr*/) {
auto& resultSelVector = result.state->getSelVector();
if (parameters[0]->state->isFlat()) {
auto inputPos = parameters[0]->state->getSelVector()[0];
if (parameters[0]->isNull(inputPos)) {
for (auto i = 0u; i < resultSelVector.getSelSize(); ++i) {
auto pos = resultSelVector[i];
result.setNull(pos, true);
}
} else {
auto& listEntry = parameters[0]->getValue<list_entry_t>(inputPos);
for (auto i = 0u; i < resultSelVector.getSelSize(); ++i) {
auto pos = resultSelVector[i];
result.setValue(pos, listEntry);
}
}
} else {
for (auto i = 0u; i < resultSelVector.getSelSize(); ++i) {
auto pos = resultSelVector[i];
if (parameters[0]->isNull(pos)) {
result.setNull(pos, true);
} else {
result.setValue(pos, parameters[0]->getValue<list_entry_t>(pos));
}
}
}
ListVector::copyListEntryAndBufferMetaData(result, *parameters[0]);
}

function_set PropertiesFunction::getFunctionSet() {
Expand Down
10 changes: 10 additions & 0 deletions src/include/common/vector/auxiliary_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@ class ValueVector;
class AuxiliaryBuffer {
public:
virtual ~AuxiliaryBuffer() = default;

template<class TARGET>
TARGET& cast() {
return common::ku_dynamic_cast<AuxiliaryBuffer&, TARGET&>(*this);
}

template<class TARGET>
const TARGET& constCast() const {
return common::ku_dynamic_cast<const AuxiliaryBuffer&, const TARGET&>(*this);
}
};

class StringAuxiliaryBuffer : public AuxiliaryBuffer {
Expand Down
70 changes: 37 additions & 33 deletions src/include/common/vector/value_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,61 +146,54 @@ struct KUZU_API BlobVector {
}
};

// Currently, ListVector is used for both VAR_LIST and ARRAY physical type
// ListVector is used for both LIST and ARRAY physical type
class KUZU_API ListVector {
public:
static const ListAuxiliaryBuffer& getAuxBuffer(const ValueVector& vector) {
return vector.auxiliaryBuffer->constCast<ListAuxiliaryBuffer>();
}
static ListAuxiliaryBuffer& getAuxBufferUnsafe(const ValueVector& vector) {
return vector.auxiliaryBuffer->cast<ListAuxiliaryBuffer>();
}
// If you call setDataVector during initialize, there must be a followed up
// copyListEntryAndBufferMetaData at runtime.
// TODO(Xiyang): try to merge setDataVector & copyListEntryAndBufferMetaData
andyfengHKU marked this conversation as resolved.
Show resolved Hide resolved
static void setDataVector(const ValueVector* vector, std::shared_ptr<ValueVector> dataVector) {
KU_ASSERT(vector->dataType.getPhysicalType() == PhysicalTypeID::LIST ||
vector->dataType.getPhysicalType() == PhysicalTypeID::ARRAY);
auto listBuffer =
ku_dynamic_cast<AuxiliaryBuffer*, ListAuxiliaryBuffer*>(vector->auxiliaryBuffer.get());
listBuffer->setDataVector(std::move(dataVector));
KU_ASSERT(validateType(*vector));
auto& listBuffer = getAuxBufferUnsafe(*vector);
listBuffer.setDataVector(std::move(dataVector));
}
static void copyListEntryAndBufferMetaData(ValueVector& vector, const ValueVector& other);
static ValueVector* getDataVector(const ValueVector* vector) {
KU_ASSERT(vector->dataType.getPhysicalType() == PhysicalTypeID::LIST ||
vector->dataType.getPhysicalType() == PhysicalTypeID::ARRAY);
return ku_dynamic_cast<AuxiliaryBuffer*, ListAuxiliaryBuffer*>(
vector->auxiliaryBuffer.get())
->getDataVector();
KU_ASSERT(validateType(*vector));
return getAuxBuffer(*vector).getDataVector();
}
static std::shared_ptr<ValueVector> getSharedDataVector(const ValueVector* vector) {
KU_ASSERT(vector->dataType.getPhysicalType() == PhysicalTypeID::LIST ||
vector->dataType.getPhysicalType() == PhysicalTypeID::ARRAY);
return ku_dynamic_cast<AuxiliaryBuffer*, ListAuxiliaryBuffer*>(
vector->auxiliaryBuffer.get())
->getSharedDataVector();
KU_ASSERT(validateType(*vector));
return getAuxBuffer(*vector).getSharedDataVector();
}
static uint64_t getDataVectorSize(const ValueVector* vector) {
KU_ASSERT(vector->dataType.getPhysicalType() == PhysicalTypeID::LIST ||
vector->dataType.getPhysicalType() == PhysicalTypeID::ARRAY);
return ku_dynamic_cast<AuxiliaryBuffer*, ListAuxiliaryBuffer*>(
vector->auxiliaryBuffer.get())
->getSize();
KU_ASSERT(validateType(*vector));
return getAuxBuffer(*vector).getSize();
}

static uint8_t* getListValues(const ValueVector* vector, const list_entry_t& listEntry) {
KU_ASSERT(vector->dataType.getPhysicalType() == PhysicalTypeID::LIST ||
vector->dataType.getPhysicalType() == PhysicalTypeID::ARRAY);
KU_ASSERT(validateType(*vector));
auto dataVector = getDataVector(vector);
return dataVector->getData() + dataVector->getNumBytesPerValue() * listEntry.offset;
}
static uint8_t* getListValuesWithOffset(const ValueVector* vector,
const list_entry_t& listEntry, offset_t elementOffsetInList) {
KU_ASSERT(vector->dataType.getPhysicalType() == PhysicalTypeID::LIST ||
vector->dataType.getPhysicalType() == PhysicalTypeID::ARRAY);
KU_ASSERT(validateType(*vector));
return getListValues(vector, listEntry) +
elementOffsetInList * getDataVector(vector)->getNumBytesPerValue();
}
static list_entry_t addList(ValueVector* vector, uint64_t listSize) {
KU_ASSERT(vector->dataType.getPhysicalType() == PhysicalTypeID::LIST ||
vector->dataType.getPhysicalType() == PhysicalTypeID::ARRAY);
return ku_dynamic_cast<AuxiliaryBuffer*, ListAuxiliaryBuffer*>(
vector->auxiliaryBuffer.get())
->addList(listSize);
KU_ASSERT(validateType(*vector));
return getAuxBufferUnsafe(*vector).addList(listSize);
}
static void resizeDataVector(ValueVector* vector, uint64_t numValues) {
ku_dynamic_cast<AuxiliaryBuffer*, ListAuxiliaryBuffer*>(vector->auxiliaryBuffer.get())
->resize(numValues);
KU_ASSERT(validateType(*vector));
getAuxBufferUnsafe(*vector).resize(numValues);
}

static void copyFromRowData(ValueVector* vector, uint32_t pos, const uint8_t* rowData);
Expand All @@ -211,6 +204,17 @@ class KUZU_API ListVector {
static void appendDataVector(ValueVector* dstVector, ValueVector* srcDataVector,
uint64_t numValuesToAppend);
static void sliceDataVector(ValueVector* vectorToSlice, uint64_t offset, uint64_t numValues);

private:
static bool validateType(const ValueVector& vector) {
switch (vector.dataType.getPhysicalType()) {
case PhysicalTypeID::LIST:
case PhysicalTypeID::ARRAY:
return true;
default:
return false;
}
}
};

class StructVector {
Expand Down
2 changes: 0 additions & 2 deletions src/include/expression_evaluator/expression_evaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ class ExpressionEvaluator {
EvaluatorType getEvaluatorType() const { return type; }

std::shared_ptr<binder::Expression> getExpression() const { return expression; }
const common::LogicalType& getResultDataType() const { return expression->getDataType(); }
void setResultFlat(bool val) { isResultFlat_ = val; }
bool isResultFlat() const { return isResultFlat_; }

const evaluator_vector_t& getChildren() const { return children; }
Expand Down
3 changes: 2 additions & 1 deletion src/planner/operator/factorization/flatten_resolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,8 @@ void GroupDependencyAnalyzer::visitNodeOrRel(std::shared_ptr<binder::Expression>
auto& node = expr->constCast<NodeExpression>();
visit(node.getInternalID());
} break;
case LogicalTypeID::REL: {
case LogicalTypeID::REL:
case LogicalTypeID::RECURSIVE_REL: {
andyfengHKU marked this conversation as resolved.
Show resolved Hide resolved
auto& rel = expr->constCast<RelExpression>();
visit(rel.getSrcNode()->getInternalID());
visit(rel.getDstNode()->getInternalID());
Expand Down
10 changes: 9 additions & 1 deletion test/test_files/demo_db/demo_db.test
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
--

-CASE DemoDBTest

-LOG UndirectedRecursivePattern
-STATEMENT MATCH p = (u:User {name: 'Adam'})-[e*1..3]-(c:User {name: 'Adam'})
HINT (u JOIN e) JOIN c
Expand Down Expand Up @@ -455,3 +454,12 @@ Adam|Karissa
Adam|Zhang
Karissa|Zhang
Zhang|Noura

-STATEMENT MATCH (n)-[:LivesIn]->(:City)
WITH collect(n.name) AS names
MATCH p = (u:User {name: "Adam"})-[e:Follows*]->(u2:User {name: "Noura"})
WHERE size(list_filter(properties(nodes(e),'name'), x->x IN names)) > 0
RETURN properties(nodes(p),'name')
---- 2
[Adam,Karissa,Zhang,Noura]
[Adam,Zhang,Noura]