Skip to content

Commit

Permalink
update tests; rework DBConfig::isDBPathInMemory
Browse files Browse the repository at this point in the history
  • Loading branch information
ray6080 committed Aug 5, 2024
1 parent ce9a96d commit abdccb1
Show file tree
Hide file tree
Showing 23 changed files with 192 additions and 191 deletions.
23 changes: 14 additions & 9 deletions extension/httpfs/src/http_config.cpp
Original file line number Diff line number Diff line change
@@ -1,25 +1,30 @@
#include "http_config.h"

#include "function/cast/functions/cast_from_string_functions.h"
#include "main/db_config.h"

namespace kuzu {
namespace httpfs {

HTTPConfig::HTTPConfig(main::ClientContext* context) {
KU_ASSERT(context != nullptr);
cacheFile = context->getCurrentSetting(HTTPCacheInMemoryConfig::HTTP_CACHE_FILE_OPTION)
.getValue<bool>();
cacheFile =
context->getCurrentSetting(HTTPCacheFileConfig::HTTP_CACHE_FILE_OPTION).getValue<bool>();
}

void HTTPConfigEnvProvider::setOptionValue(main::ClientContext* context) {
auto cacheInMemoryStrVal =
context->getEnvVariable(HTTPCacheInMemoryConfig::HTTP_CACHE_FILE_ENV_VAR);
if (cacheInMemoryStrVal != "") {
bool cacheInMemory;
const auto cacheFileOptionStrVal =
context->getEnvVariable(HTTPCacheFileConfig::HTTP_CACHE_FILE_ENV_VAR);
if (cacheFileOptionStrVal != "") {
bool enableCacheFile;
function::CastString::operation(
ku_string_t{cacheInMemoryStrVal.c_str(), cacheInMemoryStrVal.length()}, cacheInMemory);
context->setExtensionOption(HTTPCacheInMemoryConfig::HTTP_CACHE_FILE_OPTION,
Value::createValue(cacheInMemory));
ku_string_t{cacheFileOptionStrVal.c_str(), cacheFileOptionStrVal.length()},
enableCacheFile);
if (enableCacheFile && main::DBConfig::isDBPathInMemory(context->getDatabasePath())) {
throw Exception("Cannot enable HTTP file cache when database is in memory");
}
context->setExtensionOption(HTTPCacheFileConfig::HTTP_CACHE_FILE_OPTION,
Value::createValue(enableCacheFile));
}
}

Expand Down
4 changes: 2 additions & 2 deletions extension/httpfs/src/httpfs_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ void HttpfsExtension::load(main::ClientContext* context) {
common::Value{(int64_t)10000});
db->addExtensionOption("s3_uploader_threads_limit", common::LogicalTypeID::INT64,
common::Value{(int64_t)50});
db->addExtensionOption(HTTPCacheInMemoryConfig::HTTP_CACHE_FILE_OPTION,
common::LogicalTypeID::BOOL, common::Value{HTTPCacheInMemoryConfig::DEFAULT_CACHE_FILE});
db->addExtensionOption(HTTPCacheFileConfig::HTTP_CACHE_FILE_OPTION, common::LogicalTypeID::BOOL,
common::Value{HTTPCacheFileConfig::DEFAULT_CACHE_FILE});
AWSEnvironmentCredentialsProvider::setOptionValue(context);
HTTPConfigEnvProvider::setOptionValue(context);
}
Expand Down
2 changes: 1 addition & 1 deletion extension/httpfs/src/include/http_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ struct HTTPConfig {
bool cacheFile;
};

struct HTTPCacheInMemoryConfig {
struct HTTPCacheFileConfig {
static constexpr const char* HTTP_CACHE_FILE_ENV_VAR = "HTTP_CACHE_FILE";
static constexpr const char* HTTP_CACHE_FILE_OPTION = "http_cache_file";
static constexpr bool DEFAULT_CACHE_FILE = false;
Expand Down
15 changes: 9 additions & 6 deletions src/catalog/catalog.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "common/serializer/serializer.h"
#include "common/string_format.h"
#include "function/built_in_function_utils.h"
#include "main/db_config.h"
#include "storage/storage_utils.h"
#include "storage/storage_version_info.h"
#include "transaction/transaction.h"
Expand All @@ -42,17 +43,19 @@ Catalog::Catalog() {
registerBuiltInFunctions();
}

Catalog::Catalog(std::string directory, VirtualFileSystem* fs) {
if (!directory.empty() && fs->fileOrPathExists(StorageUtils::getCatalogFilePath(fs, directory,
FileVersionType::ORIGINAL))) {
readFromFile(directory, fs, FileVersionType::ORIGINAL);
Catalog::Catalog(const std::string& directory, VirtualFileSystem* vfs) {
const auto isInMemMode = main::DBConfig::isDBPathInMemory(directory);
if (!isInMemMode && vfs->fileOrPathExists(StorageUtils::getCatalogFilePath(vfs, directory,
FileVersionType::ORIGINAL))) {
readFromFile(directory, vfs, FileVersionType::ORIGINAL);
} else {
tables = std::make_unique<CatalogSet>();
sequences = std::make_unique<CatalogSet>();
functions = std::make_unique<CatalogSet>();
types = std::make_unique<CatalogSet>();
if (!directory.empty()) {
saveToFile(directory, fs, FileVersionType::ORIGINAL);
if (!isInMemMode) {
// TODO(Guodong): Ideally we should be able to remove this line. Revisit here.
saveToFile(directory, vfs, FileVersionType::ORIGINAL);
}
}
registerBuiltInFunctions();
Expand Down
84 changes: 48 additions & 36 deletions src/include/catalog/catalog.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
#include "common/cast.h"
#include "function/function.h"

namespace kuzu::main {
struct DBConfig;
} // namespace kuzu::main

namespace kuzu {
namespace main {
class AttachedKuzuDatabase;
Expand Down Expand Up @@ -47,54 +51,61 @@ class KUZU_API Catalog {
public:
// This is extended by DuckCatalog and PostgresCatalog.
Catalog();
Catalog(std::string directory, common::VirtualFileSystem* vfs);
Catalog(const std::string& directory, common::VirtualFileSystem* vfs);
virtual ~Catalog() = default;

// ----------------------------- Table Schemas ----------------------------
bool containsTable(transaction::Transaction* tx, const std::string& tableName) const;
bool containsTable(transaction::Transaction* transaction, const std::string& tableName) const;

common::table_id_t getTableID(transaction::Transaction* tx, const std::string& tableName) const;
std::vector<common::table_id_t> getNodeTableIDs(transaction::Transaction* tx) const;
std::vector<common::table_id_t> getRelTableIDs(transaction::Transaction* tx) const;
common::table_id_t getTableID(transaction::Transaction* transaction,
const std::string& tableName) const;
std::vector<common::table_id_t> getNodeTableIDs(transaction::Transaction* transaction) const;
std::vector<common::table_id_t> getRelTableIDs(transaction::Transaction* transaction) const;

// TODO: Should remove this.
std::string getTableName(transaction::Transaction* tx, common::table_id_t tableID) const;
TableCatalogEntry* getTableCatalogEntry(transaction::Transaction* tx,
std::string getTableName(transaction::Transaction* transaction,
common::table_id_t tableID) const;
TableCatalogEntry* getTableCatalogEntry(transaction::Transaction* transaction,
common::table_id_t tableID) const;
std::vector<NodeTableCatalogEntry*> getNodeTableEntries(transaction::Transaction* tx) const;
std::vector<RelTableCatalogEntry*> getRelTableEntries(transaction::Transaction* tx) const;
std::vector<RelGroupCatalogEntry*> getRelTableGroupEntries(transaction::Transaction* tx) const;
std::vector<RDFGraphCatalogEntry*> getRdfGraphEntries(transaction::Transaction* tx) const;
std::vector<TableCatalogEntry*> getTableEntries(transaction::Transaction* tx) const;
std::vector<TableCatalogEntry*> getTableEntries(transaction::Transaction* tx,
std::vector<NodeTableCatalogEntry*> getNodeTableEntries(
transaction::Transaction* transaction) const;
std::vector<RelTableCatalogEntry*> getRelTableEntries(
transaction::Transaction* transaction) const;
std::vector<RelGroupCatalogEntry*> getRelTableGroupEntries(
transaction::Transaction* transaction) const;
std::vector<RDFGraphCatalogEntry*> getRdfGraphEntries(
transaction::Transaction* transaction) const;
std::vector<TableCatalogEntry*> getTableEntries(transaction::Transaction* transaction) const;
std::vector<TableCatalogEntry*> getTableEntries(transaction::Transaction* transaction,
const common::table_id_vector_t& tableIDs) const;
bool tableInRDFGraph(transaction::Transaction* tx, common::table_id_t tableID) const;
bool tableInRelGroup(transaction::Transaction* tx, common::table_id_t tableID) const;
common::table_id_set_t getFwdRelTableIDs(transaction::Transaction* tx,
bool tableInRDFGraph(transaction::Transaction* transaction, common::table_id_t tableID) const;
bool tableInRelGroup(transaction::Transaction* transaction, common::table_id_t tableID) const;
common::table_id_set_t getFwdRelTableIDs(transaction::Transaction* transaction,
common::table_id_t nodeTableID) const;
common::table_id_set_t getBwdRelTableIDs(transaction::Transaction* tx,
common::table_id_set_t getBwdRelTableIDs(transaction::Transaction* transaction,
common::table_id_t nodeTableID) const;

common::table_id_t createTableSchema(transaction::Transaction* tx,
common::table_id_t createTableSchema(transaction::Transaction* transaction,
const binder::BoundCreateTableInfo& info);
void dropTableEntry(transaction::Transaction* tx, std::string name);
void dropTableEntry(transaction::Transaction* tx, common::table_id_t tableID);
void alterTableEntry(transaction::Transaction* tx, const binder::BoundAlterInfo& info);
void dropTableEntry(transaction::Transaction* transaction, std::string name);
void dropTableEntry(transaction::Transaction* transaction, common::table_id_t tableID);
void alterTableEntry(transaction::Transaction* transaction, const binder::BoundAlterInfo& info);

// ----------------------------- Sequences ----------------------------
bool containsSequence(transaction::Transaction* transaction,
const std::string& sequenceName) const;

common::sequence_id_t getSequenceID(transaction::Transaction* transaction,
const std::string& sequenceName) const;
SequenceCatalogEntry* getSequenceCatalogEntry(transaction::Transaction* tx,
SequenceCatalogEntry* getSequenceCatalogEntry(transaction::Transaction* transaction,
common::sequence_id_t sequenceID) const;
std::vector<SequenceCatalogEntry*> getSequenceEntries(transaction::Transaction* tx) const;
std::vector<SequenceCatalogEntry*> getSequenceEntries(
transaction::Transaction* transaction) const;

common::sequence_id_t createSequence(transaction::Transaction* tx,
common::sequence_id_t createSequence(transaction::Transaction* transaction,
const binder::BoundCreateSequenceInfo& info);
void dropSequence(transaction::Transaction* tx, std::string name);
void dropSequence(transaction::Transaction* tx, common::sequence_id_t sequenceID);
void dropSequence(transaction::Transaction* transaction, std::string name);
void dropSequence(transaction::Transaction* transaction, common::sequence_id_t sequenceID);

static std::string genSerialName(const std::string& tableName, const std::string& propertyName);

Expand All @@ -105,21 +116,22 @@ class KUZU_API Catalog {
bool containsType(transaction::Transaction* transaction, const std::string& typeName);

// ----------------------------- Functions ----------------------------
void addFunction(transaction::Transaction* tx, CatalogEntryType entryType, std::string name,
function::function_set functionSet);
void dropFunction(transaction::Transaction* tx, const std::string& name);
void addFunction(transaction::Transaction* transaction, CatalogEntryType entryType,
std::string name, function::function_set functionSet);
void dropFunction(transaction::Transaction* transaction, const std::string& name);
void addBuiltInFunction(CatalogEntryType entryType, std::string name,
function::function_set functionSet);
CatalogSet* getFunctions(transaction::Transaction* tx) const;
CatalogEntry* getFunctionEntry(transaction::Transaction* tx, const std::string& name);
std::vector<FunctionCatalogEntry*> getFunctionEntries(transaction::Transaction* tx) const;
CatalogSet* getFunctions(transaction::Transaction* transaction) const;
CatalogEntry* getFunctionEntry(transaction::Transaction* transaction, const std::string& name);
std::vector<FunctionCatalogEntry*> getFunctionEntries(
transaction::Transaction* transaction) const;

bool containsMacro(transaction::Transaction* tx, const std::string& macroName) const;
void addScalarMacroFunction(transaction::Transaction* tx, std::string name,
bool containsMacro(transaction::Transaction* transaction, const std::string& macroName) const;
void addScalarMacroFunction(transaction::Transaction* transaction, std::string name,
std::unique_ptr<function::ScalarMacroFunction> macro);
function::ScalarMacroFunction* getScalarMacroFunction(transaction::Transaction* tx,
function::ScalarMacroFunction* getScalarMacroFunction(transaction::Transaction* transaction,
const std::string& name) const;
std::vector<std::string> getMacroNames(transaction::Transaction* tx) const;
std::vector<std::string> getMacroNames(transaction::Transaction* transaction) const;

void checkpoint(const std::string& databasePath, common::VirtualFileSystem* fs) const;

Expand Down
4 changes: 2 additions & 2 deletions src/include/main/client_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ struct ExtensionOptions;
}

namespace main {
class DBConfig;
struct DBConfig;
class Database;
class DatabaseManager;
class AttachedKuzuDatabase;
Expand Down Expand Up @@ -63,7 +63,7 @@ class KUZU_API ClientContext {
const DBConfig* getDBConfig() const { return &dbConfig; }
DBConfig* getDBConfigUnsafe() { return &dbConfig; }
common::Value getCurrentSetting(const std::string& optionName);
bool isOptionSet(const std::string& name) const;
bool isOptionSet(const std::string& optionName) const;
// Timer and timeout
void interrupt() { activeQuery.interrupted = true; }
bool interrupted() const { return activeQuery.interrupted; }
Expand Down
13 changes: 7 additions & 6 deletions src/include/main/database.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

#include <memory>
#include <mutex>
#include <string>
#include <vector>

#include "common/api.h"
Expand Down Expand Up @@ -124,7 +123,7 @@ class Database {

KUZU_API catalog::Catalog* getCatalog() { return catalog.get(); }

ExtensionOption* getExtensionOption(std::string name);
ExtensionOption* getExtensionOption(std::string name) const;

const DBConfig& getConfig() const { return dbConfig; }

Expand All @@ -134,6 +133,11 @@ class Database {
uint64_t getNextQueryID();

private:
struct QueryIDGenerator {
uint64_t queryID = 0;
std::mutex queryIDLock;
};

void openLockFile();
void initAndLockDBDir();

Expand All @@ -151,10 +155,7 @@ class Database {
std::unique_ptr<extension::ExtensionOptions> extensionOptions;
std::unique_ptr<DatabaseManager> databaseManager;
common::case_insensitive_map_t<std::unique_ptr<storage::StorageExtension>> storageExtensions;
struct QueryIDGenerator {
uint64_t queryID = 0;
std::mutex queryIDLock;
} queryIDGenerator;
QueryIDGenerator queryIDGenerator;
};

} // namespace main
Expand Down
14 changes: 7 additions & 7 deletions src/include/main/db_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,20 @@ struct ExtensionOption final : Option {
defaultValue{std::move(defaultValue)} {}
};

class DBConfig {
public:
static ConfigurationOption* getOptionByName(const std::string& optionName);

explicit DBConfig(const SystemConfig& systemConfig);

struct DBConfig {
uint64_t bufferPoolSize;
uint64_t maxNumThreads;
bool enableCompression;
bool readOnly;
uint64_t maxDBSize;
bool enableMultiWrites = false;
bool enableMultiWrites;
bool autoCheckpoint;
uint64_t checkpointThreshold;

explicit DBConfig(const SystemConfig& systemConfig);

static ConfigurationOption* getOptionByName(const std::string& optionName);
static bool isDBPathInMemory(const std::string& dbPath);
};

} // namespace main
Expand Down
21 changes: 11 additions & 10 deletions src/include/storage/storage_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class StorageManager {

static void recover(main::ClientContext& clientContext);

void createTable(common::table_id_t tableID, catalog::Catalog* catalog,
void createTable(common::table_id_t tableID, const catalog::Catalog* catalog,
main::ClientContext* context);

void checkpoint(main::ClientContext& clientContext);
Expand All @@ -37,8 +37,8 @@ class StorageManager {
return tables.at(tableID).get();
}

WAL& getWAL();
ShadowFile& getShadowFile();
WAL& getWAL() const;
ShadowFile& getShadowFile() const;
BMFileHandle* getDataFH() const { return dataFH; }
BMFileHandle* getMetadataFH() const { return metadataFH; }
std::string getDatabasePath() const { return databasePath; }
Expand All @@ -48,19 +48,20 @@ class StorageManager {
uint64_t getEstimatedMemoryUsage();

private:
BMFileHandle* initFileHandle(const std::string& filename, common::VirtualFileSystem* vfs,
BMFileHandle* initFileHandle(const std::string& fileName, common::VirtualFileSystem* vfs,
main::ClientContext* context) const;

void loadTables(const catalog::Catalog& catalog, common::VirtualFileSystem* vfs,
main::ClientContext* context);
void createNodeTable(common::table_id_t tableID, catalog::NodeTableCatalogEntry* tableSchema,
void createNodeTable(common::table_id_t tableID, catalog::NodeTableCatalogEntry* nodeTableEntry,
main::ClientContext* context);
void createRelTable(common::table_id_t tableID, catalog::RelTableCatalogEntry* tableSchema,
catalog::Catalog* catalog, transaction::Transaction* transaction);
void createRelTableGroup(common::table_id_t tableID, catalog::RelGroupCatalogEntry* tableSchema,
catalog::Catalog* catalog, transaction::Transaction* transaction);
void createRelTable(common::table_id_t tableID, catalog::RelTableCatalogEntry* relTableEntry,
const catalog::Catalog* catalog, transaction::Transaction* transaction);
void createRelTableGroup(common::table_id_t tableID,
const catalog::RelGroupCatalogEntry* tableSchema, const catalog::Catalog* catalog,
transaction::Transaction* transaction);
void createRdfGraph(common::table_id_t tableID, catalog::RDFGraphCatalogEntry* tableSchema,
catalog::Catalog* catalog, main::ClientContext* context);
const catalog::Catalog* catalog, main::ClientContext* context);

private:
std::mutex mtx;
Expand Down
4 changes: 4 additions & 0 deletions src/main/attached_database.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ AttachedKuzuDatabase::AttachedKuzuDatabase(std::string dbPath, std::string dbNam
std::string dbType, ClientContext* clientContext)
: AttachedDatabase{std::move(dbName), std::move(dbType), nullptr /* catalog */} {
auto vfs = clientContext->getVFSUnsafe();
if (DBConfig::isDBPathInMemory(dbPath)) {
throw common::RuntimeException("Cannot attach an in-memory Kùzu database. Please give a "
"path to an on-disk Kùzu database directory.");
}
auto path = vfs->expandPath(clientContext, dbPath);
initCatalog(path, clientContext);
validateEmptyWAL(path, clientContext);
Expand Down
4 changes: 2 additions & 2 deletions src/main/client_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ void ActiveQuery::reset() {

ClientContext::ClientContext(Database* database)
: dbConfig{database->dbConfig}, localDatabase{database} {
progressBar = std::make_unique<common::ProgressBar>();
progressBar = std::make_unique<ProgressBar>();
transactionContext = std::make_unique<TransactionContext>(*this);
randomEngine = std::make_unique<common::RandomEngine>();
randomEngine = std::make_unique<RandomEngine>();
remoteDatabase = nullptr;
#if defined(_WIN32)
clientConfig.homeDirectory = getEnvVariable("USERPROFILE");
Expand Down
Loading

0 comments on commit abdccb1

Please sign in to comment.