diff --git a/CMakeLists.txt b/CMakeLists.txt index 30470957c47..e88771969f6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required(VERSION 3.15) -project(Kuzu VERSION 0.7.1.1 LANGUAGES CXX C) +project(Kuzu VERSION 0.7.1.3 LANGUAGES CXX C) option(SINGLE_THREADED "Single-threaded mode" FALSE) if(SINGLE_THREADED) diff --git a/extension/fts/src/catalog/fts_index_catalog_entry.cpp b/extension/fts/src/catalog/fts_index_catalog_entry.cpp index 3439361274f..1898187673d 100644 --- a/extension/fts/src/catalog/fts_index_catalog_entry.cpp +++ b/extension/fts/src/catalog/fts_index_catalog_entry.cpp @@ -1,5 +1,7 @@ #include "catalog/fts_index_catalog_entry.h" +#include "common/serializer/buffered_reader.h" + namespace kuzu { namespace fts_extension { @@ -11,5 +13,27 @@ std::unique_ptr FTSIndexCatalogEntry::copy() const { return other; } +void FTSIndexCatalogEntry::serializeAuxInfo(common::Serializer& serializer) const { + auto len = sizeof(numDocs) + sizeof(avgDocLen) + config.getNumBytesForSerialization(); + serializer.serializeValue(len); + serializer.serializeValue(numDocs); + serializer.serializeValue(avgDocLen); + config.serialize(serializer); +} + +std::unique_ptr FTSIndexCatalogEntry::deserializeAuxInfo( + catalog::IndexCatalogEntry* indexCatalogEntry) { + auto reader = std::make_unique(indexCatalogEntry->getAuxBuffer(), + indexCatalogEntry->getAuxBufferSize()); + common::Deserializer deserializer{std::move(reader)}; + common::idx_t numDocs = 0; + deserializer.deserializeValue(numDocs); + double avgDocLen = 0; + deserializer.deserializeValue(avgDocLen); + auto config = FTSConfig::deserialize(deserializer); + return std::make_unique(indexCatalogEntry->getTableID(), + indexCatalogEntry->getIndexName(), numDocs, avgDocLen, config); +} + } // namespace fts_extension } // namespace kuzu diff --git a/extension/fts/src/fts_extension.cpp b/extension/fts/src/fts_extension.cpp index 8ae17b7a7de..44262b8f47f 100644 --- a/extension/fts/src/fts_extension.cpp +++ b/extension/fts/src/fts_extension.cpp @@ -1,9 +1,11 @@ #include "fts_extension.h" -#include "catalog/catalog_entry/catalog_entry_type.h" +#include "catalog/catalog.h" +#include "catalog/catalog_entry/index_catalog_entry.h" +#include "catalog/fts_index_catalog_entry.h" #include "function/create_fts_index.h" #include "function/drop_fts_index.h" -#include "function/query_fts.h" +#include "function/query_fts_index.h" #include "function/stem.h" #include "main/client_context.h" #include "main/database.h" @@ -11,12 +13,25 @@ namespace kuzu { namespace fts_extension { +static void initFTSEntries(transaction::Transaction* tx, catalog::Catalog& catalog) { + auto indexEntries = catalog.getIndexes()->getEntries(tx); + for (auto& [_, entry] : indexEntries) { + auto indexEntry = entry->ptrCast(); + if (indexEntry->getIndexType() == FTSIndexCatalogEntry::TYPE_NAME) { + auto deserializedEntry = FTSIndexCatalogEntry::deserializeAuxInfo(indexEntry); + catalog.dropIndex(tx, indexEntry->getTableID(), indexEntry->getIndexName()); + catalog.createIndex(tx, std::move(deserializedEntry)); + } + } +} + void FTSExtension::load(main::ClientContext* context) { auto& db = *context->getDatabase(); ADD_SCALAR_FUNC(StemFunction); - ADD_GDS_FUNC(QFTSFunction); + ADD_GDS_FUNC(QueryFTSFunction); db.addStandaloneCallFunction(CreateFTSFunction::name, CreateFTSFunction::getFunctionSet()); db.addStandaloneCallFunction(DropFTSFunction::name, DropFTSFunction::getFunctionSet()); + initFTSEntries(context->getTransaction(), *db.getCatalog()); } } // namespace fts_extension diff --git a/extension/fts/src/function/CMakeLists.txt b/extension/fts/src/function/CMakeLists.txt index 415beb4a0e3..8e4fa24bae1 100644 --- a/extension/fts/src/function/CMakeLists.txt +++ b/extension/fts/src/function/CMakeLists.txt @@ -4,7 +4,7 @@ add_library(kuzu_fts_function create_fts_index.cpp fts_config.cpp drop_fts_index.cpp - query_fts.cpp + query_fts_index.cpp fts_utils.cpp) set(FTS_OBJECT_FILES diff --git a/extension/fts/src/function/fts_config.cpp b/extension/fts/src/function/fts_config.cpp index 7060f09e766..ebcf4363f03 100644 --- a/extension/fts/src/function/fts_config.cpp +++ b/extension/fts/src/function/fts_config.cpp @@ -1,6 +1,8 @@ #include "function/fts_config.h" #include "common/exception/binder.h" +#include "common/serializer/deserializer.h" +#include "common/serializer/serializer.h" #include "common/string_utils.h" #include "function/stem.h" @@ -24,6 +26,21 @@ FTSConfig::FTSConfig(const function::optional_params_t& optionalParams) { } } +void FTSConfig::serialize(common::Serializer& serializer) const { + serializer.serializeValue(stemmer); +} + +FTSConfig FTSConfig::deserialize(common::Deserializer& deserializer) { + auto config = FTSConfig{}; + deserializer.deserializeValue(config.stemmer); + return config; +} + +// We store the length + the string itself. So the total size is sizeof(length) + string length. +uint64_t FTSConfig::getNumBytesForSerialization() const { + return sizeof(stemmer.size()) + stemmer.size(); +} + void K::validate(double value) { if (value < 0) { throw common::BinderException{ diff --git a/extension/fts/src/function/query_fts.cpp b/extension/fts/src/function/query_fts_index.cpp similarity index 91% rename from extension/fts/src/function/query_fts.cpp rename to extension/fts/src/function/query_fts_index.cpp index c43f71d6193..37c46635f5f 100644 --- a/extension/fts/src/function/query_fts.cpp +++ b/extension/fts/src/function/query_fts_index.cpp @@ -1,4 +1,4 @@ -#include "function/query_fts.h" +#include "function/query_fts_index.h" #include "binder/binder.h" #include "binder/expression/expression_util.h" @@ -27,20 +27,20 @@ namespace fts_extension { using namespace function; -struct QFTSGDSBindData final : public function::GDSBindData { +struct QueryFTSBindData final : public function::GDSBindData { std::vector terms; uint64_t numDocs; double_t avgDocLen; QueryFTSConfig config; common::table_id_t outputTableID; - QFTSGDSBindData(std::vector terms, graph::GraphEntry graphEntry, + QueryFTSBindData(std::vector terms, graph::GraphEntry graphEntry, std::shared_ptr docs, uint64_t numDocs, double_t avgDocLen, QueryFTSConfig config) : GDSBindData{std::move(graphEntry), std::move(docs)}, terms{std::move(terms)}, numDocs{numDocs}, avgDocLen{avgDocLen}, config{config}, outputTableID{nodeOutput->constCast().getSingleEntry()->getTableID()} {} - QFTSGDSBindData(const QFTSGDSBindData& other) + QueryFTSBindData(const QueryFTSBindData& other) : GDSBindData{other}, terms{other.terms}, numDocs{other.numDocs}, avgDocLen{other.avgDocLen}, config{other.config}, outputTableID{other.outputTableID} {} @@ -49,11 +49,11 @@ struct QFTSGDSBindData final : public function::GDSBindData { uint64_t getNumUniqueTerms() const; std::unique_ptr copy() const override { - return std::make_unique(*this); + return std::make_unique(*this); } }; -uint64_t QFTSGDSBindData::getNumUniqueTerms() const { +uint64_t QueryFTSBindData::getNumUniqueTerms() const { auto uniqueTerms = std::unordered_set{terms.begin(), terms.end()}; return uniqueTerms.size(); } @@ -156,7 +156,7 @@ void runFrontiersOnce(processor::ExecutionContext* executionContext, QFTSState& class QFTSOutputWriter { public: QFTSOutputWriter(storage::MemoryManager* mm, QFTSOutput* qFTSOutput, - const QFTSGDSBindData& bindData); + const QueryFTSBindData& bindData); void write(processor::FactorizedTable& scoreFT, nodeID_t docNodeID, uint64_t len, int64_t docsID); @@ -170,12 +170,12 @@ class QFTSOutputWriter { std::vector vectors; common::idx_t pos; storage::MemoryManager* mm; - const QFTSGDSBindData& bindData; + const QueryFTSBindData& bindData; uint64_t numUniqueTerms; }; QFTSOutputWriter::QFTSOutputWriter(storage::MemoryManager* mm, QFTSOutput* qFTSOutput, - const QFTSGDSBindData& bindData) + const QueryFTSBindData& bindData) : qFTSOutput{qFTSOutput}, docsVector{LogicalType::INTERNAL_ID(), mm}, scoreVector{LogicalType::UINT64(), mm}, mm{mm}, bindData{bindData} { auto state = DataChunkState::getSingleValueDataChunkState(); @@ -288,7 +288,7 @@ void QFTSVertexCompute::vertexCompute(const graph::VertexScanState::Chunk& chunk } } -void QFTSAlgorithm::exec(processor::ExecutionContext* executionContext) { +void QueryFTSAlgorithm::exec(processor::ExecutionContext* executionContext) { auto termsTableID = sharedState->graph->getNodeTableIDs()[0]; auto output = std::make_unique(); @@ -298,7 +298,7 @@ void QFTSAlgorithm::exec(processor::ExecutionContext* executionContext) { const graph::GraphEntry entry{{termsTableEntry}, {} /* relTableEntries */}; graph::OnDiskGraph graph(executionContext->clientContext, entry); auto termsComputeSharedState = - TermsComputeSharedState{bindData->ptrCast()->terms}; + TermsComputeSharedState{bindData->ptrCast()->terms}; TermsVertexCompute termsCompute{&termsComputeSharedState}; GDSUtils::runVertexCompute(executionContext, &graph, termsCompute, std::vector{"term", "df"}); @@ -323,24 +323,24 @@ void QFTSAlgorithm::exec(processor::ExecutionContext* executionContext) { runFrontiersOnce(executionContext, qFTSState, sharedState->graph.get(), catalog->getTableCatalogEntry(transaction, sharedState->graph->getRelTableIDs()[0]) - ->getPropertyIdx(QFTSAlgorithm::TERM_FREQUENCY_PROP_NAME)); + ->getPropertyIdx(QueryFTSAlgorithm::TERM_FREQUENCY_PROP_NAME)); // Do vertex compute to calculate the score for doc with the length property. - auto writer = - std::make_unique(mm, output.get(), *bindData->ptrCast()); + auto writer = std::make_unique(mm, output.get(), + *bindData->ptrCast()); auto writerVC = std::make_unique(mm, sharedState.get(), std::move(writer)); auto docsTableID = sharedState->graph->getNodeTableIDs()[1]; GDSUtils::runVertexCompute(executionContext, sharedState->graph.get(), *writerVC, docsTableID, - {QFTSAlgorithm::DOC_LEN_PROP_NAME, QFTSAlgorithm::DOC_ID_PROP_NAME}); + {QueryFTSAlgorithm::DOC_LEN_PROP_NAME, QueryFTSAlgorithm::DOC_ID_PROP_NAME}); sharedState->mergeLocalTables(); } static std::shared_ptr getScoreColumn(Binder* binder) { - return binder->createVariable(QFTSAlgorithm::SCORE_PROP_NAME, LogicalType::DOUBLE()); + return binder->createVariable(QueryFTSAlgorithm::SCORE_PROP_NAME, LogicalType::DOUBLE()); } -binder::expression_vector QFTSAlgorithm::getResultColumns(binder::Binder* binder) const { +binder::expression_vector QueryFTSAlgorithm::getResultColumns(binder::Binder* binder) const { expression_vector columns; auto& docsNode = bindData->getNodeOutput()->constCast(); columns.push_back(docsNode.getInternalID()); @@ -374,7 +374,7 @@ static std::vector getTerms(std::string& query, const std::string& return result; } -void QFTSAlgorithm::bind(const GDSBindInput& input, main::ClientContext& context) { +void QueryFTSAlgorithm::bind(const GDSBindInput& input, main::ClientContext& context) { auto inputTableName = getParamVal(input, 0); auto indexName = getParamVal(input, 1); auto query = getParamVal(input, 2); @@ -398,14 +398,14 @@ void QFTSAlgorithm::bind(const GDSBindInput& input, main::ClientContext& context auto appearsInEntry = context.getCatalog()->getTableCatalogEntry(context.getTransaction(), FTSUtils::getAppearsInTableName(tableEntry.getTableID(), indexName)); auto graphEntry = graph::GraphEntry({termsEntry, docsEntry}, {appearsInEntry}); - bindData = std::make_unique(std::move(terms), std::move(graphEntry), + bindData = std::make_unique(std::move(terms), std::move(graphEntry), nodeOutput, ftsIndexEntry.getNumDocs(), ftsIndexEntry.getAvgDocLen(), QueryFTSConfig{input.optionalParams}); } -function::function_set QFTSFunction::getFunctionSet() { +function::function_set QueryFTSFunction::getFunctionSet() { function_set result; - auto algo = std::make_unique(); + auto algo = std::make_unique(); result.push_back( std::make_unique(name, algo->getParameterTypeIDs(), std::move(algo))); return result; diff --git a/extension/fts/src/include/catalog/fts_index_catalog_entry.h b/extension/fts/src/include/catalog/fts_index_catalog_entry.h index e9536a052c5..ea47a5ebdab 100644 --- a/extension/fts/src/include/catalog/fts_index_catalog_entry.h +++ b/extension/fts/src/include/catalog/fts_index_catalog_entry.h @@ -14,7 +14,7 @@ class FTSIndexCatalogEntry final : public catalog::IndexCatalogEntry { FTSIndexCatalogEntry() = default; FTSIndexCatalogEntry(common::table_id_t tableID, std::string indexName, common::idx_t numDocs, double avgDocLen, FTSConfig config) - : catalog::IndexCatalogEntry{tableID, std::move(indexName)}, numDocs{numDocs}, + : catalog::IndexCatalogEntry{TYPE_NAME, tableID, std::move(indexName)}, numDocs{numDocs}, avgDocLen{avgDocLen}, config{std::move(config)} {} //===--------------------------------------------------------------------===// @@ -29,6 +29,14 @@ class FTSIndexCatalogEntry final : public catalog::IndexCatalogEntry { //===--------------------------------------------------------------------===// std::unique_ptr copy() const override; + void serializeAuxInfo(common::Serializer& serializer) const override; + + static std::unique_ptr deserializeAuxInfo( + catalog::IndexCatalogEntry* indexCatalogEntry); + +public: + static constexpr char TYPE_NAME[] = "FTS"; + private: common::idx_t numDocs = 0; double avgDocLen = 0; diff --git a/extension/fts/src/include/function/fts_config.h b/extension/fts/src/include/function/fts_config.h index 5ad456af7ba..a440fa2a3e2 100644 --- a/extension/fts/src/include/function/fts_config.h +++ b/extension/fts/src/include/function/fts_config.h @@ -21,6 +21,12 @@ struct FTSConfig { FTSConfig() = default; explicit FTSConfig(const function::optional_params_t& optionalParams); + + void serialize(common::Serializer& serializer) const; + + static FTSConfig deserialize(common::Deserializer& deserializer); + + uint64_t getNumBytesForSerialization() const; }; struct K { diff --git a/extension/fts/src/include/function/query_fts.h b/extension/fts/src/include/function/query_fts_index.h similarity index 83% rename from extension/fts/src/include/function/query_fts.h rename to extension/fts/src/include/function/query_fts_index.h index 0edbb448ca2..90a49023689 100644 --- a/extension/fts/src/include/function/query_fts.h +++ b/extension/fts/src/include/function/query_fts_index.h @@ -7,7 +7,7 @@ namespace kuzu { namespace fts_extension { -class QFTSAlgorithm : public function::GDSAlgorithm { +class QueryFTSAlgorithm : public function::GDSAlgorithm { public: static constexpr char SCORE_PROP_NAME[] = "score"; static constexpr char TERM_FREQUENCY_PROP_NAME[] = "tf"; @@ -15,8 +15,8 @@ class QFTSAlgorithm : public function::GDSAlgorithm { static constexpr char DOC_ID_PROP_NAME[] = "docID"; public: - QFTSAlgorithm() = default; - QFTSAlgorithm(const QFTSAlgorithm& other) : GDSAlgorithm{other} {} + QueryFTSAlgorithm() = default; + QueryFTSAlgorithm(const QueryFTSAlgorithm& other) : GDSAlgorithm{other} {} /* * Inputs include the following: @@ -34,7 +34,7 @@ class QFTSAlgorithm : public function::GDSAlgorithm { void exec(processor::ExecutionContext* executionContext) override; std::unique_ptr copy() const override { - return std::make_unique(*this); + return std::make_unique(*this); } binder::expression_vector getResultColumns(binder::Binder* binder) const override; @@ -42,7 +42,7 @@ class QFTSAlgorithm : public function::GDSAlgorithm { void bind(const function::GDSBindInput& input, main::ClientContext&) override; }; -struct QFTSFunction { +struct QueryFTSFunction { static constexpr const char* name = "QUERY_FTS_INDEX"; static function::function_set getFunctionSet(); diff --git a/extension/fts/test/test_files/fts_small.test b/extension/fts/test/test_files/fts_small.test index e5319636cc3..0cdb83750fa 100644 --- a/extension/fts/test/test_files/fts_small.test +++ b/extension/fts/test/test_files/fts_small.test @@ -163,3 +163,24 @@ Binder exception: BM25 model requires the Term Frequency Saturation(k) value to -STATEMENT CALL QUERY_FTS_INDEX('doc', 'docIdx', 'alice waterloo', conjunctive := true) RETURN _node.ID, score ---- 1 0|0.465323 + +-CASE fts_serialization +-SKIP_IN_MEM +-STATEMENT load extension "${KUZU_ROOT_DIRECTORY}/extension/fts/build/libfts.kuzu_extension" +---- ok +-STATEMENT CALL CREATE_FTS_INDEX('doc', 'docIdx', ['content', 'author', 'name']) +---- ok +-STATEMENT CALL QUERY_FTS_INDEX('doc', 'docIdx', 'Alice') RETURN _node.ID, score +---- 2 +0|0.271133 +3|0.209476 +-RELOADDB +-STATEMENT CALL QUERY_FTS_INDEX('doc', 'docIdx', 'Alice') RETURN _node.ID, score +---- error +Catalog exception: QUERY_FTS_INDEX function does not exist. +-STATEMENT load extension "${KUZU_ROOT_DIRECTORY}/extension/fts/build/libfts.kuzu_extension" +---- ok +-STATEMENT CALL QUERY_FTS_INDEX('doc', 'docIdx', 'Alice') RETURN _node.ID, score +---- 2 +0|0.271133 +3|0.209476 diff --git a/src/catalog/catalog.cpp b/src/catalog/catalog.cpp index 8a0fc65691b..23c172a2eb4 100644 --- a/src/catalog/catalog.cpp +++ b/src/catalog/catalog.cpp @@ -331,6 +331,10 @@ IndexCatalogEntry* Catalog::getIndex(const Transaction* transaction, common::tab ->ptrCast(); } +CatalogSet* Catalog::getIndexes() const { + return indexes.get(); +} + bool Catalog::containsIndex(const transaction::Transaction* transaction, common::table_id_t tableID, std::string indexName) const { return indexes->containsEntry(transaction, common::stringFormat("{}_{}", tableID, indexName)); @@ -473,6 +477,7 @@ void Catalog::saveToFile(const std::string& directory, VirtualFileSystem* fs, sequences->serialize(serializer); functions->serialize(serializer); types->serialize(serializer); + indexes->serialize(serializer); } void Catalog::readFromFile(const std::string& directory, VirtualFileSystem* fs, @@ -489,7 +494,7 @@ void Catalog::readFromFile(const std::string& directory, VirtualFileSystem* fs, sequences = CatalogSet::deserialize(deserializer); functions = CatalogSet::deserialize(deserializer); types = CatalogSet::deserialize(deserializer); - indexes = std::make_unique(); + indexes = CatalogSet::deserialize(deserializer); } void Catalog::registerBuiltInFunctions() { diff --git a/src/catalog/catalog_entry/CMakeLists.txt b/src/catalog/catalog_entry/CMakeLists.txt index c4a99a80414..1f17d473643 100644 --- a/src/catalog/catalog_entry/CMakeLists.txt +++ b/src/catalog/catalog_entry/CMakeLists.txt @@ -9,7 +9,8 @@ add_library(kuzu_catalog_entry rel_group_catalog_entry.cpp scalar_macro_catalog_entry.cpp type_catalog_entry.cpp - sequence_catalog_entry.cpp) + sequence_catalog_entry.cpp + index_catalog_entry.cpp) set(ALL_OBJECT_FILES ${ALL_OBJECT_FILES} $ diff --git a/src/catalog/catalog_entry/catalog_entry.cpp b/src/catalog/catalog_entry/catalog_entry.cpp index 06e0ab70b60..265d3093de2 100644 --- a/src/catalog/catalog_entry/catalog_entry.cpp +++ b/src/catalog/catalog_entry/catalog_entry.cpp @@ -1,5 +1,6 @@ #include "catalog/catalog_entry/catalog_entry.h" +#include "catalog/catalog_entry/index_catalog_entry.h" #include "catalog/catalog_entry/scalar_macro_catalog_entry.h" #include "catalog/catalog_entry/sequence_catalog_entry.h" #include "catalog/catalog_entry/table_catalog_entry.h" @@ -51,6 +52,9 @@ std::unique_ptr CatalogEntry::deserialize(common::Deserializer& de case CatalogEntryType::TYPE_ENTRY: { entry = TypeCatalogEntry::deserialize(deserializer); } break; + case CatalogEntryType::INDEX_ENTRY: { + entry = IndexCatalogEntry::deserialize(deserializer); + } break; default: KU_UNREACHABLE; } diff --git a/src/catalog/catalog_entry/index_catalog_entry.cpp b/src/catalog/catalog_entry/index_catalog_entry.cpp new file mode 100644 index 00000000000..abb8099463c --- /dev/null +++ b/src/catalog/catalog_entry/index_catalog_entry.cpp @@ -0,0 +1,39 @@ +#include "catalog/catalog_entry/index_catalog_entry.h" + +namespace kuzu { +namespace catalog { + +void IndexCatalogEntry::serialize(common::Serializer& serializer) const { + CatalogEntry::serialize(serializer); + serializer.write(type); + serializer.write(tableID); + serializer.write(indexName); + serializeAuxInfo(serializer); +} + +std::unique_ptr IndexCatalogEntry::deserialize( + common::Deserializer& deserializer) { + std::string type; + common::table_id_t tableID = common::INVALID_TABLE_ID; + std::string indexName; + deserializer.deserializeValue(type); + deserializer.deserializeValue(tableID); + deserializer.deserializeValue(indexName); + auto indexEntry = std::make_unique(type, tableID, std::move(indexName)); + uint64_t auxBufferSize = 0; + deserializer.deserializeValue(auxBufferSize); + indexEntry->auxBuffer = std::make_unique(auxBufferSize); + indexEntry->auxBufferSize = auxBufferSize; + deserializer.read(indexEntry->auxBuffer.get(), auxBufferSize); + return indexEntry; +} + +void IndexCatalogEntry::copyFrom(const CatalogEntry& other) { + CatalogEntry::copyFrom(other); + auto& otherTable = other.constCast(); + tableID = otherTable.tableID; + indexName = otherTable.indexName; +} + +} // namespace catalog +} // namespace kuzu diff --git a/src/include/catalog/catalog.h b/src/include/catalog/catalog.h index 6931c08a028..898e0b3111d 100644 --- a/src/include/catalog/catalog.h +++ b/src/include/catalog/catalog.h @@ -124,6 +124,7 @@ class KUZU_API Catalog { std::unique_ptr indexCatalogEntry); IndexCatalogEntry* getIndex(const transaction::Transaction*, common::table_id_t tableID, std::string indexName) const; + CatalogSet* getIndexes() const; bool containsIndex(const transaction::Transaction* transaction, common::table_id_t tableID, std::string indexName) const; void dropAllIndexes(transaction::Transaction* transaction, common::table_id_t tableID); diff --git a/src/include/catalog/catalog_entry/index_catalog_entry.h b/src/include/catalog/catalog_entry/index_catalog_entry.h index f1d1697c4c9..68b6933b1b4 100644 --- a/src/include/catalog/catalog_entry/index_catalog_entry.h +++ b/src/include/catalog/catalog_entry/index_catalog_entry.h @@ -1,6 +1,7 @@ #pragma once #include "catalog_entry.h" +#include "common/serializer/deserializer.h" namespace kuzu { namespace catalog { @@ -12,38 +13,48 @@ class KUZU_API IndexCatalogEntry : public CatalogEntry { //===--------------------------------------------------------------------===// IndexCatalogEntry() = default; - IndexCatalogEntry(common::table_id_t tableID, std::string indexName) + IndexCatalogEntry(std::string type, common::table_id_t tableID, std::string indexName) : CatalogEntry{CatalogEntryType::INDEX_ENTRY, common::stringFormat("{}_{}", tableID, indexName)}, - tableID{tableID}, indexName{std::move(indexName)} {} + type{std::move(type)}, tableID{tableID}, indexName{std::move(indexName)} {} + + std::string getIndexType() const { return type; } common::table_id_t getTableID() const { return tableID; } + std::string getIndexName() const { return indexName; } + + uint8_t* getAuxBuffer() const { return auxBuffer.get(); } + + uint64_t getAuxBufferSize() const { return auxBufferSize; } + //===--------------------------------------------------------------------===// // serialization & deserialization //===--------------------------------------------------------------------===// - void serialize(common::Serializer& /*serializer*/) const override {} - // TODO(Ziyi/Guodong) : If the database fails with loaded extensions, should we restart the db - // and reload previously loaded extensions? Currently, we don't have the mechanism to reload - // extensions during recovery, thus, we are not able to recover the indexes created by - // extensions. - static std::unique_ptr deserialize(common::Deserializer& /*deserializer*/) { - KU_UNREACHABLE; - } + // When serializing index entries to disk, we first write the fields of the base class, + // followed by the size (in bytes) of the auxiliary data and its content. + void serialize(common::Serializer& serializer) const override; + // During deserialization of index entries from disk, we first read the base class + // (IndexCatalogEntry). The auxiliary data is stored in auxBuffer, with its size in + // auxBufferSize. Once the extension is loaded, the corresponding indexes are reconstructed + // using the auxBuffer. + static std::unique_ptr deserialize(common::Deserializer& deserializer); - std::string toCypher(main::ClientContext* /*clientContext*/) const override { KU_UNREACHABLE; } - virtual std::unique_ptr copy() const = 0; + virtual void serializeAuxInfo(common::Serializer& /*serializer*/) const { KU_UNREACHABLE; } - void copyFrom(const CatalogEntry& other) override { - CatalogEntry::copyFrom(other); - auto& otherTable = other.constCast(); - tableID = otherTable.tableID; - indexName = otherTable.indexName; + std::string toCypher(main::ClientContext* /*clientContext*/) const override { KU_UNREACHABLE; } + virtual std::unique_ptr copy() const { + return std::make_unique(type, tableID, indexName); } + void copyFrom(const CatalogEntry& other) override; + protected: + std::string type; common::table_id_t tableID = common::INVALID_TABLE_ID; std::string indexName; + std::unique_ptr auxBuffer; + uint64_t auxBufferSize = 0; }; } // namespace catalog diff --git a/src/include/catalog/catalog_set.h b/src/include/catalog/catalog_set.h index 7c94bcb9bb8..67fdf0e7db6 100644 --- a/src/include/catalog/catalog_set.h +++ b/src/include/catalog/catalog_set.h @@ -22,13 +22,13 @@ class Transaction; using CatalogEntrySet = common::case_insensitive_map_t; namespace catalog { -class CatalogSet { +class KUZU_API CatalogSet { friend class storage::UndoBuffer; public: bool containsEntry(const transaction::Transaction* transaction, const std::string& name); CatalogEntry* getEntry(const transaction::Transaction* transaction, const std::string& name); - KUZU_API common::oid_t createEntry(transaction::Transaction* transaction, + common::oid_t createEntry(transaction::Transaction* transaction, std::unique_ptr entry); void dropEntry(transaction::Transaction* transaction, const std::string& name, common::oid_t oid); diff --git a/src/include/common/serializer/buffered_reader.h b/src/include/common/serializer/buffered_reader.h new file mode 100644 index 00000000000..63da1c56fb5 --- /dev/null +++ b/src/include/common/serializer/buffered_reader.h @@ -0,0 +1,26 @@ +#pragma once + +#include + +#include "common/serializer/reader.h" + +namespace kuzu { +namespace common { + +struct BufferReader : Reader { + BufferReader(uint8_t* data, size_t dataSize) : data(data), dataSize(dataSize), readSize(0) {} + + void read(uint8_t* outputData, uint64_t size) final { + memcpy(outputData, data + readSize, size); + readSize += size; + } + + bool finished() final { return readSize >= dataSize; } + + uint8_t* data; + size_t dataSize; + size_t readSize; +}; + +} // namespace common +} // namespace kuzu diff --git a/src/include/common/serializer/deserializer.h b/src/include/common/serializer/deserializer.h index f89a2592393..3a9361d7ff1 100644 --- a/src/include/common/serializer/deserializer.h +++ b/src/include/common/serializer/deserializer.h @@ -14,7 +14,7 @@ namespace kuzu { namespace common { -class Deserializer { +class KUZU_API Deserializer { public: explicit Deserializer(std::unique_ptr reader) : reader(std::move(reader)) {} diff --git a/src/include/common/serializer/serializer.h b/src/include/common/serializer/serializer.h index a9269552d59..5a3baa88bcb 100644 --- a/src/include/common/serializer/serializer.h +++ b/src/include/common/serializer/serializer.h @@ -13,7 +13,7 @@ namespace kuzu { namespace common { -class Serializer { +class KUZU_API Serializer { public: explicit Serializer(std::shared_ptr writer) : writer(std::move(writer)) {} diff --git a/src/include/parser/visitor/statement_read_write_analyzer.h b/src/include/parser/visitor/statement_read_write_analyzer.h index 7232362023c..0950cf3b58b 100644 --- a/src/include/parser/visitor/statement_read_write_analyzer.h +++ b/src/include/parser/visitor/statement_read_write_analyzer.h @@ -21,6 +21,7 @@ class StatementReadWriteAnalyzer final : public StatementVisitor { void visitStandaloneCall(const Statement& /*statement*/) override { readOnly = true; } void visitStandaloneCallFunction(const Statement& /*statement*/) override { readOnly = false; } void visitCreateMacro(const Statement& /*statement*/) override { readOnly = false; } + void visitExtension(const Statement& /*statement*/) override { readOnly = false; } void visitReadingClause(const ReadingClause* readingClause) override; void visitWithClause(const WithClause* withClause) override; diff --git a/test/storage/compression_test.cpp b/test/storage/compression_test.cpp index 7cab1e8babf..b4638d1a57b 100644 --- a/test/storage/compression_test.cpp +++ b/test/storage/compression_test.cpp @@ -4,6 +4,7 @@ #include "common/constants.h" #include "common/exception/not_implemented.h" #include "common/exception/storage.h" +#include "common/serializer/buffered_reader.h" #include "common/serializer/buffered_serializer.h" #include "common/serializer/deserializer.h" #include "common/serializer/reader.h" @@ -61,21 +62,6 @@ bool operator==(const CompressionMetadata& a, const CompressionMetadata& b) { return true; } -struct BufferReader : Reader { - BufferReader(uint8_t* data, size_t dataSize) : data(data), dataSize(dataSize), readSize(0) {} - - void read(uint8_t* outputData, uint64_t size) final { - memcpy(outputData, data + readSize, size); - readSize += size; - } - - bool finished() final { return (readSize >= dataSize); } - - uint8_t* data; - size_t dataSize; - size_t readSize; -}; - void testSerializeThenDeserialize(const CompressionMetadata& orig) { // test serializing/deserializing twice diff --git a/tools/python_api/test/test_extension.py b/tools/python_api/test/test_extension.py index 83d7f950231..bd1d99ef1a2 100644 --- a/tools/python_api/test/test_extension.py +++ b/tools/python_api/test/test_extension.py @@ -27,7 +27,7 @@ def extension_extension_dir_prefix() -> str: return extension_extension_dir_prefix -def test_extension_install_httpfs(conn_db_readonly: ConnDB, tmpdir: str, extension_extension_dir_prefix: str) -> None: +def test_extension_install_httpfs(conn_db_readwrite: ConnDB, tmpdir: str, extension_extension_dir_prefix: str) -> None: current_dir = Path(__file__).resolve().parent cmake_list_file = Path(current_dir).parent.parent.parent / "CMakeLists.txt" extension_version = None @@ -56,7 +56,7 @@ def test_extension_install_httpfs(conn_db_readonly: ConnDB, tmpdir: str, extensi temp_path = Path(tmpdir) / "libhttpfs.kuzu_extension" urllib.request.urlretrieve(download_url, temp_path) - conn, _ = conn_db_readonly + conn, _ = conn_db_readwrite conn.execute("INSTALL httpfs") assert Path.exists(extension_path)