Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Index entry serialization #4686

Merged
merged 10 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
cmake_minimum_required(VERSION 3.15)

project(Kuzu VERSION 0.7.1.1 LANGUAGES CXX C)
project(Kuzu VERSION 0.7.1.3 LANGUAGES CXX C)

option(SINGLE_THREADED "Single-threaded mode" FALSE)
if(SINGLE_THREADED)
Expand Down
24 changes: 24 additions & 0 deletions extension/fts/src/catalog/fts_index_catalog_entry.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "catalog/fts_index_catalog_entry.h"

#include "common/serializer/buffered_reader.h"

namespace kuzu {
namespace fts_extension {

Expand All @@ -11,5 +13,27 @@ std::unique_ptr<catalog::IndexCatalogEntry> FTSIndexCatalogEntry::copy() const {
return other;
}

void FTSIndexCatalogEntry::serializeAuxInfo(common::Serializer& serializer) const {
auto len = sizeof(numDocs) + sizeof(avgDocLen) + config.getNumBytesForSerialization();
serializer.serializeValue(len);
serializer.serializeValue(numDocs);
serializer.serializeValue(avgDocLen);
config.serialize(serializer);
}

std::unique_ptr<FTSIndexCatalogEntry> FTSIndexCatalogEntry::deserializeAuxInfo(
catalog::IndexCatalogEntry* indexCatalogEntry) {
auto reader = std::make_unique<common::BufferReader>(indexCatalogEntry->getAuxBuffer(),
indexCatalogEntry->getAuxBufferSize());
common::Deserializer deserializer{std::move(reader)};
common::idx_t numDocs = 0;
deserializer.deserializeValue(numDocs);
double avgDocLen = 0;
deserializer.deserializeValue(avgDocLen);
auto config = FTSConfig::deserialize(deserializer);
return std::make_unique<FTSIndexCatalogEntry>(indexCatalogEntry->getTableID(),
indexCatalogEntry->getIndexName(), numDocs, avgDocLen, config);
}

} // namespace fts_extension
} // namespace kuzu
21 changes: 18 additions & 3 deletions extension/fts/src/fts_extension.cpp
Original file line number Diff line number Diff line change
@@ -1,22 +1,37 @@
#include "fts_extension.h"

#include "catalog/catalog_entry/catalog_entry_type.h"
#include "catalog/catalog.h"
#include "catalog/catalog_entry/index_catalog_entry.h"
#include "catalog/fts_index_catalog_entry.h"
#include "function/create_fts_index.h"
#include "function/drop_fts_index.h"
#include "function/query_fts.h"
#include "function/query_fts_index.h"
#include "function/stem.h"
#include "main/client_context.h"
#include "main/database.h"

