From 7ce75033e6ba4d894695e8434eb9b28ba3750fab Mon Sep 17 00:00:00 2001 From: aceforeverd Date: Wed, 24 Jan 2024 07:21:03 +0000 Subject: [PATCH] feat: create table with map type columns --- hybridse/include/node/sql_node.h | 3 + hybridse/src/codegen/block_ir_builder.cc | 12 +- hybridse/src/codegen/buf_ir_builder.cc | 2 +- hybridse/src/codegen/ir_base_builder_test.h | 6 +- hybridse/src/codegen/map_ir_builder.cc | 212 ++++++++++++-------- hybridse/src/codegen/map_ir_builder.h | 8 +- hybridse/src/codegen/struct_ir_builder.cc | 27 +-- hybridse/src/codegen/struct_ir_builder.h | 4 +- src/proto/common.proto | 9 + src/proto/type.proto | 4 +- src/schema/schema_adapter.cc | 179 +++++++++++++++-- src/schema/schema_adapter.h | 23 ++- src/sdk/node_adapter.cc | 13 +- 13 files changed, 371 insertions(+), 131 deletions(-) diff --git a/hybridse/include/node/sql_node.h b/hybridse/include/node/sql_node.h index 14e139bdd3f..366a6c2b76d 100644 --- a/hybridse/include/node/sql_node.h +++ b/hybridse/include/node/sql_node.h @@ -1888,6 +1888,9 @@ class ColumnDefNode : public SqlNode { std::string GetColumnName() const { return column_name_; } + const ColumnSchemaNode *schema() const { return schema_; } + + // deprecated, use ColumnDefNode::schema instead DataType GetColumnType() const { return schema_->type(); } const ExprNode* GetDefaultValue() const { return schema_->default_value(); } diff --git a/hybridse/src/codegen/block_ir_builder.cc b/hybridse/src/codegen/block_ir_builder.cc index 818229553ca..200a8f9f732 100644 --- a/hybridse/src/codegen/block_ir_builder.cc +++ b/hybridse/src/codegen/block_ir_builder.cc @@ -290,16 +290,18 @@ bool BlockIRBuilder::BuildReturnStmt(const ::hybridse::node::FnReturnStmt *node, } ::llvm::Value *value = value_wrapper.GetValue(&builder); if (TypeIRBuilder::IsStructPtr(value->getType())) { - StructTypeIRBuilder *struct_builder = - StructTypeIRBuilder::CreateStructTypeIRBuilder(block->getModule(), - value->getType()); + auto struct_builder = StructTypeIRBuilder::CreateStructTypeIRBuilder(block->getModule(), value->getType()); + if (!struct_builder.ok()) { + status.code = kCodegenError; + status.msg = struct_builder.status().ToString(); + return false; + } NativeValue ret_value; if (!var_ir_builder.LoadRetStruct(&ret_value, status)) { LOG(WARNING) << "fail to load ret struct address"; return false; } - if (!struct_builder->CopyFrom(block, value, - ret_value.GetValue(&builder))) { + if (!struct_builder.value()->CopyFrom(block, value, ret_value.GetValue(&builder))) { return false; } value = builder.getInt1(true); diff --git a/hybridse/src/codegen/buf_ir_builder.cc b/hybridse/src/codegen/buf_ir_builder.cc index 2a1f81af4c7..432a7b4b499 100644 --- a/hybridse/src/codegen/buf_ir_builder.cc +++ b/hybridse/src/codegen/buf_ir_builder.cc @@ -529,7 +529,7 @@ absl::StatusOr BufNativeEncoderIRBuilder::GetOrBuildAppendMapFn auto bs = ctx_->CreateBranchNot(is_null, [&]() -> base::Status { auto row_ptr = BuildGetPtrOffset(sub_builder, i8_ptr, str_body_offset); CHECK_TRUE(row_ptr.ok(), common::kCodegenError, row_ptr.status().ToString()); - auto sz = map_builder.Encode(ctx_, map_ptr, row_ptr.value()); + auto sz = map_builder.Encode(ctx_, row_ptr.value(), map_ptr); CHECK_TRUE(sz.ok(), common::kCodegenError, sz.status().ToString()); sub_builder->CreateStore(sz.value(), encode_sz_alloca); return {}; diff --git a/hybridse/src/codegen/ir_base_builder_test.h b/hybridse/src/codegen/ir_base_builder_test.h index af29e4fd56c..494cfdb0818 100644 --- a/hybridse/src/codegen/ir_base_builder_test.h +++ b/hybridse/src/codegen/ir_base_builder_test.h @@ -360,7 +360,11 @@ void ModuleFunctionBuilderWithFullInfo::ExpandApplyArg( if (TypeIRBuilder::IsStructPtr(expect_ty)) { auto struct_builder = StructTypeIRBuilder::CreateStructTypeIRBuilder(function->getEntryBlock().getModule(), expect_ty); - struct_builder->CreateDefault(&function->getEntryBlock(), + if (!struct_builder.ok()) { + LOG(WARNING) << struct_builder.status(); + return; + } + struct_builder.value()->CreateDefault(&function->getEntryBlock(), &alloca); arg = builder.CreateSelect( is_null, alloca, builder.CreatePointerCast(arg, expect_ty)); diff --git a/hybridse/src/codegen/map_ir_builder.cc b/hybridse/src/codegen/map_ir_builder.cc index e54543040fd..27e6944c102 100644 --- a/hybridse/src/codegen/map_ir_builder.cc +++ b/hybridse/src/codegen/map_ir_builder.cc @@ -197,7 +197,7 @@ absl::StatusOr MapIRBuilder::ExtractElement(CodeGenContextBase* ctx ctx->GetBuilder()->getInt1Ty()->getPointerTo() // output is null ptr }, false); - fn = llvm::Function::Create(fnt, llvm::Function::ExternalLinkage, fn_name, ctx->GetModule()); + fn = llvm::Function::Create(fnt, llvm::GlobalValue::ExternalLinkage, fn_name, ctx->GetModule()); FunctionScopeGuard fg(fn, ctx); @@ -362,6 +362,53 @@ absl::StatusOr MapIRBuilder::MapKeys(CodeGenContextBase* ctx, const return out; } +absl::StatusOr MapIRBuilder::BuildEncodeByteSizeFn(CodeGenContextBase* ctx) const { + std::string fn_name = absl::StrCat("calc_encode_map_sz_", GetIRTypeName(struct_type_)); + llvm::Function* fn = ctx->GetModule()->getFunction(fn_name); + auto builder = ctx->GetBuilder(); + if (fn == nullptr) { + llvm::FunctionType* fnt = llvm::FunctionType::get(builder->getInt32Ty(), // return size + { + struct_type_->getPointerTo(), + }, + false); + fn = llvm::Function::Create(fnt, llvm::GlobalValue::ExternalLinkage, fn_name, ctx->GetModule()); + FunctionScopeGuard fg(fn, ctx); + + llvm::Value* raw = fn->arg_begin(); + // map_size + [key_ele_sz * map_size] + [val_ele_sz * map_sz] + [sizeof(bool) * map_size] + llvm::Value* final_size = CodecSizeForPrimitive(builder, builder->getInt32Ty()); + + auto elements = Load(ctx, raw); + if (!elements.ok()) { + return elements.status(); + } + + if (elements->size() != FIELDS_CNT) { + return absl::FailedPreconditionError( + absl::Substitute("element count error, expect $0, got $1", FIELDS_CNT, elements->size())); + } + auto& elements_vec = elements.value(); + auto& map_size = elements_vec[0]; + auto& key_vec = elements_vec[1]; + auto& value_vec = elements_vec[2]; + auto& value_null_vec = elements_vec[3]; + + auto keys_sz = CalEncodeSizeForArray(ctx, key_vec, map_size); + CHECK_ABSL_STATUSOR(keys_sz); + auto values_sz = CalEncodeSizeForArray(ctx, value_vec, map_size); + CHECK_ABSL_STATUSOR(values_sz); + auto values_null_sz = CalEncodeSizeForArray(ctx, value_null_vec, map_size); + CHECK_ABSL_STATUSOR(values_null_sz); + + builder->CreateRet(builder->CreateAdd( + final_size, + builder->CreateAdd(keys_sz.value(), builder->CreateAdd(values_sz.value(), values_null_sz.value())))); + } + + return fn; +} + absl::StatusOr MapIRBuilder::CalEncodeByteSize(CodeGenContextBase* ctx, llvm::Value* raw) const { auto builder = ctx->GetBuilder(); if (!raw->getType()->isPointerTy() || raw->getType()->getPointerElementType() != struct_type_) { @@ -370,33 +417,11 @@ absl::StatusOr MapIRBuilder::CalEncodeByteSize(CodeGenContextBase* GetLlvmObjectString(raw->getType()))); } - // map_size + [key_ele_sz * map_size] + [val_ele_sz * map_sz] + [sizeof(bool) * map_size] - llvm::Value* final_size = CodecSizeForPrimitive(builder, builder->getInt32Ty()); + auto fns = BuildEncodeByteSizeFn(ctx); - auto elements = Load(ctx, raw); - if (!elements.ok()) { - return elements.status(); - } + CHECK_ABSL_STATUSOR(fns); - if (elements->size() != FIELDS_CNT) { - return absl::FailedPreconditionError( - absl::Substitute("element count error, expect $0, got $1", FIELDS_CNT, elements->size())); - } - auto& elements_vec = elements.value(); - auto& map_size = elements_vec[0]; - auto& key_vec = elements_vec[1]; - auto& value_vec = elements_vec[2]; - auto& value_null_vec = elements_vec[3]; - - auto keys_sz = CalEncodeSizeForArray(ctx, key_vec, map_size); - CHECK_ABSL_STATUSOR(keys_sz); - auto values_sz = CalEncodeSizeForArray(ctx, value_vec, map_size); - CHECK_ABSL_STATUSOR(values_sz); - auto values_null_sz = CalEncodeSizeForArray(ctx, value_null_vec, map_size); - CHECK_ABSL_STATUSOR(values_null_sz); - - return builder->CreateAdd( - final_size, builder->CreateAdd(keys_sz.value(), builder->CreateAdd(values_sz.value(), values_null_sz.value()))); + return builder->CreateCall(fns.value(), {raw}); } absl::StatusOr MapIRBuilder::CalEncodeSizeForArray(CodeGenContextBase* ctx, llvm::Value* arr_ptr, @@ -429,7 +454,7 @@ absl::StatusOr MapIRBuilder::CalEncodeSizeForArray(CodeGenContextB builder->getInt32Ty() // arr size }, false); - fn = llvm::Function::Create(fnt, llvm::Function::ExternalLinkage, fn_name, ctx->GetModule()); + fn = llvm::Function::Create(fnt, llvm::GlobalValue::ExternalLinkage, fn_name, ctx->GetModule()); FunctionScopeGuard fg(fn, ctx); auto sub_builder = ctx->GetBuilder(); @@ -508,10 +533,82 @@ absl::StatusOr MapIRBuilder::TypeEncodeByteSize(CodeGenContextBase return absl::UnimplementedError(absl::StrCat("encode type ", GetLlvmObjectString(ele_type))); } -absl::StatusOr MapIRBuilder::Encode(CodeGenContextBase* ctx, llvm::Value* map_ptr, - llvm::Value* row_ptr) const { +absl::StatusOr MapIRBuilder::BuildEncodeFn(CodeGenContextBase* ctx) const { + std::string fn_name = absl::StrCat("encode_map_", GetIRTypeName(struct_type_)); + llvm::Function* fn = ctx->GetModule()->getFunction(fn_name); + + auto builder = ctx->GetBuilder(); + if (fn == nullptr) { + llvm::FunctionType* fnt = llvm::FunctionType::get(builder->getInt32Ty(), // encoded byte size + { + builder->getInt8PtrTy(), // row ptr + struct_type_->getPointerTo(), // map ptr + }, + false); + fn = llvm::Function::Create(fnt, llvm::GlobalValue::ExternalLinkage, fn_name, ctx->GetModule()); + + FunctionScopeGuard fg(fn, ctx); + + llvm::Value* row_ptr = fn->arg_begin(); + llvm::Value* map_ptr = fn->arg_begin() + 1; + llvm::Value* written = builder->getInt32(0); + + auto elements = Load(ctx, map_ptr); + if (!elements.ok()) { + return elements.status(); + } + + if (elements->size() != FIELDS_CNT) { + return absl::FailedPreconditionError( + absl::Substitute("element count error, expect $0, got $1", FIELDS_CNT, elements->size())); + } + + auto& elements_vec = elements.value(); + auto& map_size = elements_vec[0]; + auto& key_vec = elements_vec[1]; + auto& value_vec = elements_vec[2]; + auto& value_null_vec = elements_vec[3]; + + // *(int32*) row_ptr = map_size + { + CHECK_ABSL_STATUS(BuildStoreOffset(builder, row_ptr, builder->getInt32(0), map_size)); + + written = builder->CreateAdd(written, builder->getInt32(4)); + } + { + // *(key_type[map_size]) (row_ptr + 4) = key_vec + auto row_ptr_with_offset = BuildGetPtrOffset(builder, row_ptr, written); + CHECK_ABSL_STATUSOR(row_ptr_with_offset); + auto s = EncodeArray(ctx, row_ptr_with_offset.value(), key_vec, map_size); + CHECK_ABSL_STATUSOR(s); + written = builder->CreateAdd(written, s.value()); + } + { + // *(value_type[map_size]) (row_ptr + ?) = value_vec + auto row_ptr_with_offset = BuildGetPtrOffset(builder, row_ptr, written); + CHECK_ABSL_STATUSOR(row_ptr_with_offset); + auto s = EncodeArray(ctx, row_ptr_with_offset.value(), value_vec, map_size); + CHECK_ABSL_STATUSOR(s); + written = builder->CreateAdd(written, s.value()); + } + { + // *(bool[map_size]) (row_ptr + ?) = value_null_vec + auto row_ptr_with_offset = BuildGetPtrOffset(builder, row_ptr, written); + CHECK_ABSL_STATUSOR(row_ptr_with_offset); + // TODO(someone): alignment issue, bitwise operation for better performance ? + auto s = EncodeArray(ctx, row_ptr_with_offset.value(), value_null_vec, map_size); + CHECK_ABSL_STATUSOR(s); + written = builder->CreateAdd(written, s.value()); + } + + builder->CreateRet(written); + } + return fn; +} + +absl::StatusOr MapIRBuilder::Encode(CodeGenContextBase* ctx, llvm::Value* row_ptr, + llvm::Value* map_ptr) const { auto builder = ctx->GetBuilder(); - llvm::Value* written = builder->getInt32(0); if (row_ptr->getType() != builder->getInt8Ty()->getPointerTo()) { return absl::FailedPreconditionError( @@ -525,55 +622,10 @@ absl::StatusOr MapIRBuilder::Encode(CodeGenContextBase* ctx, llvm: GetLlvmObjectString(map_ptr->getType()->getPointerElementType()))); } - auto elements = Load(ctx, map_ptr); - if (!elements.ok()) { - return elements.status(); - } - - if (elements->size() != FIELDS_CNT) { - return absl::FailedPreconditionError( - absl::Substitute("element count error, expect $0, got $1", FIELDS_CNT, elements->size())); - } - - auto& elements_vec = elements.value(); - auto& map_size = elements_vec[0]; - auto& key_vec = elements_vec[1]; - auto& value_vec = elements_vec[2]; - auto& value_null_vec = elements_vec[3]; - - // *(int32*) row_ptr = map_size - { - CHECK_ABSL_STATUS(BuildStoreOffset(builder, row_ptr, builder->getInt32(0), map_size)); - - written = builder->CreateAdd(written, builder->getInt32(4)); - } - { - // *(key_type[map_size]) (row_ptr + 4) = key_vec - auto row_ptr_with_offset = BuildGetPtrOffset(builder, row_ptr, written); - CHECK_ABSL_STATUSOR(row_ptr_with_offset); - auto s = EncodeArray(ctx, row_ptr_with_offset.value(), key_vec, map_size); - CHECK_ABSL_STATUSOR(s); - written = builder->CreateAdd(written, s.value()); - } - { - // *(value_type[map_size]) (row_ptr + ?) = value_vec - auto row_ptr_with_offset = BuildGetPtrOffset(builder, row_ptr, written); - CHECK_ABSL_STATUSOR(row_ptr_with_offset); - auto s = EncodeArray(ctx, row_ptr_with_offset.value(), value_vec, map_size); - CHECK_ABSL_STATUSOR(s); - written = builder->CreateAdd(written, s.value()); - } - { - // *(bool[map_size]) (row_ptr + ?) = value_null_vec - auto row_ptr_with_offset = BuildGetPtrOffset(builder, row_ptr, written); - CHECK_ABSL_STATUSOR(row_ptr_with_offset); - // TODO(someone): alignment issue, bitwise operation for better performance ? - auto s = EncodeArray(ctx, row_ptr_with_offset.value(), value_null_vec, map_size); - CHECK_ABSL_STATUSOR(s); - written = builder->CreateAdd(written, s.value()); - } + auto fns = BuildEncodeFn(ctx); + CHECK_ABSL_STATUSOR(fns); - return written; + return builder->CreateCall(fns.value(), {row_ptr, map_ptr}); } absl::StatusOr MapIRBuilder::EncodeArray(CodeGenContextBase* ctx_, llvm::Value* row_ptr, @@ -737,7 +789,7 @@ absl::StatusOr MapIRBuilder::DecodeArrayValue(CodeGenContextBase* }, false); - fn = llvm::Function::Create(fnt, llvm::Function::ExternalLinkage, fn_name, ctx->GetModule()); + fn = llvm::Function::Create(fnt, llvm::GlobalValue::ExternalLinkage, fn_name, ctx->GetModule()); FunctionScopeGuard fg(fn, ctx); auto* sub_builder = ctx->GetBuilder(); @@ -855,7 +907,7 @@ absl::StatusOr MapIRBuilder::GetOrBuildEncodeArrFunction(CodeGe builder->getInt32Ty(), {builder->getInt8Ty()->getPointerTo(), ele_type->getPointerTo(), builder->getInt32Ty()}, false); - fn = llvm::Function::Create(fnt, llvm::Function::ExternalLinkage, fn_name, ctx->GetModule()); + fn = llvm::Function::Create(fnt, llvm::GlobalValue::ExternalLinkage, fn_name, ctx->GetModule()); // enter function FunctionScopeGuard fg(fn, ctx); diff --git a/hybridse/src/codegen/map_ir_builder.h b/hybridse/src/codegen/map_ir_builder.h index 1a7413f3ad5..f7063bcde45 100644 --- a/hybridse/src/codegen/map_ir_builder.h +++ b/hybridse/src/codegen/map_ir_builder.h @@ -17,8 +17,6 @@ #ifndef HYBRIDSE_SRC_CODEGEN_MAP_IR_BUILDER_H_ #define HYBRIDSE_SRC_CODEGEN_MAP_IR_BUILDER_H_ -#include - #include "codegen/struct_ir_builder.h" namespace hybridse { @@ -43,9 +41,13 @@ class MapIRBuilder final : public StructTypeIRBuilder { absl::StatusOr CalEncodeByteSize(CodeGenContextBase* ctx, llvm::Value*) const; + absl::StatusOr BuildEncodeByteSizeFn(CodeGenContextBase* ctx) const; + // Encode the `map_ptr` into `row_ptr`, returns byte size written on success // `row_ptr` is ensured to have enough space - absl::StatusOr Encode(CodeGenContextBase*, llvm::Value* map_ptr, llvm::Value* row_ptr) const; + absl::StatusOr Encode(CodeGenContextBase*, llvm::Value* row_ptr, llvm::Value* map_ptr) const; + + absl::StatusOr BuildEncodeFn(CodeGenContextBase*) const; // Decode the stored map value at address row_ptr absl::StatusOr Decode(CodeGenContextBase*, llvm::Value* row_ptr) const; diff --git a/hybridse/src/codegen/struct_ir_builder.cc b/hybridse/src/codegen/struct_ir_builder.cc index d616522931a..0d08e89aefb 100644 --- a/hybridse/src/codegen/struct_ir_builder.cc +++ b/hybridse/src/codegen/struct_ir_builder.cc @@ -31,31 +31,34 @@ StructTypeIRBuilder::StructTypeIRBuilder(::llvm::Module* m) StructTypeIRBuilder::~StructTypeIRBuilder() {} bool StructTypeIRBuilder::StructCopyFrom(::llvm::BasicBlock* block, ::llvm::Value* src, ::llvm::Value* dist) { - StructTypeIRBuilder* struct_builder = CreateStructTypeIRBuilder(block->getModule(), src->getType()); - bool ok = struct_builder->CopyFrom(block, src, dist); - delete struct_builder; - return ok; + auto struct_builder = CreateStructTypeIRBuilder(block->getModule(), src->getType()); + if (struct_builder.ok()) { + return struct_builder.value()->CopyFrom(block, src, dist); + } + return false; } -StructTypeIRBuilder* StructTypeIRBuilder::CreateStructTypeIRBuilder(::llvm::Module* m, ::llvm::Type* type) { +absl::StatusOr> StructTypeIRBuilder::CreateStructTypeIRBuilder( + ::llvm::Module* m, ::llvm::Type* type) { node::DataType base_type; if (!GetBaseType(type, &base_type)) { - return nullptr; + return absl::UnimplementedError( + absl::StrCat("fail to create struct type ir builder for ", GetLlvmObjectString(type))); } switch (base_type) { case node::kTimestamp: - return new TimestampIRBuilder(m); + return std::make_unique(m); case node::kDate: - return new DateIRBuilder(m); + return std::make_unique(m); case node::kVarchar: - return new StringIRBuilder(m); + return std::make_unique(m); default: { - LOG(WARNING) << "fail to create struct type ir builder for " << DataTypeName(base_type); - return nullptr; + break; } } - return nullptr; + return absl::UnimplementedError( + absl::StrCat("fail to create struct type ir builder for ", GetLlvmObjectString(type))); } absl::StatusOr StructTypeIRBuilder::CreateNull(::llvm::BasicBlock* block) { diff --git a/hybridse/src/codegen/struct_ir_builder.h b/hybridse/src/codegen/struct_ir_builder.h index 9e5437f5158..4c09e488ce9 100644 --- a/hybridse/src/codegen/struct_ir_builder.h +++ b/hybridse/src/codegen/struct_ir_builder.h @@ -19,6 +19,7 @@ #include #include +#include #include "absl/status/statusor.h" #include "base/fe_status.h" @@ -33,7 +34,8 @@ class StructTypeIRBuilder : public TypeIRBuilder { explicit StructTypeIRBuilder(::llvm::Module*); ~StructTypeIRBuilder(); - static StructTypeIRBuilder* CreateStructTypeIRBuilder(::llvm::Module*, ::llvm::Type*); + static absl::StatusOr> CreateStructTypeIRBuilder(::llvm::Module*, + ::llvm::Type*); static bool StructCopyFrom(::llvm::BasicBlock* block, ::llvm::Value* src, ::llvm::Value* dist); virtual bool CopyFrom(::llvm::BasicBlock* block, ::llvm::Value* src, ::llvm::Value* dist) = 0; diff --git a/src/proto/common.proto b/src/proto/common.proto index 1bc539766c3..8241e646f34 100755 --- a/src/proto/common.proto +++ b/src/proto/common.proto @@ -41,12 +41,21 @@ message DbTableNamePair { required string db_name = 1; required string table_name = 2; } +message TableColumnSchema { + optional string name = 1; + optional openmldb.type.DataType type = 2; + repeated TableColumnSchema type_fields = 3; + optional bool not_null = 4 [default = false]; +} + message ColumnDesc { required string name = 1; optional openmldb.type.DataType data_type = 2; optional bool not_null = 3 [default = false]; optional bool is_constant = 4 [default = false]; optional string default_value = 5; + // replacing ColumnDesc::data_type and ColumnDesc::not_null + optional TableColumnSchema schema = 6; } message TTLSt { diff --git a/src/proto/type.proto b/src/proto/type.proto index 5bb6b7faa67..83b80631ca1 100755 --- a/src/proto/type.proto +++ b/src/proto/type.proto @@ -34,6 +34,8 @@ enum DataType { // reserve 9, 10, 11, 12 kVarchar = 13; kString = 14; + kArray = 15; + kMap = 16; } enum IndexType { @@ -77,4 +79,4 @@ enum ProcedureType { enum NotifyType { kTable = 1; kGlobalVar = 2; -} \ No newline at end of file +} diff --git a/src/schema/schema_adapter.cc b/src/schema/schema_adapter.cc index d35061c3886..9d6fa8ffcee 100644 --- a/src/schema/schema_adapter.cc +++ b/src/schema/schema_adapter.cc @@ -20,12 +20,15 @@ #include #include #include +#include "absl/status/status.h" #include "glog/logging.h" +#include "proto/fe_type.pb.h" +#include "proto/type.pb.h" namespace openmldb { namespace schema { -bool SchemaAdapter::ConvertSchemaAndIndex(const ::hybridse::vm::Schema& sql_schema, +bool SchemaAdapter::ConvertSchemaAndIndex(const ::hybridse::codec::Schema& sql_schema, const ::hybridse::vm::IndexList& index, PBSchema* schema_output, PBIndex* index_output) { if (nullptr == schema_output || nullptr == index_output) { @@ -56,8 +59,8 @@ bool SchemaAdapter::ConvertSchemaAndIndex(const ::hybridse::vm::Schema& sql_sche return true; } -bool SchemaAdapter::SubSchema(const ::hybridse::vm::Schema* schema, - const ::google::protobuf::RepeatedField& projection, hybridse::vm::Schema* output) { +bool SchemaAdapter::SubSchema(const ::hybridse::codec::Schema* schema, + const ::google::protobuf::RepeatedField& projection, hybridse::codec::Schema* output) { if (output == nullptr) { LOG(WARNING) << "output ptr is null"; return false; @@ -70,12 +73,12 @@ bool SchemaAdapter::SubSchema(const ::hybridse::vm::Schema* schema, return true; } std::shared_ptr<::hybridse::sdk::Schema> SchemaAdapter::ConvertSchema(const PBSchema& schema) { - ::hybridse::vm::Schema vm_schema; + ::hybridse::codec::Schema vm_schema; ConvertSchema(schema, &vm_schema); return std::make_shared<::hybridse::sdk::SchemaImpl>(vm_schema); } -bool SchemaAdapter::ConvertSchema(const PBSchema& schema, ::hybridse::vm::Schema* output) { +bool SchemaAdapter::ConvertSchema(const PBSchema& schema, ::hybridse::codec::Schema* output) { if (output == nullptr) { LOG(WARNING) << "output ptr is null"; return false; @@ -85,23 +88,18 @@ bool SchemaAdapter::ConvertSchema(const PBSchema& schema, ::hybridse::vm::Schema return false; } for (int32_t i = 0; i < schema.size(); i++) { - const common::ColumnDesc& column = schema.Get(i); - ::hybridse::type::ColumnDef* new_column = output->Add(); - new_column->set_name(column.name()); - new_column->set_is_not_null(column.not_null()); - new_column->set_is_constant(column.is_constant()); - ::hybridse::type::Type type; - if (!ConvertType(column.data_type(), &type)) { - LOG(WARNING) << "type " << ::openmldb::type::DataType_Name(column.data_type()) - << " is not supported"; + const common::ColumnDesc& table_column = schema.Get(i); + ::hybridse::type::ColumnDef* sql_column = output->Add(); + auto s = ConvertColumn(table_column, sql_column); + if (!s.ok()) { + LOG(WARNING) << s.ToString(); return false; } - new_column->set_type(type); } return true; } -bool SchemaAdapter::ConvertSchema(const ::hybridse::vm::Schema& hybridse_schema, PBSchema* schema) { +bool SchemaAdapter::ConvertSchema(const ::hybridse::codec::Schema& hybridse_schema, PBSchema* schema) { if (schema == nullptr) { LOG(WARNING) << "schema is null"; return false; @@ -155,6 +153,62 @@ bool SchemaAdapter::ConvertType(hybridse::node::DataType hybridse_type, openmldb return true; } +absl::Status SchemaAdapter::ConvertType(const hybridse::node::ColumnSchemaNode* sc, common::TableColumnSchema* tbs) { + if (sc == nullptr) { + return absl::InvalidArgumentError("paramter null"); + } + switch (sc->type()) { + case hybridse::node::kBool: + tbs->set_type(openmldb::type::kBool); + break; + case hybridse::node::kInt16: + tbs->set_type(openmldb::type::kSmallInt); + break; + case hybridse::node::kInt32: + tbs->set_type(openmldb::type::kInt); + break; + case hybridse::node::kInt64: + tbs->set_type(openmldb::type::kBigInt); + break; + case hybridse::node::kFloat: + tbs->set_type(openmldb::type::kFloat); + break; + case hybridse::node::kDouble: + tbs->set_type(openmldb::type::kDouble); + break; + case hybridse::node::kDate: + tbs->set_type(openmldb::type::kDate); + break; + case hybridse::node::kTimestamp: + tbs->set_type(openmldb::type::kTimestamp); + break; + case hybridse::node::kVarchar: + tbs->set_type(openmldb::type::kVarchar); + break; + case hybridse::node::kArray: { + tbs->set_type(openmldb::type::kArray); + break; + } + case hybridse::node::kMap: { + tbs->set_type(openmldb::type::kMap); + break; + } + default: + return absl::UnimplementedError(absl::StrCat("unsupported type: ", sc->DebugString())); + } + + for (auto& field_type : sc->generics()) { + auto* field = tbs->add_type_fields(); + auto s = ConvertType(field_type, field); + if (!s.ok()) { + return s; + } + } + + tbs->set_not_null(sc->not_null()); + return absl::OkStatus(); +} + bool SchemaAdapter::ConvertType(openmldb::type::DataType type, hybridse::node::DataType* hybridse_type) { if (hybridse_type == nullptr) { return false; @@ -358,6 +412,99 @@ bool SchemaAdapter::ConvertColumn(const hybridse::type::ColumnDef& sql_column, o return true; } +absl::Status SchemaAdapter::ConvertColumn(const openmldb::common::ColumnDesc& column, + hybridse::type::ColumnDef* sql_column) { + if (column.has_schema()) { + // new schema field + auto s = ConvertSchema(column.schema(), sql_column->mutable_schema()); + if (!s.ok()) { + return s; + } + } else { + // fallback use data_type and not_null + ::hybridse::type::Type ty; + if (!ConvertType(column.data_type(), &ty)) { + return absl::InternalError(absl::StrCat("failed to convert type: ", column.DebugString())); + } + auto sc = sql_column->mutable_schema(); + sc->set_base_type(ty); + sc->set_is_not_null(column.not_null()); + } + + if (sql_column->schema().has_base_type()) { + sql_column->set_type(sql_column->schema().base_type()); + } + sql_column->set_is_not_null(sql_column->schema().is_not_null()); + + sql_column->set_name(column.name()); + sql_column->set_is_constant(column.is_constant()); + return absl::OkStatus(); +} + +absl::Status SchemaAdapter::ConvertSchema(const openmldb::common::TableColumnSchema& ts, + hybridse::type::ColumnSchema* sc) { + switch (ts.type()) { + case openmldb::type::kBool: + sc->set_base_type(::hybridse::type::kBool); + break; + case openmldb::type::kSmallInt: + sc->set_base_type(::hybridse::type::kInt16); + break; + case openmldb::type::kInt: + sc->set_base_type(::hybridse::type::kInt32); + break; + case openmldb::type::kBigInt: + sc->set_base_type(::hybridse::type::kInt64); + break; + case openmldb::type::kFloat: + sc->set_base_type(::hybridse::type::kFloat); + break; + case openmldb::type::kDouble: + sc->set_base_type(::hybridse::type::kDouble); + break; + case openmldb::type::kDate: + sc->set_base_type(::hybridse::type::kDate); + break; + case openmldb::type::kTimestamp: + sc->set_base_type(::hybridse::type::kTimestamp); + break; + case openmldb::type::kVarchar: + case openmldb::type::kString: + sc->set_base_type(::hybridse::type::kVarchar); + break; + + case openmldb::type::kArray: { + auto arr_ty = sc->mutable_array_type(); + if (ts.type_fields_size() != 1) { + return absl::FailedPreconditionError( + absl::StrCat("array type requires type_fields size=1, got size=", ts.type_fields_size())); + } + auto s = ConvertSchema(ts.type_fields().Get(0), arr_ty->mutable_ele_type()); + if (!s.ok()) { + return s; + } + break; + } + case openmldb::type::kMap: { + auto map_ty = sc->mutable_map_type(); + if (ts.type_fields_size() != 2) { + return absl::FailedPreconditionError( + absl::StrCat("map type requires type_fields size=2, got size=", ts.type_fields_size())); + } + auto s = ConvertSchema(ts.type_fields().Get(0), map_ty->mutable_key_type()); + s.Update(ConvertSchema(ts.type_fields().Get(1), map_ty->mutable_value_type())); + if (!s.ok()) { + return s; + } + break; + } + } + + sc->set_is_not_null(ts.not_null()); + + return absl::OkStatus(); +} + std::map SchemaAdapter::GetColMap(const nameserver::TableInfo& table_info) { std::map col_map; for (const auto& col : table_info.column_desc()) { diff --git a/src/schema/schema_adapter.h b/src/schema/schema_adapter.h index c14e366e8de..ecd62369ac4 100644 --- a/src/schema/schema_adapter.h +++ b/src/schema/schema_adapter.h @@ -28,28 +28,31 @@ #include "proto/tablet.pb.h" #include "schema/index_util.h" #include "vm/catalog.h" +#include "node/sql_node.h" namespace openmldb { namespace schema { class SchemaAdapter { public: - static bool ConvertSchemaAndIndex(const ::hybridse::vm::Schema& sql_schema, + static bool ConvertSchemaAndIndex(const ::hybridse::codec::Schema& sql_schema, const ::hybridse::vm::IndexList& index, PBSchema* schema_output, PBIndex* index_output); - static bool SubSchema(const ::hybridse::vm::Schema* schema, + static bool SubSchema(const ::hybridse::codec::Schema* schema, const ::google::protobuf::RepeatedField& projection, - hybridse::vm::Schema* output); + hybridse::codec::Schema* output); - static bool ConvertSchema(const PBSchema& schema, ::hybridse::vm::Schema* output); + static bool ConvertSchema(const PBSchema& schema, ::hybridse::codec::Schema* output); static std::shared_ptr<::hybridse::sdk::Schema> ConvertSchema(const PBSchema& schema); - static bool ConvertSchema(const ::hybridse::vm::Schema& hybridse_schema, PBSchema* schema); + static bool ConvertSchema(const ::hybridse::codec::Schema& hybridse_schema, PBSchema* schema); static bool ConvertType(hybridse::node::DataType hybridse_type, openmldb::type::DataType* type); + static absl::Status ConvertType(const hybridse::node::ColumnSchemaNode* sc, common::TableColumnSchema* tbs); + static bool ConvertType(openmldb::type::DataType type, hybridse::node::DataType* hybridse_type); static bool ConvertType(hybridse::type::Type hybridse_type, openmldb::type::DataType* openmldb_type); @@ -68,6 +71,16 @@ class SchemaAdapter { private: static bool ConvertColumn(const hybridse::type::ColumnDef& sql_column, openmldb::common::ColumnDesc* column); + + // table column definition to SQL type. + // + // NOTE NOT ALL fields from table column are convertable to SQL type, be aware the difference between + // 'table_column_definition' and 'type' from parser. + // For example common::ColumnDesc::default_value does not have corresponding field in hybridse::type::ColumnDef. + static absl::Status ConvertColumn(const openmldb::common::ColumnDesc& column, hybridse::type::ColumnDef* sql_column) + ABSL_ATTRIBUTE_NONNULL(); + static absl::Status ConvertSchema(const openmldb::common::TableColumnSchema&, hybridse::type::ColumnSchema*) + ABSL_ATTRIBUTE_NONNULL(); }; } // namespace schema diff --git a/src/sdk/node_adapter.cc b/src/sdk/node_adapter.cc index 2a7960741a8..063a3071861 100644 --- a/src/sdk/node_adapter.cc +++ b/src/sdk/node_adapter.cc @@ -313,16 +313,17 @@ bool NodeAdapter::TransformToTableDef(::hybridse::node::CreatePlanNode* create_n return false; } add_column_desc->set_name(column_def->GetColumnName()); - add_column_desc->set_not_null(column_def->GetIsNotNull()); column_names.insert(std::make_pair(column_def->GetColumnName(), add_column_desc)); - openmldb::type::DataType data_type; - if (!openmldb::schema::SchemaAdapter::ConvertType(column_def->GetColumnType(), &data_type)) { - status->msg = "column type " + - hybridse::node::DataTypeName(column_def->GetColumnType()) + " is not supported"; + auto s = openmldb::schema::SchemaAdapter::ConvertType(column_def->schema(), + add_column_desc->mutable_schema()); + if (!s.ok()) { + status->msg = s.ToString(); status->code = hybridse::common::kUnsupportSql; return false; } - add_column_desc->set_data_type(data_type); + add_column_desc->set_data_type(add_column_desc->schema().type()); + add_column_desc->set_not_null(add_column_desc->schema().not_null()); + auto default_val = column_def->GetDefaultValue(); if (default_val) { if (default_val->GetExprType() != hybridse::node::kExprPrimary) {