From 0cbbe8b5ff9f9e118ed4ea6c32e52bbdfc6fcd38 Mon Sep 17 00:00:00 2001 From: aceforeverd Date: Wed, 24 Jan 2024 14:01:47 +0000 Subject: [PATCH] feat: insert stmt support for map data type WIP! --- hybridse/include/codec/fe_schema_codec.h | 7 +- hybridse/include/sdk/base_impl.h | 8 +- hybridse/src/codegen/buf_ir_builder.cc | 2 +- hybridse/src/codegen/buf_ir_builder.h | 9 +- hybridse/src/codegen/insert_row_builder.cc | 149 ++++++++++++++++++ hybridse/src/codegen/insert_row_builder.h | 67 ++++++++ .../src/codegen/insert_row_builder_test.cc | 71 +++++++++ hybridse/src/plan/planner.cc | 2 +- hybridse/src/planv2/ast_node_converter.cc | 5 - hybridse/src/sdk/base_impl.cc | 2 +- hybridse/src/vm/engine.cc | 12 +- hybridse/src/vm/jit.h | 10 +- hybridse/src/vm/jit_wrapper.h | 3 +- src/cmd/sql_cmd_test.cc | 47 ++++++ src/codec/codec.cc | 10 ++ src/codec/codec.h | 2 + src/sdk/sql_cluster_router.cc | 90 +++++++++-- src/sdk/sql_cluster_router.h | 2 +- src/sdk/sql_insert_row.cc | 10 +- src/sdk/sql_insert_row.h | 34 ++++ 20 files changed, 491 insertions(+), 51 deletions(-) create mode 100644 hybridse/src/codegen/insert_row_builder.cc create mode 100644 hybridse/src/codegen/insert_row_builder.h create mode 100644 hybridse/src/codegen/insert_row_builder_test.cc diff --git a/hybridse/include/codec/fe_schema_codec.h b/hybridse/include/codec/fe_schema_codec.h index df8642de8fa..02c03c886ee 100644 --- a/hybridse/include/codec/fe_schema_codec.h +++ b/hybridse/include/codec/fe_schema_codec.h @@ -18,10 +18,7 @@ #define HYBRIDSE_INCLUDE_CODEC_FE_SCHEMA_CODEC_H_ #include -#include -#include #include -#include #include "vm/catalog.h" namespace hybridse { @@ -56,7 +53,7 @@ class SchemaCodec { if (it->name().size() >= 128) { return false; } - uint8_t name_size = (uint8_t)(it->name().size()); + uint8_t name_size = static_cast(it->name().size()); memcpy(cbuffer, static_cast(&name_size), 1); cbuffer += 1; memcpy(cbuffer, static_cast(it->name().c_str()), @@ -66,7 +63,7 @@ class SchemaCodec { return true; } - static bool Decode(const std::string& buf, vm::Schema* schema) { + static bool Decode(const std::string& buf, codec::Schema* schema) { if (schema == NULL) return false; if (buf.size() <= 0) return true; const char* buffer = buf.c_str(); diff --git a/hybridse/include/sdk/base_impl.h b/hybridse/include/sdk/base_impl.h index 5d1fd8bc842..524c41f5f0c 100644 --- a/hybridse/include/sdk/base_impl.h +++ b/hybridse/include/sdk/base_impl.h @@ -30,13 +30,13 @@ typedef ::google::protobuf::RepeatedPtrField< ::hybridse::type::TableDef> class SchemaImpl : public Schema { public: - explicit SchemaImpl(const vm::Schema& schema); + explicit SchemaImpl(const codec::Schema& schema); SchemaImpl() {} ~SchemaImpl(); - const vm::Schema& GetSchema() const { return schema_; } - inline void SetSchema(const vm::Schema& schema) { schema_ = schema; } + const codec::Schema& GetSchema() const { return schema_; } + inline void SetSchema(const codec::Schema& schema) { schema_ = schema; } int32_t GetColumnCnt() const; const std::string& GetColumnName(uint32_t index) const; @@ -46,7 +46,7 @@ class SchemaImpl : public Schema { const bool IsConstant(uint32_t index) const; private: - vm::Schema schema_; + codec::Schema schema_; }; class TableImpl : public Table { diff --git a/hybridse/src/codegen/buf_ir_builder.cc b/hybridse/src/codegen/buf_ir_builder.cc index 432a7b4b499..bbbf1ac0b92 100644 --- a/hybridse/src/codegen/buf_ir_builder.cc +++ b/hybridse/src/codegen/buf_ir_builder.cc @@ -275,7 +275,7 @@ bool BufNativeIRBuilder::BuildGetStringField(uint32_t col_idx, uint32_t offset, BufNativeEncoderIRBuilder::BufNativeEncoderIRBuilder(CodeGenContextBase* ctx, const std::map* outputs, - const vm::Schema* schema) + const codec::Schema* schema) : ctx_(ctx), outputs_(outputs), schema_(schema), diff --git a/hybridse/src/codegen/buf_ir_builder.h b/hybridse/src/codegen/buf_ir_builder.h index 52d5d83385c..0ec9e664baf 100644 --- a/hybridse/src/codegen/buf_ir_builder.h +++ b/hybridse/src/codegen/buf_ir_builder.h @@ -25,7 +25,6 @@ #include "codegen/row_ir_builder.h" #include "codegen/scope_var.h" #include "codegen/variable_ir_builder.h" -#include "vm/catalog.h" namespace hybridse { namespace codegen { @@ -33,7 +32,7 @@ namespace codegen { class BufNativeEncoderIRBuilder : public RowEncodeIRBuilder { public: BufNativeEncoderIRBuilder(CodeGenContextBase* ctx, const std::map* outputs, - const vm::Schema* schema); + const codec::Schema* schema); ~BufNativeEncoderIRBuilder() override; @@ -55,10 +54,6 @@ class BufNativeEncoderIRBuilder : public RowEncodeIRBuilder { ::llvm::Value* str_addr_space, ::llvm::Value* str_body_offset, uint32_t str_field_idx, ::llvm::Value** output); - // encode SQL map data type into row - base::Status AppendMapVal(const type::ColumnSchema& sc, llvm::Value* i8_ptr, uint32_t field_idx, - const NativeValue& val, llvm::Value* str_addr_space, llvm::Value* str_body_offset, - uint32_t str_field_idx, llvm::Value** next_str_body_offset); absl::StatusOr GetOrBuildAppendMapFn(const type::ColumnSchema& sc) const; base::Status AppendHeader(::llvm::Value* i8_ptr, ::llvm::Value* size, @@ -74,7 +69,7 @@ class BufNativeEncoderIRBuilder : public RowEncodeIRBuilder { private: CodeGenContextBase* ctx_; const std::map* outputs_; - const vm::Schema* schema_; + const codec::Schema* schema_; uint32_t str_field_start_offset_; // n = offset_vec_[i] is // schema_[i] is base type (except string): col encode offset in row diff --git a/hybridse/src/codegen/insert_row_builder.cc b/hybridse/src/codegen/insert_row_builder.cc new file mode 100644 index 00000000000..c52eec6a1d8 --- /dev/null +++ b/hybridse/src/codegen/insert_row_builder.cc @@ -0,0 +1,149 @@ +/** + * Copyright (c) 2024 OpenMLDB authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "codegen/insert_row_builder.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "base/fe_status.h" +#include "codegen/buf_ir_builder.h" +#include "codegen/context.h" +#include "codegen/expr_ir_builder.h" +#include "node/node_manager.h" +#include "passes/resolve_fn_and_attrs.h" +#include "udf/default_udf_library.h" +#include "vm/engine.h" +#include "vm/jit_wrapper.h" + +namespace hybridse { +namespace codegen { + +InsertRowBuilder::InsertRowBuilder(const codec::Schema* schema) : schema_(schema) {} + +absl::Status InsertRowBuilder::Init() { + ::hybridse::vm::Engine::InitializeGlobalLLVM(); + + jit_ = std::unique_ptr(vm::HybridSeJitWrapper::Create()); + if (!jit_->Init()) { + jit_ = nullptr; + return absl::InternalError("fail to init jit"); + } + if (!vm::HybridSeJitWrapper::InitJitSymbols(jit_.get())) { + jit_ = nullptr; + return absl::InternalError("fail to init jit symbols"); + } + return absl::OkStatus(); +} + +absl::StatusOr> InsertRowBuilder::ComputeRow(const node::ExprListNode* values) { + EnsureInitialized(); + return ComputeRow(values->children_); +} + +absl::StatusOr> InsertRowBuilder::ComputeRow(absl::Span values) { + EnsureInitialized(); + + std::unique_ptr llvm_ctx = llvm::make_unique(); + std::unique_ptr llvm_module = llvm::make_unique("insert_row_builder", *llvm_ctx); + vm::SchemasContext empty_sc; + node::NodeManager nm; + codec::Schema empty_param_types; + CodeGenContext dump_ctx(llvm_module.get(), &empty_sc, &empty_param_types, &nm); + + auto library = udf::DefaultUdfLibrary::get(); + node::ExprAnalysisContext expr_ctx(&nm, library, &empty_sc, &empty_param_types); + passes::ResolveFnAndAttrs resolver(&expr_ctx); + + std::vector transformed; + for (auto& expr : values) { + node::ExprNode* out = nullptr; + CHECK_STATUS_TO_ABSL(resolver.VisitExpr(expr, &out)); + transformed.push_back(out); + } + + std::string fn_name = absl::StrCat("gen_insert_row_", fn_counter_++); + auto fs = BuildFn(&dump_ctx, fn_name, transformed); + CHECK_ABSL_STATUSOR(fs); + + llvm::Function* fn = fs.value(); + + if (!jit_->OptModule(llvm_module.get())) { + return absl::InternalError("fail to optimize module"); + } + + if (!jit_->AddModule(std::move(llvm_module), std::move(llvm_ctx))) { + return absl::InternalError("add llvm module failed"); + } + + auto c_fn = jit_->FindFunction(fn->getName()); + void (*encode)(int8_t**) = reinterpret_cast(const_cast(c_fn)); + + int8_t* insert_row = nullptr; + encode(&insert_row); + + auto managed_row = std::shared_ptr(insert_row, std::free); + + return managed_row; +} + +absl::StatusOr InsertRowBuilder::BuildFn(CodeGenContext* ctx, llvm::StringRef fn_name, + absl::Span values) { + llvm::Function* fn = ctx->GetModule()->getFunction(fn_name); + if (fn == nullptr) { + auto builder = ctx->GetBuilder(); + llvm::FunctionType* fnt = llvm::FunctionType::get(builder->getVoidTy(), + { + builder->getInt8PtrTy()->getPointerTo(), + }, + false); + + fn = llvm::Function::Create(fnt, llvm::GlobalValue::ExternalLinkage, fn_name, ctx->GetModule()); + FunctionScopeGuard fg(fn, ctx); + + llvm::Value* row_ptr_ptr = fn->arg_begin(); + + ExprIRBuilder expr_builder(ctx); + + std::map columns; + for (uint32_t i = 0; i < values.size(); ++i) { + auto expr = values[i]; + + NativeValue out; + auto s = expr_builder.Build(expr, &out); + CHECK_STATUS_TO_ABSL(s); + + columns[i] = out; + } + + BufNativeEncoderIRBuilder encode_builder(ctx, &columns, schema_); + CHECK_STATUS_TO_ABSL(encode_builder.Init()); + + encode_builder.BuildEncode(row_ptr_ptr); + + builder->CreateRetVoid(); + } + + return fn; +} + +// build the function that transform a single insert row values into encoded row +absl::StatusOr InsertRowBuilder::BuildEncodeFn() { return absl::OkStatus(); } +} // namespace codegen +} // namespace hybridse diff --git a/hybridse/src/codegen/insert_row_builder.h b/hybridse/src/codegen/insert_row_builder.h new file mode 100644 index 00000000000..83e8c1c2126 --- /dev/null +++ b/hybridse/src/codegen/insert_row_builder.h @@ -0,0 +1,67 @@ +/** + * Copyright (c) 2024 OpenMLDB authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef HYBRIDSE_SRC_CODEGEN_INSERT_ROW_BUILDER_H_ +#define HYBRIDSE_SRC_CODEGEN_INSERT_ROW_BUILDER_H_ + +#include + +#include "absl/status/statusor.h" +#include "codec/fe_row_codec.h" +#include "codegen/context.h" +#include "llvm/IR/Function.h" +#include "node/sql_node.h" +#include "vm/jit_wrapper.h" + +namespace hybridse { +namespace codegen { + +class InsertRowBuilder { + public: + explicit InsertRowBuilder(const codec::Schema* schema); + + absl::Status Init(); + + // compute the encoded row result for insert statement's single values expression list + // + // currently, expressions in insert values do not expect external source, so unsupported expressions + // will simply fail on resolving. + absl::StatusOr> ComputeRow(absl::Span values); + + absl::StatusOr> ComputeRow(const node::ExprListNode* values); + + private: + void EnsureInitialized() { assert(jit_ && "InsertRowBuilder not initialized"); } + + // build the function the will output the row from single insert values + // + // the function is just equivalent to C: `void fn(int8_t**)`. + // BuildFn returns different function with different name on every invocation + absl::StatusOr BuildFn(CodeGenContext* ctx, llvm::StringRef fn_name, + absl::Span); + + // build the function that transform a single insert row values into encoded row + absl::StatusOr BuildEncodeFn(); + + // CodeGenContextBase* ctx_; + const codec::Schema* schema_; + std::atomic fn_counter_ = 0; + + std::unique_ptr jit_; +}; +} // namespace codegen +} // namespace hybridse +#endif // HYBRIDSE_SRC_CODEGEN_INSERT_ROW_BUILDER_H_ diff --git a/hybridse/src/codegen/insert_row_builder_test.cc b/hybridse/src/codegen/insert_row_builder_test.cc new file mode 100644 index 00000000000..4924c175957 --- /dev/null +++ b/hybridse/src/codegen/insert_row_builder_test.cc @@ -0,0 +1,71 @@ +/** + * Copyright (c) 2024 OpenMLDB authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "codegen/insert_row_builder.h" + +#include + +#include "gtest/gtest.h" +#include "node/sql_node.h" +#include "plan/plan_api.h" +#include "vm/sql_ctx.h" + +namespace hybridse { +namespace codegen { + +class InsertRowBuilderTest : public ::testing::Test {}; + +TEST_F(InsertRowBuilderTest, encode) { + std::string sql = "insert into t1 values (1, map (1, '12'))"; + vm::SqlContext ctx; + ctx.sql = sql; + auto s = plan::PlanAPI::CreatePlanTreeFromScript(&ctx); + ASSERT_TRUE(s.isOK()) << s; + + auto* exprlist = dynamic_cast(ctx.logical_plan.front())->GetInsertNode()->values_[0]; + + codec::Schema sc; + { + auto col1 = sc.Add(); + col1->mutable_schema()->set_base_type(type::kInt32); + col1->set_type(type::kInt32); + } + + { + auto col = sc.Add(); + auto map_ty = col->mutable_schema()->mutable_map_type(); + map_ty->mutable_key_type()->set_base_type(type::kInt32); + map_ty->mutable_value_type()->set_base_type(type::kVarchar); + } + + InsertRowBuilder builder(&sc); + { + auto s = builder.Init(); + ASSERT_TRUE(s.ok()) << s; + } + + auto as = builder.ComputeRow(dynamic_cast(exprlist)); + ASSERT_TRUE(as.ok()) << as.status(); + + ASSERT_TRUE(as.value() != nullptr); +} +} // namespace codegen +} // namespace hybridse +// +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/hybridse/src/plan/planner.cc b/hybridse/src/plan/planner.cc index c60f8fce155..f9cee3c49fa 100644 --- a/hybridse/src/plan/planner.cc +++ b/hybridse/src/plan/planner.cc @@ -710,7 +710,7 @@ base::Status SimplePlanner::CreatePlanTree(const NodePointVector &parser_trees, break; } case node::kInsertStmt: { - CHECK_TRUE(is_batch_mode_, common::kPlanError, "Non-support INSERT Op in online serving"); + // CHECK_TRUE(is_batch_mode_, common::kPlanError, "Non-support INSERT Op in online serving"); node::PlanNode *insert_plan = nullptr; CHECK_STATUS(CreateInsertPlan(parser_tree, &insert_plan)) plan_trees.push_back(insert_plan); diff --git a/hybridse/src/planv2/ast_node_converter.cc b/hybridse/src/planv2/ast_node_converter.cc index 2c6225be9a8..6e60ece8143 100644 --- a/hybridse/src/planv2/ast_node_converter.cc +++ b/hybridse/src/planv2/ast_node_converter.cc @@ -1969,11 +1969,6 @@ base::Status ConvertInsertStatement(const zetasql::ASTInsertStatement* root, nod CHECK_TRUE(nullptr != row, common::kSqlAstError, "Un-support insert statement with null row") node::ExprListNode* row_values; CHECK_STATUS(ConvertExprNodeList(row->values(), node_manager, &row_values)) - for (auto expr : row_values->children_) { - CHECK_TRUE(nullptr != expr && - (node::kExprPrimary == expr->GetExprType() || node::kExprParameter == expr->GetExprType()), - common::kSqlAstError, "Un-support insert statement with un-const value") - } rows->AddChild(row_values); } diff --git a/hybridse/src/sdk/base_impl.cc b/hybridse/src/sdk/base_impl.cc index fe34f9cd1a8..153ceef8eef 100644 --- a/hybridse/src/sdk/base_impl.cc +++ b/hybridse/src/sdk/base_impl.cc @@ -24,7 +24,7 @@ namespace sdk { static const std::string EMPTY_STR; // NOLINT -SchemaImpl::SchemaImpl(const vm::Schema& schema) : schema_(schema) {} +SchemaImpl::SchemaImpl(const codec::Schema& schema) : schema_(schema) {} SchemaImpl::~SchemaImpl() {} diff --git a/hybridse/src/vm/engine.cc b/hybridse/src/vm/engine.cc index 0865655f3c1..ac9ee9dcaaf 100644 --- a/hybridse/src/vm/engine.cc +++ b/hybridse/src/vm/engine.cc @@ -15,19 +15,22 @@ */ #include "vm/engine.h" + #include #include #include + +#include "absl/time/clock.h" #include "boost/none.hpp" #include "codec/fe_row_codec.h" #include "gflags/gflags.h" #include "llvm-c/Target.h" #include "udf/default_udf_library.h" +#include "vm/internal/node_helper.h" #include "vm/local_tablet_handler.h" #include "vm/mem_catalog.h" -#include "vm/sql_compiler.h" -#include "vm/internal/node_helper.h" #include "vm/runner_ctx.h" +#include "vm/sql_compiler.h" DECLARE_bool(enable_spark_unsaferow_format); @@ -52,10 +55,15 @@ Engine::Engine(const std::shared_ptr& catalog) : cl_(catalog), options_ Engine::Engine(const std::shared_ptr& catalog, const EngineOptions& options) : cl_(catalog), options_(options), mu_(), lru_cache_() {} Engine::~Engine() {} + void Engine::InitializeGlobalLLVM() { + // not thread safe, but is generally fine to call multiple times if (LLVM_IS_INITIALIZED) return; + + absl::Time begin = absl::Now(); LLVMInitializeNativeTarget(); LLVMInitializeNativeAsmPrinter(); + LOG(INFO) << "initialize llvm native target and asm printer, takes " << absl::Now() - begin; LLVM_IS_INITIALIZED = true; } diff --git a/hybridse/src/vm/jit.h b/hybridse/src/vm/jit.h index 24bb9a74856..7af5f17ac0d 100644 --- a/hybridse/src/vm/jit.h +++ b/hybridse/src/vm/jit.h @@ -94,13 +94,15 @@ class HybridSeLlvmJitWrapper : public HybridSeJitWrapper { bool OptModule(::llvm::Module* module) override; - bool AddModule(std::unique_ptr module, - std::unique_ptr llvm_ctx) override; + bool AddModule(std::unique_ptr module, std::unique_ptr llvm_ctx) override; bool AddExternalFunction(const std::string& name, void* addr) override; - hybridse::vm::RawPtrHandle FindFunction( - const std::string& funcname) override; + hybridse::vm::RawPtrHandle FindFunction(const std::string& funcname) override; + + // llvm::Module* GetModule() { + // } + // llvm::LLVMContext* GetLlvmContext(); private: std::unique_ptr jit_; diff --git a/hybridse/src/vm/jit_wrapper.h b/hybridse/src/vm/jit_wrapper.h index 458cb28272d..b0bbb70c6ec 100644 --- a/hybridse/src/vm/jit_wrapper.h +++ b/hybridse/src/vm/jit_wrapper.h @@ -45,8 +45,7 @@ class HybridSeJitWrapper { bool AddModuleFromBuffer(const base::RawBuffer&); - virtual hybridse::vm::RawPtrHandle FindFunction( - const std::string& funcname) = 0; + virtual hybridse::vm::RawPtrHandle FindFunction(const std::string& funcname) = 0; static HybridSeJitWrapper* Create(const JitOptions& jit_options); static HybridSeJitWrapper* Create(); diff --git a/src/cmd/sql_cmd_test.cc b/src/cmd/sql_cmd_test.cc index cdff3943254..db468b58e7e 100644 --- a/src/cmd/sql_cmd_test.cc +++ b/src/cmd/sql_cmd_test.cc @@ -3238,6 +3238,53 @@ TEST_P(DBSDKTest, CreateIfNotExists) { ASSERT_TRUE(cs->GetNsClient()->DropDatabase("test2", msg)) << msg; } +TEST_P(DBSDKTest, MapTypeTable) { + auto cli = GetParam(); + cs = cli->cs; + sr = cli->sr; + absl::BitGen gen; + auto db = absl::StrCat("db_", absl::Uniform(gen, 0, std::numeric_limits::max())); + auto table = absl::StrCat("tb_", absl::Uniform(gen, 0, std::numeric_limits::max())); + + ProcessSQLs(sr, { + "set session execute_mode = 'online'", + absl::StrCat("create database ", db), + absl::StrCat("use ", db), + absl::Substitute("create table $0 (id string, val map)", table), + absl::Substitute("insert into $0 values ('1', map(12, '23')) ", table), + absl::Substitute("insert into $0 values ('4', map(99, '44')) ", table), + }); + absl::Cleanup clean = [&]() { + ProcessSQLs(sr, { + absl::Substitute("drop table $0", table), + absl::Substitute("drop database $0", db), + }); + }; + + // query + hybridse::sdk::Status status; + auto rs = sr->ExecuteSQL(absl::Substitute("select id, val[12] as ele from $0", table), &status); + ASSERT_TRUE(status.IsOK()) << status.ToString(); + ASSERT_EQ(rs->Size(), 2); + + while (rs->Next()) { + // result is unordered + std::string id; + ASSERT_TRUE(rs->GetAsString(0, id)); + std::string ele; + ASSERT_TRUE(rs->GetAsString(1, ele)); + + if (id == "1") { + EXPECT_EQ(ele, "23"); + } else if (id == "4") { + EXPECT_EQ(ele, "NULL"); + EXPECT_TRUE(rs->IsNULL(1)); + } else { + ASSERT_FALSE(true) << "should not reach"; + } + } +} + TEST_P(DBSDKTest, ShowComponents) { auto cli = GetParam(); cs = cli->cs; diff --git a/src/codec/codec.cc b/src/codec/codec.cc index 8d5e24bc8c8..858acfc374a 100644 --- a/src/codec/codec.cc +++ b/src/codec/codec.cc @@ -1152,5 +1152,15 @@ bool RowProject::Project(const int8_t* row_ptr, uint32_t size, int8_t** output_p return true; } +bool ColumnSupportLegacyCodec(const openmldb::common::ColumnDesc& col_desc) { + auto dt = col_desc.data_type(); + if (col_desc.has_schema()) { + dt = col_desc.schema().type(); + } + + return (dt >= openmldb::type::kBool && dt <= openmldb::type::kTimestamp) || dt == openmldb::type::kVarchar || + dt == openmldb::type::kString; +} + } // namespace codec } // namespace openmldb diff --git a/src/codec/codec.h b/src/codec/codec.h index 681c05ae2aa..b84289a2ce8 100644 --- a/src/codec/codec.h +++ b/src/codec/codec.h @@ -192,6 +192,8 @@ class RowView { std::vector offset_vec_; }; +bool ColumnSupportLegacyCodec(const openmldb::common::ColumnDesc&); + namespace v1 { inline int8_t GetAddrSpace(uint32_t size) { diff --git a/src/sdk/sql_cluster_router.cc b/src/sdk/sql_cluster_router.cc index bdad16cfc8c..f9bdf508835 100644 --- a/src/sdk/sql_cluster_router.cc +++ b/src/sdk/sql_cluster_router.cc @@ -60,6 +60,7 @@ #include "sdk/split.h" #include "udf/udf.h" #include "vm/catalog.h" +#include "codegen/insert_row_builder.h" DECLARE_string(bucket_size); DECLARE_uint32(replica_num); @@ -482,7 +483,8 @@ bool SQLClusterRouter::GetMultiRowInsertInfo(const std::string& db, const std::s ::hybridse::sdk::Status* status, std::shared_ptr<::openmldb::nameserver::TableInfo>* table_info, std::vector* default_maps, - std::vector* str_lengths, bool* put_if_absent) { + std::vector* str_lengths, bool* put_if_absent, + std::vector>* codegen_rows) { RET_FALSE_IF_NULL_AND_WARN(status, "output status is nullptr"); // TODO(hw): return status? RET_FALSE_IF_NULL_AND_WARN(table_info, "output table_info is nullptr"); @@ -524,6 +526,14 @@ bool SQLClusterRouter::GetMultiRowInsertInfo(const std::string& db, const std::s return false; } std::map column_map; + + bool insert_codegen = false; + for (int i = 0; i < (*table_info)->column_desc_size(); ++i) { + auto& col_desc = (*table_info)->column_desc(i); + if (!codec::ColumnSupportLegacyCodec(col_desc)) { + insert_codegen = true; + } + } for (size_t j = 0; j < insert_stmt->columns_.size(); ++j) { const std::string& col_name = insert_stmt->columns_[j]; bool find_flag = false; @@ -535,6 +545,7 @@ bool SQLClusterRouter::GetMultiRowInsertInfo(const std::string& db, const std::s } column_map.insert(std::make_pair(i, j)); find_flag = true; + break; } } @@ -544,6 +555,24 @@ bool SQLClusterRouter::GetMultiRowInsertInfo(const std::string& db, const std::s return false; } } + + ::hybridse::codec::Schema sc; + if (!schema::SchemaAdapter::ConvertSchema((*table_info)->column_desc(), &sc)) { + SET_STATUS_AND_WARN(status, StatusCode::kCmdError, "failed to convert table schema"); + return false; + } + // TODO(someone): + // 1. default value from table definition + // 2. parameters + ::hybridse::codegen::InsertRowBuilder insert_builder(&sc); + { + auto s = insert_builder.Init(); + if (!s.ok()) { + SET_STATUS_AND_WARN(status, StatusCode::kCmdError, s.ToString()); + return false; + } + } + size_t total_rows_size = insert_stmt->values_.size(); for (size_t i = 0; i < total_rows_size; i++) { hybridse::node::ExprNode* value = insert_stmt->values_[i]; @@ -554,23 +583,37 @@ bool SQLClusterRouter::GetMultiRowInsertInfo(const std::string& db, const std::s hybridse::node::ExprTypeName(value->GetExprType())); return false; } - uint32_t str_length = 0; - default_maps->push_back( - GetDefaultMap(*table_info, column_map, dynamic_cast<::hybridse::node::ExprListNode*>(value), &str_length)); - if (!default_maps->back()) { - SET_STATUS_AND_WARN(status, StatusCode::kCmdError, - "fail to parse row[" + std::to_string(i) + "]: " + value->GetExprString()); - return false; + if (insert_codegen) { + auto s = insert_builder.ComputeRow(dynamic_cast<::hybridse::node::ExprListNode*>(value)); + if (!s.ok()) { + SET_STATUS_AND_WARN(status, StatusCode::kCmdError, s.status().ToString()); + return false; + } + + codegen_rows->push_back(s.value()); + continue; + } else { + uint32_t str_length = 0; + default_maps->push_back(GetDefaultMap(*table_info, column_map, + dynamic_cast<::hybridse::node::ExprListNode*>(value), &str_length)); + if (!default_maps->back()) { + SET_STATUS_AND_WARN(status, StatusCode::kCmdError, + "fail to parse row[" + std::to_string(i) + "]: " + value->GetExprString()); + return false; + } + str_lengths->push_back(str_length); } - str_lengths->push_back(str_length); } - if (default_maps->empty() || str_lengths->empty()) { - SET_STATUS_AND_WARN(status, StatusCode::kCmdError, "default_maps or str_lengths are empty"); - return false; - } - if (default_maps->size() != str_lengths->size()) { - SET_STATUS_AND_WARN(status, StatusCode::kCmdError, "default maps isn't match with str_lengths"); - return false; + + if (!insert_codegen) { + if (default_maps->empty() || str_lengths->empty()) { + SET_STATUS_AND_WARN(status, StatusCode::kCmdError, "default_maps or str_lengths are empty"); + return false; + } + if (default_maps->size() != str_lengths->size()) { + SET_STATUS_AND_WARN(status, StatusCode::kCmdError, "default maps isn't match with str_lengths"); + return false; + } } return true; } @@ -1308,7 +1351,9 @@ bool SQLClusterRouter::ExecuteInsert(const std::string& db, const std::string& s std::vector default_maps; std::vector str_lengths; bool put_if_absent; - if (!GetMultiRowInsertInfo(db, sql, status, &table_info, &default_maps, &str_lengths, &put_if_absent)) { + std::vector> codegen_rows; + if (!GetMultiRowInsertInfo(db, sql, status, &table_info, &default_maps, &str_lengths, &put_if_absent, + &codegen_rows)) { CODE_PREPEND_AND_WARN(status, StatusCode::kCmdError, "Fail to get insert info"); return false; } @@ -1321,6 +1366,17 @@ bool SQLClusterRouter::ExecuteInsert(const std::string& db, const std::string& s "Fail to execute insert statement: fail to get " + table_info->name() + " tablets"); return false; } + if (!codegen_rows.empty()) { + for (auto& r : codegen_rows) { + auto row = std::make_shared(table_info, schema, r, put_if_absent); + if (!PutRow(table_info->tid(), row, tablets, status)) { + LOG(WARNING) << "fail to put row[" + << "] due to: " << status->msg; + continue; + } + } + return true; + } std::vector fails; for (size_t i = 0; i < default_maps.size(); i++) { auto row = std::make_shared(table_info, schema, default_maps[i], str_lengths[i], put_if_absent); diff --git a/src/sdk/sql_cluster_router.h b/src/sdk/sql_cluster_router.h index 0b9f6cca272..d7ba83090fa 100644 --- a/src/sdk/sql_cluster_router.h +++ b/src/sdk/sql_cluster_router.h @@ -320,7 +320,7 @@ class SQLClusterRouter : public SQLRouter { bool GetMultiRowInsertInfo(const std::string& db, const std::string& sql, ::hybridse::sdk::Status* status, std::shared_ptr<::openmldb::nameserver::TableInfo>* table_info, std::vector* default_maps, std::vector* str_lengths, - bool* put_if_absent); + bool* put_if_absent, std::vector>* codegen_rows); DefaultValueMap GetDefaultMap(const std::shared_ptr<::openmldb::nameserver::TableInfo>& table_info, const std::map& column_map, ::hybridse::node::ExprListNode* row, diff --git a/src/sdk/sql_insert_row.cc b/src/sdk/sql_insert_row.cc index 492bb80e49b..2f74d9fc330 100644 --- a/src/sdk/sql_insert_row.cc +++ b/src/sdk/sql_insert_row.cc @@ -90,6 +90,9 @@ SQLInsertRow::SQLInsertRow(std::shared_ptr<::openmldb::nameserver::TableInfo> ta } bool SQLInsertRow::Init(int str_length) { + if (is_codegen_row_) { + return true; + } str_size_ = str_length + default_string_length_; uint32_t row_size = rb_.CalTotalLength(str_size_); val_.resize(row_size); @@ -301,7 +304,12 @@ bool SQLInsertRow::AppendNULL() { return false; } -bool SQLInsertRow::IsComplete() { return rb_.IsComplete(); } +bool SQLInsertRow::IsComplete() { + if (is_codegen_row_) { + return true; + } + return rb_.IsComplete(); +} bool SQLInsertRow::Build() const { return str_size_ == 0; } diff --git a/src/sdk/sql_insert_row.h b/src/sdk/sql_insert_row.h index af18891587f..b6e40de730c 100644 --- a/src/sdk/sql_insert_row.h +++ b/src/sdk/sql_insert_row.h @@ -110,6 +110,38 @@ class SQLInsertRow { SQLInsertRow(std::shared_ptr<::openmldb::nameserver::TableInfo> table_info, std::shared_ptr schema, DefaultValueMap default_map, uint32_t default_str_length, std::vector hole_idx_arr, bool put_if_absent); + SQLInsertRow(std::shared_ptr<::openmldb::nameserver::TableInfo> table_info, + std::shared_ptr schema, std::shared_ptr codegen_row, bool put_if_absent) + : table_info_(table_info), + schema_(schema), + rb_(table_info->column_desc()), + put_if_absent_(put_if_absent), + is_codegen_row_(true) { + auto size = hybridse::codec::RowView::GetSize(codegen_row.get()); + val_ = std::string(reinterpret_cast(codegen_row.get()), size); + std::map column_name_map; + for (int idx = 0; idx < table_info_->column_desc_size(); idx++) { + column_name_map.emplace(table_info_->column_desc(idx).name(), idx); + } + if (table_info_->column_key_size() > 0) { + index_map_.clear(); + raw_dimensions_.clear(); + for (int idx = 0; idx < table_info_->column_key_size(); ++idx) { + const auto& index = table_info_->column_key(idx); + if (index.flag()) { + continue; + } + for (const auto& column : index.col_name()) { + index_map_[idx].push_back(column_name_map[column]); + raw_dimensions_[column_name_map[column]] = hybridse::codec::NONETOKEN; + } + if (!index.ts_name().empty()) { + ts_set_.insert(column_name_map[index.ts_name()]); + } + } + } + } + ~SQLInsertRow() = default; bool Init(int str_length); bool AppendBool(bool val); @@ -181,6 +213,8 @@ class SQLInsertRow { std::string val_; uint32_t str_size_; bool put_if_absent_; + + bool is_codegen_row_ = false; }; class SQLInsertRows {