namespace kuzu {
namespace fts_extension {

static void initFTSEntries(transaction::Transaction* tx, catalog::Catalog& catalog) {
auto indexEntries = catalog.getIndexes()->getEntries(tx);
for (auto& [_, entry] : indexEntries) {
auto indexEntry = entry->ptrCast<catalog::IndexCatalogEntry>();
if (indexEntry->getIndexType() == FTSIndexCatalogEntry::TYPE_NAME) {
auto deserializedEntry = FTSIndexCatalogEntry::deserializeAuxInfo(indexEntry);
catalog.dropIndex(tx, indexEntry->getTableID(), indexEntry->getIndexName());
catalog.createIndex(tx, std::move(deserializedEntry));
acquamarin marked this conversation as resolved.
Show resolved Hide resolved
}
}
}

void FTSExtension::load(main::ClientContext* context) {
auto& db = *context->getDatabase();
ADD_SCALAR_FUNC(StemFunction);
ADD_GDS_FUNC(QFTSFunction);
ADD_GDS_FUNC(QueryFTSFunction);
db.addStandaloneCallFunction(CreateFTSFunction::name, CreateFTSFunction::getFunctionSet());
db.addStandaloneCallFunction(DropFTSFunction::name, DropFTSFunction::getFunctionSet());
initFTSEntries(context->getTransaction(), *db.getCatalog());
}

} // namespace fts_extension
Expand Down
2 changes: 1 addition & 1 deletion extension/fts/src/function/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ add_library(kuzu_fts_function
create_fts_index.cpp
fts_config.cpp
drop_fts_index.cpp
query_fts.cpp
query_fts_index.cpp
fts_utils.cpp)

set(FTS_OBJECT_FILES
Expand Down
17 changes: 17 additions & 0 deletions extension/fts/src/function/fts_config.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include "function/fts_config.h"

#include "common/exception/binder.h"
#include "common/serializer/deserializer.h"
#include "common/serializer/serializer.h"
#include "common/string_utils.h"
#include "function/stem.h"

Expand All @@ -24,6 +26,21 @@ FTSConfig::FTSConfig(const function::optional_params_t& optionalParams) {
}
}

void FTSConfig::serialize(common::Serializer& serializer) const {
serializer.serializeValue(stemmer);
}

FTSConfig FTSConfig::deserialize(common::Deserializer& deserializer) {
auto config = FTSConfig{};
deserializer.deserializeValue(config.stemmer);
return config;
}

// We store the length + the string itself. So the total size is sizeof(length) + string length.
uint64_t FTSConfig::getNumBytesForSerialization() const {
return sizeof(stemmer.size()) + stemmer.size();
}

void K::validate(double value) {
if (value < 0) {
throw common::BinderException{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include "function/query_fts.h"
#include "function/query_fts_index.h"

#include "binder/binder.h"
#include "binder/expression/expression_util.h"
Expand Down Expand Up @@ -27,20 +27,20 @@ namespace fts_extension {

using namespace function;

struct QFTSGDSBindData final : public function::GDSBindData {
struct QueryFTSBindData final : public function::GDSBindData {
std::vector<std::string> terms;
uint64_t numDocs;
double_t avgDocLen;
QueryFTSConfig config;
common::table_id_t outputTableID;

QFTSGDSBindData(std::vector<std::string> terms, graph::GraphEntry graphEntry,
QueryFTSBindData(std::vector<std::string> terms, graph::GraphEntry graphEntry,
std::shared_ptr<binder::Expression> docs, uint64_t numDocs, double_t avgDocLen,
QueryFTSConfig config)
: GDSBindData{std::move(graphEntry), std::move(docs)}, terms{std::move(terms)},
numDocs{numDocs}, avgDocLen{avgDocLen}, config{config},
outputTableID{nodeOutput->constCast<NodeExpression>().getSingleEntry()->getTableID()} {}
QFTSGDSBindData(const QFTSGDSBindData& other)
QueryFTSBindData(const QueryFTSBindData& other)
: GDSBindData{other}, terms{other.terms}, numDocs{other.numDocs},
avgDocLen{other.avgDocLen}, config{other.config}, outputTableID{other.outputTableID} {}

Expand All @@ -49,11 +49,11 @@ struct QFTSGDSBindData final : public function::GDSBindData {
uint64_t getNumUniqueTerms() const;

std::unique_ptr<GDSBindData> copy() const override {
return std::make_unique<QFTSGDSBindData>(*this);
return std::make_unique<QueryFTSBindData>(*this);
}
};

uint64_t QFTSGDSBindData::getNumUniqueTerms() const {
uint64_t QueryFTSBindData::getNumUniqueTerms() const {
auto uniqueTerms = std::unordered_set<std::string>{terms.begin(), terms.end()};
return uniqueTerms.size();
}
Expand Down Expand Up @@ -156,7 +156,7 @@ void runFrontiersOnce(processor::ExecutionContext* executionContext, QFTSState&
class QFTSOutputWriter {
public:
QFTSOutputWriter(storage::MemoryManager* mm, QFTSOutput* qFTSOutput,
const QFTSGDSBindData& bindData);
const QueryFTSBindData& bindData);

void write(processor::FactorizedTable& scoreFT, nodeID_t docNodeID, uint64_t len,
int64_t docsID);
Expand All @@ -170,12 +170,12 @@ class QFTSOutputWriter {
std::vector<common::ValueVector*> vectors;
common::idx_t pos;
storage::MemoryManager* mm;
const QFTSGDSBindData& bindData;
const QueryFTSBindData& bindData;
uint64_t numUniqueTerms;
};

QFTSOutputWriter::QFTSOutputWriter(storage::MemoryManager* mm, QFTSOutput* qFTSOutput,
const QFTSGDSBindData& bindData)
const QueryFTSBindData& bindData)
: qFTSOutput{qFTSOutput}, docsVector{LogicalType::INTERNAL_ID(), mm},
scoreVector{LogicalType::UINT64(), mm}, mm{mm}, bindData{bindData} {
auto state = DataChunkState::getSingleValueDataChunkState();
Expand Down Expand Up @@ -288,7 +288,7 @@ void QFTSVertexCompute::vertexCompute(const graph::VertexScanState::Chunk& chunk
}
}

void QFTSAlgorithm::exec(processor::ExecutionContext* executionContext) {
void QueryFTSAlgorithm::exec(processor::ExecutionContext* executionContext) {
auto termsTableID = sharedState->graph->getNodeTableIDs()[0];
auto output = std::make_unique<QFTSOutput>();

Expand All @@ -298,7 +298,7 @@ void QFTSAlgorithm::exec(processor::ExecutionContext* executionContext) {
const graph::GraphEntry entry{{termsTableEntry}, {} /* relTableEntries */};
graph::OnDiskGraph graph(executionContext->clientContext, entry);
auto termsComputeSharedState =
TermsComputeSharedState{bindData->ptrCast<QFTSGDSBindData>()->terms};
TermsComputeSharedState{bindData->ptrCast<QueryFTSBindData>()->terms};
TermsVertexCompute termsCompute{&termsComputeSharedState};
GDSUtils::runVertexCompute(executionContext, &graph, termsCompute,
std::vector<std::string>{"term", "df"});
Expand All @@ -323,24 +323,24 @@ void QFTSAlgorithm::exec(processor::ExecutionContext* executionContext) {

runFrontiersOnce(executionContext, qFTSState, sharedState->graph.get(),
catalog->getTableCatalogEntry(transaction, sharedState->graph->getRelTableIDs()[0])
->getPropertyIdx(QFTSAlgorithm::TERM_FREQUENCY_PROP_NAME));
->getPropertyIdx(QueryFTSAlgorithm::TERM_FREQUENCY_PROP_NAME));

// Do vertex compute to calculate the score for doc with the length property.

auto writer =
std::make_unique<QFTSOutputWriter>(mm, output.get(), *bindData->ptrCast<QFTSGDSBindData>());
auto writer = std::make_unique<QFTSOutputWriter>(mm, output.get(),
*bindData->ptrCast<QueryFTSBindData>());
auto writerVC = std::make_unique<QFTSVertexCompute>(mm, sharedState.get(), std::move(writer));
auto docsTableID = sharedState->graph->getNodeTableIDs()[1];
GDSUtils::runVertexCompute(executionContext, sharedState->graph.get(), *writerVC, docsTableID,
{QFTSAlgorithm::DOC_LEN_PROP_NAME, QFTSAlgorithm::DOC_ID_PROP_NAME});
{QueryFTSAlgorithm::DOC_LEN_PROP_NAME, QueryFTSAlgorithm::DOC_ID_PROP_NAME});
sharedState->mergeLocalTables();
}

static std::shared_ptr<Expression> getScoreColumn(Binder* binder) {
return binder->createVariable(QFTSAlgorithm::SCORE_PROP_NAME, LogicalType::DOUBLE());
return binder->createVariable(QueryFTSAlgorithm::SCORE_PROP_NAME, LogicalType::DOUBLE());
}

binder::expression_vector QFTSAlgorithm::getResultColumns(binder::Binder* binder) const {
binder::expression_vector QueryFTSAlgorithm::getResultColumns(binder::Binder* binder) const {
expression_vector columns;
auto& docsNode = bindData->getNodeOutput()->constCast<NodeExpression>();
columns.push_back(docsNode.getInternalID());
Expand Down Expand Up @@ -374,7 +374,7 @@ static std::vector<std::string> getTerms(std::string& query, const std::string&
return result;
}

void QFTSAlgorithm::bind(const GDSBindInput& input, main::ClientContext& context) {
void QueryFTSAlgorithm::bind(const GDSBindInput& input, main::ClientContext& context) {
auto inputTableName = getParamVal(input, 0);
auto indexName = getParamVal(input, 1);
auto query = getParamVal(input, 2);
Expand All @@ -398,14 +398,14 @@ void QFTSAlgorithm::bind(const GDSBindInput& input, main::ClientContext& context
auto appearsInEntry = context.getCatalog()->getTableCatalogEntry(context.getTransaction(),
FTSUtils::getAppearsInTableName(tableEntry.getTableID(), indexName));
auto graphEntry = graph::GraphEntry({termsEntry, docsEntry}, {appearsInEntry});
bindData = std::make_unique<QFTSGDSBindData>(std::move(terms), std::move(graphEntry),
bindData = std::make_unique<QueryFTSBindData>(std::move(terms), std::move(graphEntry),
nodeOutput, ftsIndexEntry.getNumDocs(), ftsIndexEntry.getAvgDocLen(),
QueryFTSConfig{input.optionalParams});
}

function::function_set QFTSFunction::getFunctionSet() {
function::function_set QueryFTSFunction::getFunctionSet() {
function_set result;
auto algo = std::make_unique<QFTSAlgorithm>();
auto algo = std::make_unique<QueryFTSAlgorithm>();
result.push_back(
std::make_unique<GDSFunction>(name, algo->getParameterTypeIDs(), std::move(algo)));
return result;
Expand Down
10 changes: 9 additions & 1 deletion extension/fts/src/include/catalog/fts_index_catalog_entry.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class FTSIndexCatalogEntry final : public catalog::IndexCatalogEntry {
FTSIndexCatalogEntry() = default;
FTSIndexCatalogEntry(common::table_id_t tableID, std::string indexName, common::idx_t numDocs,
double avgDocLen, FTSConfig config)
: catalog::IndexCatalogEntry{tableID, std::move(indexName)}, numDocs{numDocs},
: catalog::IndexCatalogEntry{TYPE_NAME, tableID, std::move(indexName)}, numDocs{numDocs},
avgDocLen{avgDocLen}, config{std::move(config)} {}

//===--------------------------------------------------------------------===//
Expand All @@ -29,6 +29,14 @@ class FTSIndexCatalogEntry final : public catalog::IndexCatalogEntry {
//===--------------------------------------------------------------------===//
std::unique_ptr<catalog::IndexCatalogEntry> copy() const override;

void serializeAuxInfo(common::Serializer& serializer) const override;

static std::unique_ptr<FTSIndexCatalogEntry> deserializeAuxInfo(
catalog::IndexCatalogEntry* indexCatalogEntry);

public:
static constexpr char TYPE_NAME[] = "FTS";

private:
common::idx_t numDocs = 0;
double avgDocLen = 0;
Expand Down
6 changes: 6 additions & 0 deletions extension/fts/src/include/function/fts_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ struct FTSConfig {

FTSConfig() = default;
explicit FTSConfig(const function::optional_params_t& optionalParams);

void serialize(common::Serializer& serializer) const;

static FTSConfig deserialize(common::Deserializer& deserializer);

uint64_t getNumBytesForSerialization() const;
};

struct K {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@
namespace kuzu {
namespace fts_extension {

class QFTSAlgorithm : public function::GDSAlgorithm {
class QueryFTSAlgorithm : public function::GDSAlgorithm {
public:
static constexpr char SCORE_PROP_NAME[] = "score";
static constexpr char TERM_FREQUENCY_PROP_NAME[] = "tf";
static constexpr char DOC_LEN_PROP_NAME[] = "len";
static constexpr char DOC_ID_PROP_NAME[] = "docID";

public:
QFTSAlgorithm() = default;
QFTSAlgorithm(const QFTSAlgorithm& other) : GDSAlgorithm{other} {}
QueryFTSAlgorithm() = default;
QueryFTSAlgorithm(const QueryFTSAlgorithm& other) : GDSAlgorithm{other} {}

/*
* Inputs include the following:
Expand All @@ -34,15 +34,15 @@ class QFTSAlgorithm : public function::GDSAlgorithm {
void exec(processor::ExecutionContext* executionContext) override;

std::unique_ptr<GDSAlgorithm> copy() const override {
return std::make_unique<QFTSAlgorithm>(*this);
return std::make_unique<QueryFTSAlgorithm>(*this);
}

binder::expression_vector getResultColumns(binder::Binder* binder) const override;

void bind(const function::GDSBindInput& input, main::ClientContext&) override;
};

struct QFTSFunction {
struct QueryFTSFunction {
static constexpr const char* name = "QUERY_FTS_INDEX";

static function::function_set getFunctionSet();
Expand Down
21 changes: 21 additions & 0 deletions extension/fts/test/test_files/fts_small.test
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,24 @@ Binder exception: BM25 model requires the Term Frequency Saturation(k) value to
-STATEMENT CALL QUERY_FTS_INDEX('doc', 'docIdx', 'alice waterloo', conjunctive := true) RETURN _node.ID, score
---- 1
0|0.465323

-CASE fts_serialization
-SKIP_IN_MEM
-STATEMENT load extension "${KUZU_ROOT_DIRECTORY}/extension/fts/build/libfts.kuzu_extension"
---- ok
-STATEMENT CALL CREATE_FTS_INDEX('doc', 'docIdx', ['content', 'author', 'name'])
---- ok
-STATEMENT CALL QUERY_FTS_INDEX('doc', 'docIdx', 'Alice') RETURN _node.ID, score
---- 2
0|0.271133
3|0.209476
-RELOADDB
-STATEMENT CALL QUERY_FTS_INDEX('doc', 'docIdx', 'Alice') RETURN _node.ID, score
---- error
Catalog exception: QUERY_FTS_INDEX function does not exist.
acquamarin marked this conversation as resolved.
Show resolved Hide resolved
-STATEMENT load extension "${KUZU_ROOT_DIRECTORY}/extension/fts/build/libfts.kuzu_extension"
---- ok
-STATEMENT CALL QUERY_FTS_INDEX('doc', 'docIdx', 'Alice') RETURN _node.ID, score
---- 2
0|0.271133
3|0.209476
7 changes: 6 additions & 1 deletion src/catalog/catalog.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,10 @@ IndexCatalogEntry* Catalog::getIndex(const Transaction* transaction, common::tab
->ptrCast<IndexCatalogEntry>();
}

CatalogSet* Catalog::getIndexes() const {
return indexes.get();
}

bool Catalog::containsIndex(const transaction::Transaction* transaction, common::table_id_t tableID,
std::string indexName) const {
return indexes->containsEntry(transaction, common::stringFormat("{}_{}", tableID, indexName));
Expand Down Expand Up @@ -473,6 +477,7 @@ void Catalog::saveToFile(const std::string& directory, VirtualFileSystem* fs,
sequences->serialize(serializer);
functions->serialize(serializer);
types->serialize(serializer);
indexes->serialize(serializer);
}

void Catalog::readFromFile(const std::string& directory, VirtualFileSystem* fs,
Expand All @@ -489,7 +494,7 @@ void Catalog::readFromFile(const std::string& directory, VirtualFileSystem* fs,
sequences = CatalogSet::deserialize(deserializer);
functions = CatalogSet::deserialize(deserializer);
types = CatalogSet::deserialize(deserializer);
indexes = std::make_unique<CatalogSet>();
indexes = CatalogSet::deserialize(deserializer);
}

void Catalog::registerBuiltInFunctions() {
Expand Down
Loading
Loading