Skip to content

Commit

Permalink
Refactor list auxilary buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Aug 10, 2024
1 parent f08f80d commit dbc4547
Show file tree
Hide file tree
Showing 9 changed files with 85 additions and 75 deletions.
28 changes: 24 additions & 4 deletions src/common/vector/value_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -533,9 +533,29 @@ void StringVector::copyToRowData(const ValueVector* vector, uint32_t pos, uint8_
}
}

void ListVector::copyListEntryAndBufferMetaData(ValueVector& vector, const ValueVector& other) {
auto& selVector = vector.state->getSelVector();
auto& otherSelVector = other.state->getSelVector();
KU_ASSERT(selVector.getSelSize() == otherSelVector.getSelSize());
// Copy list entries
for (auto i = 0u; i < otherSelVector.getSelSize(); ++i) {
auto pos = selVector[i];
auto otherPos = otherSelVector[i];
if (other.isNull(otherPos)) {
vector.setNull(pos, true);
} else {
vector.setValue(pos, other.getValue<list_entry_t>(otherPos));
}
}
// Copy buffer metadata
auto& buffer = getAuxBufferUnsafe(vector);
auto& otherBuffer = getAuxBuffer(other);
buffer.size = otherBuffer.size;
buffer.capacity = otherBuffer.capacity;
}

void ListVector::copyFromRowData(ValueVector* vector, uint32_t pos, const uint8_t* rowData) {
KU_ASSERT(vector->dataType.getPhysicalType() == PhysicalTypeID::LIST ||
vector->dataType.getPhysicalType() == PhysicalTypeID::ARRAY);
KU_ASSERT(validateType(*vector));
auto& srcKuList = *(ku_list_t*)rowData;
auto srcNullBytes = reinterpret_cast<uint8_t*>(srcKuList.overflowPtr);
auto srcListValues = srcNullBytes + NullBuffer::getNumBytesForNullValues(srcKuList.size);
Expand Down Expand Up @@ -593,8 +613,8 @@ void ListVector::copyFromVectorData(ValueVector* dstVector, uint8_t* dstData,
}
}

