Skip to content

Commit

Permalink
Index entry serialization (#4686)
Browse files Browse the repository at this point in the history
  • Loading branch information
acquamarin authored Jan 8, 2025
1 parent decc6d5 commit 6c78434
Show file tree
Hide file tree
Showing 23 changed files with 237 additions and 72 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
cmake_minimum_required(VERSION 3.15)

project(Kuzu VERSION 0.7.1.1 LANGUAGES CXX C)
project(Kuzu VERSION 0.7.1.3 LANGUAGES CXX C)

option(SINGLE_THREADED "Single-threaded mode" FALSE)
if(SINGLE_THREADED)
Expand Down
24 changes: 24 additions & 0 deletions extension/fts/src/catalog/fts_index_catalog_entry.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "catalog/fts_index_catalog_entry.h"

#include "common/serializer/buffered_reader.h"

namespace kuzu {
namespace fts_extension {

Expand All @@ -11,5 +13,27 @@ std::unique_ptr<catalog::IndexCatalogEntry> FTSIndexCatalogEntry::copy() const {
return other;
}

void FTSIndexCatalogEntry::serializeAuxInfo(common::Serializer& serializer) const {
auto len = sizeof(numDocs) + sizeof(avgDocLen) + config.getNumBytesForSerialization();
serializer.serializeValue(len);
serializer.serializeValue(numDocs);
serializer.serializeValue(avgDocLen);
config.serialize(serializer);
}

std::unique_ptr<FTSIndexCatalogEntry> FTSIndexCatalogEntry::deserializeAuxInfo(
catalog::IndexCatalogEntry* indexCatalogEntry) {
auto reader = std::make_unique<common::BufferReader>(indexCatalogEntry->getAuxBuffer(),
indexCatalogEntry->getAuxBufferSize());
common::Deserializer deserializer{std::move(reader)};
common::idx_t numDocs = 0;
deserializer.deserializeValue(numDocs);
double avgDocLen = 0;
deserializer.deserializeValue(avgDocLen);
auto config = FTSConfig::deserialize(deserializer);
return std::make_unique<FTSIndexCatalogEntry>(indexCatalogEntry->getTableID(),
indexCatalogEntry->getIndexName(), numDocs, avgDocLen, config);
}

} // namespace fts_extension
} // namespace kuzu
21 changes: 18 additions & 3 deletions extension/fts/src/fts_extension.cpp
Original file line number Diff line number Diff line change
@@ -1,22 +1,37 @@
#include "fts_extension.h"

#include "catalog/catalog_entry/catalog_entry_type.h"
#include "catalog/catalog.h"
#include "catalog/catalog_entry/index_catalog_entry.h"
#include "catalog/fts_index_catalog_entry.h"
#include "function/create_fts_index.h"
#include "function/drop_fts_index.h"
#include "function/query_fts.h"
#include "function/query_fts_index.h"
#include "function/stem.h"
#include "main/client_context.h"
#include "main/database.h"

namespace kuzu {
namespace fts_extension {

static void initFTSEntries(transaction::Transaction* tx, catalog::Catalog& catalog) {
auto indexEntries = catalog.getIndexes()->getEntries(tx);
for (auto& [_, entry] : indexEntries) {
auto indexEntry = entry->ptrCast<catalog::IndexCatalogEntry>();
if (indexEntry->getIndexType() == FTSIndexCatalogEntry::TYPE_NAME) {
auto deserializedEntry = FTSIndexCatalogEntry::deserializeAuxInfo(indexEntry);
catalog.dropIndex(tx, indexEntry->getTableID(), indexEntry->getIndexName());
catalog.createIndex(tx, std::move(deserializedEntry));
}
}
}

void FTSExtension::load(main::ClientContext* context) {
auto& db = *context->getDatabase();
ADD_SCALAR_FUNC(StemFunction);
ADD_GDS_FUNC(QFTSFunction);
ADD_GDS_FUNC(QueryFTSFunction);
db.addStandaloneCallFunction(CreateFTSFunction::name, CreateFTSFunction::getFunctionSet());
db.addStandaloneCallFunction(DropFTSFunction::name, DropFTSFunction::getFunctionSet());
initFTSEntries(context->getTransaction(), *db.getCatalog());
}

} // namespace fts_extension
Expand Down
2 changes: 1 addition & 1 deletion extension/fts/src/function/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ add_library(kuzu_fts_function
create_fts_index.cpp
fts_config.cpp
drop_fts_index.cpp
query_fts.cpp
query_fts_index.cpp
fts_utils.cpp)

set(FTS_OBJECT_FILES
Expand Down
17 changes: 17 additions & 0 deletions extension/fts/src/function/fts_config.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include "function/fts_config.h"

#include "common/exception/binder.h"
#include "common/serializer/deserializer.h"
#include "common/serializer/serializer.h"
#include "common/string_utils.h"
#include "function/stem.h"

Expand All @@ -24,6 +26,21 @@ FTSConfig::FTSConfig(const function::optional_params_t& optionalParams) {
}
}

void FTSConfig::serialize(common::Serializer& serializer) const {
serializer.serializeValue(stemmer);
}

FTSConfig FTSConfig::deserialize(common::Deserializer& deserializer) {
auto config = FTSConfig{};
deserializer.deserializeValue(config.stemmer);
return config;
}

// We store the length + the string itself. So the total size is sizeof(length) + string length.
uint64_t FTSConfig::getNumBytesForSerialization() const {
return sizeof(stemmer.size()) + stemmer.size();
}

void K::validate(double value) {
if (value < 0) {
throw common::BinderException{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include "function/query_fts.h"
#include "function/query_fts_index.h"

#include "binder/binder.h"
#include "binder/expression/expression_util.h"
Expand Down Expand Up @@ -27,20 +27,20 @@ namespace fts_extension {

using namespace function;

struct QFTSGDSBindData final : public function::GDSBindData {
struct QueryFTSBindData final : public function::GDSBindData {
std::vector<std::string> terms;
uint64_t numDocs;
double_t avgDocLen;
QueryFTSConfig config;
common::table_id_t outputTableID;

QFTSGDSBindData(std::vector<std::string> terms, graph::GraphEntry graphEntry,
QueryFTSBindData(std::vector<std::string> terms, graph::GraphEntry graphEntry,
std::shared_ptr<binder::Expression> docs, uint64_t numDocs, double_t avgDocLen,
QueryFTSConfig config)
: GDSBindData{std::move(graphEntry), std::move(docs)}, terms{std::move(terms)},
numDocs{numDocs}, avgDocLen{avgDocLen}, config{config},
outputTableID{nodeOutput->constCast<NodeExpression>().getSingleEntry()->getTableID()} {}
QFTSGDSBindData(const QFTSGDSBindData& other)
QueryFTSBindData(const QueryFTSBindData& other)
: GDSBindData{other}, terms{other.terms}, numDocs{other.numDocs},
avgDocLen{other.avgDocLen}, config{other.config}, outputTableID{other.outputTableID} {}

Expand All @@ -49,11 +49,11 @@ struct QFTSGDSBindData final : public function::GDSBindData {
uint64_t getNumUniqueTerms() const;

std::unique_ptr<GDSBindData> copy() const override {
return std::make_unique<QFTSGDSBindData>(*this);
return std::make_unique<QueryFTSBindData>(*this);
}
};

uint64_t QFTSGDSBindData::getNumUniqueTerms() const {
uint64_t QueryFTSBindData::getNumUniqueTerms() const {
auto uniqueTerms = std::unordered_set<std::string>{terms.begin(), terms.end()};
return uniqueTerms.size();
}
Expand Down Expand Up @@ -156,7 +156,7 @@ void runFrontiersOnce(processor::ExecutionContext* executionContext, QFTSState&
class QFTSOutputWriter {
public:
QFTSOutputWriter(storage::MemoryManager* mm, QFTSOutput* qFTSOutput,
const QFTSGDSBindData& bindData);
const QueryFTSBindData& bindData);

void write(processor::FactorizedTable& scoreFT, nodeID_t docNodeID, uint64_t len,
int64_t docsID);
Expand All @@ -170,12 +170,12 @@ class QFTSOutputWriter {
std::vector<common::ValueVector*> vectors;
common::idx_t pos;
storage::MemoryManager* mm;
const QFTSGDSBindData& bindData;
const QueryFTSBindData& bindData;
uint64_t numUniqueTerms;
};

QFTSOutputWriter::QFTSOutputWriter(storage::MemoryManager* mm, QFTSOutput* qFTSOutput,
const QFTSGDSBindData& bindData)
const QueryFTSBindData& bindData)
: qFTSOutput{qFTSOutput}, docsVector{LogicalType::INTERNAL_ID(), mm},
scoreVector{LogicalType::UINT64(), mm}, mm{mm}, bindData{bindData} {
auto state = DataChunkState::getSingleValueDataChunkState();
Expand Down Expand Up @@ -288,7 +288,7 @@ void QFTSVertexCompute::vertexCompute(const graph::VertexScanState::Chunk& chunk
}
}

void QFTSAlgorithm::exec(processor::ExecutionContext* executionContext) {
void QueryFTSAlgorithm::exec(processor::ExecutionContext* executionContext) {
auto termsTableID = sharedState->graph->getNodeTableIDs()[0];
auto output = std::make_unique<QFTSOutput>();

Expand All @@ -298,7 +298,7 @@ void QFTSAlgorithm::exec(processor::ExecutionContext* executionContext) {
const graph::GraphEntry entry{{termsTableEntry}, {} /* relTableEntries */};
graph::OnDiskGraph graph(executionContext->clientContext, entry);
auto termsComputeSharedState =
TermsComputeSharedState{bindData->ptrCast<QFTSGDSBindData>()->terms};
TermsComputeSharedState{bindData->ptrCast<QueryFTSBindData>()->terms};
TermsVertexCompute termsCompute{&termsComputeSharedState};
GDSUtils::runVertexCompute(executionContext, &graph, termsCompute,
std::vector<std::string>{"term", "df"});
Expand All @@ -323,24 +323,24 @@ void QFTSAlgorithm::exec(processor::ExecutionContext* executionContext) {

runFrontiersOnce(executionContext, qFTSState, sharedState->graph.get(),
catalog->getTableCatalogEntry(transaction, sharedState->graph->getRelTableIDs()[0])
->getPropertyIdx(QFTSAlgorithm::TERM_FREQUENCY_PROP_NAME));
->getPropertyIdx(QueryFTSAlgorithm::TERM_FREQUENCY_PROP_NAME));

// Do vertex compute to calculate the score for doc with the length property.

auto writer =
std::make_unique<QFTSOutputWriter>(mm, output.get(), *bindData->ptrCast<QFTSGDSBindData>());
auto writer = std::make_unique<QFTSOutputWriter>(mm, output.get(),
*bindData->ptrCast<QueryFTSBindData>());
auto writerVC = std::make_unique<QFTSVertexCompute>(mm, sharedState.get(), std::move(writer));
auto docsTableID = sharedState->graph->getNodeTableIDs()[1];
GDSUtils::runVertexCompute(executionContext, sharedState->graph.get(), *writerVC, docsTableID,
{QFTSAlgorithm::DOC_LEN_PROP_NAME, QFTSAlgorithm::DOC_ID_PROP_NAME});
{QueryFTSAlgorithm::DOC_LEN_PROP_NAME, QueryFTSAlgorithm::DOC_ID_PROP_NAME});
sharedState->mergeLocalTables();
}

static std::shared_ptr<Expression> getScoreColumn(Binder* binder) {
return binder->createVariable(QFTSAlgorithm::SCORE_PROP_NAME, LogicalType::DOUBLE());
return binder->createVariable(QueryFTSAlgorithm::SCORE_PROP_NAME, LogicalType::DOUBLE());
}

binder::expression_vector QFTSAlgorithm::getResultColumns(binder::Binder* binder) const {
binder::expression_vector QueryFTSAlgorithm::getResultColumns(binder::Binder* binder) const {
expression_vector columns;
auto& docsNode = bindData->getNodeOutput()->constCast<NodeExpression>();
columns.push_back(docsNode.getInternalID());
Expand Down Expand Up @@ -374,7 +374,7 @@ static std::vector<std::string> getTerms(std::string& query, const std::string&
return result;
}

void QFTSAlgorithm::bind(const GDSBindInput& input, main::ClientContext& context) {
void QueryFTSAlgorithm::bind(const GDSBindInput& input, main::ClientContext& context) {
auto inputTableName = getParamVal(input, 0);
auto indexName = getParamVal(input, 1);
auto query = getParamVal(input, 2);
Expand All @@ -398,14 +398,14 @@ void QFTSAlgorithm::bind(const GDSBindInput& input, main::ClientContext& context
auto appearsInEntry = context.getCatalog()->getTableCatalogEntry(context.getTransaction(),
FTSUtils::getAppearsInTableName(tableEntry.getTableID(), indexName));
auto graphEntry = graph::GraphEntry({termsEntry, docsEntry}, {appearsInEntry});
bindData = std::make_unique<QFTSGDSBindData>(std::move(terms), std::move(graphEntry),
bindData = std::make_unique<QueryFTSBindData>(std::move(terms), std::move(graphEntry),
nodeOutput, ftsIndexEntry.getNumDocs(), ftsIndexEntry.getAvgDocLen(),
QueryFTSConfig{input.optionalParams});
}

function::function_set QFTSFunction::getFunctionSet() {
function::function_set QueryFTSFunction::getFunctionSet() {
function_set result;
auto algo = std::make_unique<QFTSAlgorithm>();
auto algo = std::make_unique<QueryFTSAlgorithm>();
result.push_back(
std::make_unique<GDSFunction>(name, algo->getParameterTypeIDs(), std::move(algo)));
return result;
Expand Down
10 changes: 9 additions & 1 deletion extension/fts/src/include/catalog/fts_index_catalog_entry.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class FTSIndexCatalogEntry final : public catalog::IndexCatalogEntry {
FTSIndexCatalogEntry() = default;
FTSIndexCatalogEntry(common::table_id_t tableID, std::string indexName, common::idx_t numDocs,
double avgDocLen, FTSConfig config)
: catalog::IndexCatalogEntry{tableID, std::move(indexName)}, numDocs{numDocs},
: catalog::IndexCatalogEntry{TYPE_NAME, tableID, std::move(indexName)}, numDocs{numDocs},
avgDocLen{avgDocLen}, config{std::move(config)} {}

//===--------------------------------------------------------------------===//
Expand All @@ -29,6 +29,14 @@ class FTSIndexCatalogEntry final : public catalog::IndexCatalogEntry {
//===--------------------------------------------------------------------===//
std::unique_ptr<catalog::IndexCatalogEntry> copy() const override;

void serializeAuxInfo(common::Serializer& serializer) const override;

static std::unique_ptr<FTSIndexCatalogEntry> deserializeAuxInfo(
catalog::IndexCatalogEntry* indexCatalogEntry);

public:
static constexpr char TYPE_NAME[] = "FTS";

private:
common::idx_t numDocs = 0;
double avgDocLen = 0;
Expand Down
6 changes: 6 additions & 0 deletions extension/fts/src/include/function/fts_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ struct FTSConfig {

FTSConfig() = default;
explicit FTSConfig(const function::optional_params_t& optionalParams);

void serialize(common::Serializer& serializer) const;

static FTSConfig deserialize(common::Deserializer& deserializer);

uint64_t getNumBytesForSerialization() const;
};

struct K {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@
namespace kuzu {
namespace fts_extension {

class QFTSAlgorithm : public function::GDSAlgorithm {
class QueryFTSAlgorithm : public function::GDSAlgorithm {
public:
static constexpr char SCORE_PROP_NAME[] = "score";
static constexpr char TERM_FREQUENCY_PROP_NAME[] = "tf";
static constexpr char DOC_LEN_PROP_NAME[] = "len";
static constexpr char DOC_ID_PROP_NAME[] = "docID";

public:
QFTSAlgorithm() = default;
QFTSAlgorithm(const QFTSAlgorithm& other) : GDSAlgorithm{other} {}
QueryFTSAlgorithm() = default;
QueryFTSAlgorithm(const QueryFTSAlgorithm& other) : GDSAlgorithm{other} {}

/*
* Inputs include the following:
Expand All @@ -34,15 +34,15 @@ class QFTSAlgorithm : public function::GDSAlgorithm {
void exec(processor::ExecutionContext* executionContext) override;

std::unique_ptr<GDSAlgorithm> copy() const override {
return std::make_unique<QFTSAlgorithm>(*this);
return std::make_unique<QueryFTSAlgorithm>(*this);
}

binder::expression_vector getResultColumns(binder::Binder* binder) const override;

void bind(const function::GDSBindInput& input, main::ClientContext&) override;
};

struct QFTSFunction {
struct QueryFTSFunction {
static constexpr const char* name = "QUERY_FTS_INDEX";

static function::function_set getFunctionSet();
Expand Down
21 changes: 21 additions & 0 deletions extension/fts/test/test_files/fts_small.test
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,24 @@ Binder exception: BM25 model requires the Term Frequency Saturation(k) value to
-STATEMENT CALL QUERY_FTS_INDEX('doc', 'docIdx', 'alice waterloo', conjunctive := true) RETURN _node.ID, score
---- 1
0|0.465323

-CASE fts_serialization
-SKIP_IN_MEM
-STATEMENT load extension "${KUZU_ROOT_DIRECTORY}/extension/fts/build/libfts.kuzu_extension"
---- ok
-STATEMENT CALL CREATE_FTS_INDEX('doc', 'docIdx', ['content', 'author', 'name'])
---- ok
-STATEMENT CALL QUERY_FTS_INDEX('doc', 'docIdx', 'Alice') RETURN _node.ID, score
---- 2
0|0.271133
3|0.209476
-RELOADDB
-STATEMENT CALL QUERY_FTS_INDEX('doc', 'docIdx', 'Alice') RETURN _node.ID, score
---- error
Catalog exception: QUERY_FTS_INDEX function does not exist.
-STATEMENT load extension "${KUZU_ROOT_DIRECTORY}/extension/fts/build/libfts.kuzu_extension"
---- ok
-STATEMENT CALL QUERY_FTS_INDEX('doc', 'docIdx', 'Alice') RETURN _node.ID, score
---- 2
0|0.271133
3|0.209476
7 changes: 6 additions & 1 deletion src/catalog/catalog.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,10 @@ IndexCatalogEntry* Catalog::getIndex(const Transaction* transaction, common::tab
->ptrCast<IndexCatalogEntry>();
}

CatalogSet* Catalog::getIndexes() const {
return indexes.get();
}

bool Catalog::containsIndex(const transaction::Transaction* transaction, common::table_id_t tableID,
std::string indexName) const {
return indexes->containsEntry(transaction, common::stringFormat("{}_{}", tableID, indexName));
Expand Down Expand Up @@ -473,6 +477,7 @@ void Catalog::saveToFile(const std::string& directory, VirtualFileSystem* fs,
sequences->serialize(serializer);
functions->serialize(serializer);
types->serialize(serializer);
indexes->serialize(serializer);
}

void Catalog::readFromFile(const std::string& directory, VirtualFileSystem* fs,
Expand All @@ -489,7 +494,7 @@ void Catalog::readFromFile(const std::string& directory, VirtualFileSystem* fs,
sequences = CatalogSet::deserialize(deserializer);
functions = CatalogSet::deserialize(deserializer);
types = CatalogSet::deserialize(deserializer);
indexes = std::make_unique<CatalogSet>();
indexes = CatalogSet::deserialize(deserializer);
}

void Catalog::registerBuiltInFunctions() {
Expand Down
Loading

0 comments on commit 6c78434

Please sign in to comment.