void ListVector::appendDataVector(kuzu::common::ValueVector* dstVector,
kuzu::common::ValueVector* srcDataVector, uint64_t numValuesToAppend) {
void ListVector::appendDataVector(ValueVector* dstVector,
ValueVector* srcDataVector, uint64_t numValuesToAppend) {
auto offset = getDataVectorSize(dstVector);
resizeDataVector(dstVector, offset + numValuesToAppend);
auto dstDataVector = getDataVector(dstVector);
Expand Down
1 change: 1 addition & 0 deletions src/expression_evaluator/lambda_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ void ListLambdaEvaluator::evaluate() {
void ListLambdaEvaluator::resolveResultVector(const ResultSet&, MemoryManager* memoryManager) {
resultVector = std::make_shared<ValueVector>(expression->getDataType().copy(), memoryManager);
resultVector->state = children[0]->resultVector->state;
isResultFlat_ = children[0]->isResultFlat();
}

} // namespace evaluator
Expand Down
10 changes: 1 addition & 9 deletions src/function/list/list_transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,9 @@ static void execFunc(const std::vector<std::shared_ptr<common::ValueVector>>& in
auto lambdaParamVector = listLambdaBindData->lambdaParamEvaluators[0]->resultVector.get();
lambdaParamVector->state->getSelVectorUnsafe().setSelSize(listSize);
listLambdaBindData->rootEvaluator->evaluate();
auto& listInputSelVector = inputVector->state->getSelVector();
// NOTE: the following can be done with a memcpy. But I think soon we will need to change
// to handle cases like
// MATCH (a:person) RETURN LIST_TRANSFORM([1,2,3], x->x + a.ID)
// So I'm leaving it in the naive form.
KU_ASSERT(input.size() == 2);
ListVector::setDataVector(&result, input[1]);
for (auto i = 0u; i < listInputSelVector.getSelSize(); ++i) {
auto pos = listInputSelVector[i];
result.setValue(pos, inputVector->getValue<list_entry_t>(pos));
}
ListVector::copyListEntryAndBufferMetaData(result, *inputVector);
}

function_set ListTransformFunction::getFunctionSet() {
Expand Down
26 changes: 1 addition & 25 deletions src/function/path/properties_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,31 +51,7 @@ static void compileFunc(FunctionBindData* bindData,

static void execFunc(const std::vector<std::shared_ptr<ValueVector>>& parameters,
ValueVector& result, void* /*dataPtr*/) {
auto& resultSelVector = result.state->getSelVector();
if (parameters[0]->state->isFlat()) {
auto inputPos = parameters[0]->state->getSelVector()[0];
if (parameters[0]->isNull(inputPos)) {
for (auto i = 0u; i < resultSelVector.getSelSize(); ++i) {
auto pos = resultSelVector[i];
result.setNull(pos, true);
}
} else {
auto& listEntry = parameters[0]->getValue<list_entry_t>(inputPos);
for (auto i = 0u; i < resultSelVector.getSelSize(); ++i) {
auto pos = resultSelVector[i];
result.setValue(pos, listEntry);
}
}
} else {
for (auto i = 0u; i < resultSelVector.getSelSize(); ++i) {
auto pos = resultSelVector[i];
if (parameters[0]->isNull(pos)) {
result.setNull(pos, true);
} else {
result.setValue(pos, parameters[0]->getValue<list_entry_t>(pos));
}
}
}
ListVector::copyListEntryAndBufferMetaData(result, *parameters[0]);
}

function_set PropertiesFunction::getFunctionSet() {
Expand Down
10 changes: 10 additions & 0 deletions src/include/common/vector/auxiliary_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@ class ValueVector;
class AuxiliaryBuffer {
public:
virtual ~AuxiliaryBuffer() = default;

template<class TARGET>
TARGET& cast() {
return common::ku_dynamic_cast<AuxiliaryBuffer&, TARGET&>(*this);
}

template<class TARGET>
const TARGET& constCast() const {
return common::ku_dynamic_cast<const AuxiliaryBuffer&, const TARGET&>(*this);
}
};

class StringAuxiliaryBuffer : public AuxiliaryBuffer {
Expand Down
70 changes: 37 additions & 33 deletions src/include/common/vector/value_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,61 +146,54 @@ struct KUZU_API BlobVector {
}
};

// Currently, ListVector is used for both VAR_LIST and ARRAY physical type
// ListVector is used for both LIST and ARRAY physical type
class KUZU_API ListVector {
public:
static const ListAuxiliaryBuffer& getAuxBuffer(const ValueVector& vector) {
return vector.auxiliaryBuffer->constCast<ListAuxiliaryBuffer>();
}
static ListAuxiliaryBuffer& getAuxBufferUnsafe(const ValueVector& vector) {
return vector.auxiliaryBuffer->cast<ListAuxiliaryBuffer>();
}
// If you call setDataVector during initialize, there must be a followed up
// copyListEntryAndBufferMetaData at runtime.
// TODO(Xiyang): try to merge setDataVector & copyListEntryAndBufferMetaData
static void setDataVector(const ValueVector* vector, std::shared_ptr<ValueVector> dataVector) {
KU_ASSERT(vector->dataType.getPhysicalType() == PhysicalTypeID::LIST ||
vector->dataType.getPhysicalType() == PhysicalTypeID::ARRAY);
auto listBuffer =
ku_dynamic_cast<AuxiliaryBuffer*, ListAuxiliaryBuffer*>(vector->auxiliaryBuffer.get());
listBuffer->setDataVector(std::move(dataVector));
KU_ASSERT(validateType(*vector));
auto& listBuffer = getAuxBufferUnsafe(*vector);
listBuffer.setDataVector(std::move(dataVector));
}
static void copyListEntryAndBufferMetaData(ValueVector& vector, const ValueVector& other);
static ValueVector* getDataVector(const ValueVector* vector) {
KU_ASSERT(vector->dataType.getPhysicalType() == PhysicalTypeID::LIST ||
vector->dataType.getPhysicalType() == PhysicalTypeID::ARRAY);
return ku_dynamic_cast<AuxiliaryBuffer*, ListAuxiliaryBuffer*>(
vector->auxiliaryBuffer.get())
->getDataVector();
KU_ASSERT(validateType(*vector));
return getAuxBuffer(*vector).getDataVector();
}
static std::shared_ptr<ValueVector> getSharedDataVector(const ValueVector* vector) {
KU_ASSERT(vector->dataType.getPhysicalType() == PhysicalTypeID::LIST ||
vector->dataType.getPhysicalType() == PhysicalTypeID::ARRAY);
return ku_dynamic_cast<AuxiliaryBuffer*, ListAuxiliaryBuffer*>(
vector->auxiliaryBuffer.get())
->getSharedDataVector();
KU_ASSERT(validateType(*vector));
return getAuxBuffer(*vector).getSharedDataVector();
}
static uint64_t getDataVectorSize(const ValueVector* vector) {
KU_ASSERT(vector->dataType.getPhysicalType() == PhysicalTypeID::LIST ||
vector->dataType.getPhysicalType() == PhysicalTypeID::ARRAY);
return ku_dynamic_cast<AuxiliaryBuffer*, ListAuxiliaryBuffer*>(
vector->auxiliaryBuffer.get())
->getSize();
KU_ASSERT(validateType(*vector));
return getAuxBuffer(*vector).getSize();
}

static uint8_t* getListValues(const ValueVector* vector, const list_entry_t& listEntry) {
KU_ASSERT(vector->dataType.getPhysicalType() == PhysicalTypeID::LIST ||
vector->dataType.getPhysicalType() == PhysicalTypeID::ARRAY);
KU_ASSERT(validateType(*vector));
auto dataVector = getDataVector(vector);
return dataVector->getData() + dataVector->getNumBytesPerValue() * listEntry.offset;
}
static uint8_t* getListValuesWithOffset(const ValueVector* vector,
const list_entry_t& listEntry, offset_t elementOffsetInList) {
KU_ASSERT(vector->dataType.getPhysicalType() == PhysicalTypeID::LIST ||
vector->dataType.getPhysicalType() == PhysicalTypeID::ARRAY);
KU_ASSERT(validateType(*vector));
return getListValues(vector, listEntry) +
elementOffsetInList * getDataVector(vector)->getNumBytesPerValue();
}
static list_entry_t addList(ValueVector* vector, uint64_t listSize) {
KU_ASSERT(vector->dataType.getPhysicalType() == PhysicalTypeID::LIST ||
vector->dataType.getPhysicalType() == PhysicalTypeID::ARRAY);
return ku_dynamic_cast<AuxiliaryBuffer*, ListAuxiliaryBuffer*>(
vector->auxiliaryBuffer.get())
->addList(listSize);
KU_ASSERT(validateType(*vector));
return getAuxBufferUnsafe(*vector).addList(listSize);
}
static void resizeDataVector(ValueVector* vector, uint64_t numValues) {
ku_dynamic_cast<AuxiliaryBuffer*, ListAuxiliaryBuffer*>(vector->auxiliaryBuffer.get())
->resize(numValues);
KU_ASSERT(validateType(*vector));
getAuxBufferUnsafe(*vector).resize(numValues);
}

static void copyFromRowData(ValueVector* vector, uint32_t pos, const uint8_t* rowData);
Expand All @@ -211,6 +204,17 @@ class KUZU_API ListVector {
static void appendDataVector(ValueVector* dstVector, ValueVector* srcDataVector,
uint64_t numValuesToAppend);
static void sliceDataVector(ValueVector* vectorToSlice, uint64_t offset, uint64_t numValues);

private:
static bool validateType(const ValueVector& vector) {
switch (vector.dataType.getPhysicalType()) {
case PhysicalTypeID::LIST:
case PhysicalTypeID::ARRAY:
return true;
default:
return false;
}
}
};

class StructVector {
Expand Down
2 changes: 0 additions & 2 deletions src/include/expression_evaluator/expression_evaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ class ExpressionEvaluator {
EvaluatorType getEvaluatorType() const { return type; }

std::shared_ptr<binder::Expression> getExpression() const { return expression; }
const common::LogicalType& getResultDataType() const { return expression->getDataType(); }
void setResultFlat(bool val) { isResultFlat_ = val; }
bool isResultFlat() const { return isResultFlat_; }

const evaluator_vector_t& getChildren() const { return children; }
Expand Down
3 changes: 2 additions & 1 deletion src/planner/operator/factorization/flatten_resolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,8 @@ void GroupDependencyAnalyzer::visitNodeOrRel(std::shared_ptr<binder::Expression>
auto& node = expr->constCast<NodeExpression>();
visit(node.getInternalID());
} break;
case LogicalTypeID::REL: {
case LogicalTypeID::REL:
case LogicalTypeID::RECURSIVE_REL: {
auto& rel = expr->constCast<RelExpression>();
visit(rel.getSrcNode()->getInternalID());
visit(rel.getDstNode()->getInternalID());
Expand Down
10 changes: 9 additions & 1 deletion test/test_files/demo_db/demo_db.test
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
--

-CASE DemoDBTest

-LOG UndirectedRecursivePattern
-STATEMENT MATCH p = (u:User {name: 'Adam'})-[e*1..3]-(c:User {name: 'Adam'})
HINT (u JOIN e) JOIN c
Expand Down Expand Up @@ -455,3 +454,12 @@ Adam|Karissa
Adam|Zhang
Karissa|Zhang
Zhang|Noura

-STATEMENT MATCH (n)-[:LivesIn]->(:City)
WITH collect(n.name) AS names
MATCH p = (u:User {name: "Adam"})-[e:Follows*]->(u2:User {name: "Noura"})
WHERE size(list_filter(properties(nodes(e),'name'), x->x IN names)) > 0
RETURN properties(nodes(p),'name')
---- 2
[Adam,Karissa,Zhang,Noura]
[Adam,Zhang,Noura]

0 comments on commit dbc4547

Please sign in to comment.