diff --git a/hybridse/src/CMakeLists.txt b/hybridse/src/CMakeLists.txt index 80c5cc2a5a3..4f25d87ab70 100644 --- a/hybridse/src/CMakeLists.txt +++ b/hybridse/src/CMakeLists.txt @@ -48,6 +48,7 @@ hybridse_add_src_and_tests(vm) hybridse_add_src_and_tests(codec) hybridse_add_src_and_tests(case) hybridse_add_src_and_tests(passes) +hybridse_add_src_and_tests(rewriter) get_property(SRC_FILE_LIST_STR GLOBAL PROPERTY PROP_SRC_FILE_LIST) string(REPLACE " " ";" SRC_FILE_LIST ${SRC_FILE_LIST_STR}) diff --git a/hybridse/src/rewriter/ast_rewriter.cc b/hybridse/src/rewriter/ast_rewriter.cc new file mode 100644 index 00000000000..f0b77c75bde --- /dev/null +++ b/hybridse/src/rewriter/ast_rewriter.cc @@ -0,0 +1,279 @@ +/** + * Copyright (c) 2024 4Paradigm + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "rewriter/ast_rewriter.h" + +#include + +#include "plan/plan_api.h" +#include "zetasql/parser/parse_tree_manual.h" +#include "zetasql/parser/parser.h" +#include "zetasql/parser/unparser.h" + +namespace hybridse { +namespace rewriter { + +// unparser that make some rewrites so outputed SQL is +// compatible with ANSI SQL as much as can +class ANSISQLRewriteUnparser : public zetasql::parser::Unparser { + public: + explicit ANSISQLRewriteUnparser(std::string* unparsed) : zetasql::parser::Unparser(unparsed) {} + ~ANSISQLRewriteUnparser() override {} + ANSISQLRewriteUnparser(const ANSISQLRewriteUnparser&) = delete; + ANSISQLRewriteUnparser& operator=(const ANSISQLRewriteUnparser&) = delete; + + void visitASTSelect(const zetasql::ASTSelect* node, void* data) override { + while (true) { + absl::string_view filter_col; + + // 1. filter condition is 'col = 1' + if (node->where_clause() != nullptr && + node->where_clause()->expression()->node_kind() == zetasql::AST_BINARY_EXPRESSION) { + auto expr = node->where_clause()->expression()->GetAsOrNull(); + if (expr && expr->op() == zetasql::ASTBinaryExpression::Op::EQ && !expr->is_not()) { + { + auto lval = expr->lhs()->GetAsOrNull(); + auto rval = expr->rhs()->GetAsOrNull(); + if (lval && rval && lval->image() == "1") { + // TODO(someone): + // 1. consider lval->iamge() as '1L' + // 2. consider rval as . + filter_col = rval->last_name()->GetAsStringView(); + } + } + if (filter_col.empty()) { + auto lval = expr->rhs()->GetAsOrNull(); + auto rval = expr->lhs()->GetAsOrNull(); + if (lval && rval && lval->image() == "1") { + // TODO(someone): + // 1. consider lval->iamge() as '1L' + // 2. consider rval as . + filter_col = rval->last_name()->GetAsStringView(); + } + } + } + } + + // 2. FROM a subquery: SELECT ... t1 LEFT JOIN t2 WINDOW + const zetasql::ASTPathExpression* join_lhs_key = nullptr; + const zetasql::ASTPathExpression* join_rhs_key = nullptr; + if (node->from_clause() == nullptr) { + break; + } + auto sub = node->from_clause()->table_expression()->GetAsOrNull(); + if (!sub) { + break; + } + auto subquery = sub->subquery(); + if (subquery->with_clause() != nullptr || subquery->order_by() != nullptr || + subquery->limit_offset() != nullptr) { + break; + } + + auto select = subquery->query_expr()->GetAsOrNull(); + // select have window + if (select->window_clause() == nullptr || select->from_clause() == nullptr) { + break; + } + + // 3. CHECK FROM CLAUSE: must 't1 LEFT JOIN t2 on t1.key = t2.key' + auto join = select->from_clause()->table_expression()->GetAsOrNull(); + if (join == nullptr || join->join_type() != zetasql::ASTJoin::LEFT || join->on_clause() == nullptr) { + break; + } + auto on_expr = join->on_clause()->expression()->GetAsOrNull(); + if (on_expr == nullptr || on_expr->is_not() || on_expr->op() != zetasql::ASTBinaryExpression::EQ) { + break; + } + + // still might null + join_lhs_key = on_expr->lhs()->GetAsOrNull(); + join_rhs_key = on_expr->rhs()->GetAsOrNull(); + if (join_lhs_key == nullptr || join_rhs_key == nullptr) { + break; + } + + // 3. CHECK row_id is row_number() over w FROM select_list + bool found = false; + absl::string_view window_name; + for (auto col : select->select_list()->columns()) { + if (col->alias() && col->alias()->GetAsStringView() == filter_col) { + auto agg_func = col->expression()->GetAsOrNull(); + if (!agg_func || !agg_func->function()) { + break; + } + + auto w = agg_func->window_spec(); + if (!w || w->base_window_name() == nullptr) { + break; + } + window_name = w->base_window_name()->GetAsStringView(); + + auto ph = agg_func->function()->function(); + if (ph->num_names() == 1 && + absl::AsciiStrToLower(ph->first_name()->GetAsStringView()) == "row_number") { + found = true; + } + } + } + if (!found || window_name.empty()) { + break; + } + + // 4. CHECK WINDOW CLAUSE + { + if (select->window_clause()->windows().size() != 1) { + // targeting single window only + break; + } + auto win = select->window_clause()->windows().front(); + if (win->name()->GetAsStringView() != window_name) { + break; + } + auto spec = win->window_spec(); + if (spec->window_frame() != nullptr || spec->partition_by() == nullptr || spec->order_by() == nullptr) { + // TODO(someone): allow unbounded window frame + break; + } + + // PARTITION BY contains join_lhs_key + // ORDER BY is join_rhs_key + bool partition_meet = false; + for (auto expr : spec->partition_by()->partitioning_expressions()) { + auto e = expr->GetAsOrNull(); + if (e) { + if (e->last_name()->GetAsStringView() == join_lhs_key->last_name()->GetAsStringView()) { + partition_meet = true; + } + } + } + + if (!partition_meet) { + break; + } + + if (spec->order_by()->ordering_expressions().size() != 1) { + break; + } + + if (spec->order_by()->ordering_expressions().front()->ordering_spec() != + zetasql::ASTOrderingExpression::DESC) { + break; + } + + auto e = spec->order_by() + ->ordering_expressions() + .front() + ->expression() + ->GetAsOrNull(); + if (!e) { + break; + } + + // rewrite + { + PrintOpenParenIfNeeded(node); + println(); + print("SELECT"); + if (node->hint() != nullptr) { + node->hint()->Accept(this, data); + } + if (node->anonymization_options() != nullptr) { + print("WITH ANONYMIZATION OPTIONS"); + node->anonymization_options()->Accept(this, data); + } + if (node->distinct()) { + print("DISTINCT"); + } + + // Visit all children except hint() and anonymization_options, which we + // processed above. We can't just use visitASTChildren(node, data) because + // we need to insert the DISTINCT modifier after the hint and anonymization + // nodes and before everything else. + for (int i = 0; i < node->num_children(); ++i) { + const zetasql::ASTNode* child = node->child(i); + if (child == node->from_clause()) { + // this from subquery will simplified to join + println(); + print("FROM"); + visitASTJoinRewrited(join, e, data); + } else if (child != node->hint() && child != node->anonymization_options() && + child != node->where_clause()) { + child->Accept(this, data); + } + } + + println(); + PrintCloseParenIfNeeded(node); + return; + } + } + + break; + } + + zetasql::parser::Unparser::visitASTSelect(node, data); + } + + void visitASTJoinRewrited(const zetasql::ASTJoin* node, const zetasql::ASTPathExpression* order, void* data) { + node->child(0)->Accept(this, data); + + if (node->join_type() == zetasql::ASTJoin::COMMA) { + print(","); + } else { + println(); + if (node->natural()) { + print("NATURAL"); + } + print("LAST"); + print(node->GetSQLForJoinHint()); + + print("JOIN"); + } + println(); + + // This will print hints, the rhs, and the ON or USING clause. + for (int i = 1; i < node->num_children(); i++) { + node->child(i)->Accept(this, data); + if (node->child(i) == node->rhs() && order) { + // optional order by after rhs + print("ORDER BY"); + order->Accept(this, data); + } + } + } +}; + +absl::StatusOr Rewrite(absl::string_view query) { + std::unique_ptr ast; + auto s = hybridse::plan::ParseStatement(query, &ast); + if (!s.ok()) { + return s; + } + + if (ast->statement() && ast->statement()->node_kind() == zetasql::AST_QUERY_STATEMENT) { + std::string unparsed_; + ANSISQLRewriteUnparser unparser(&unparsed_); + ast->statement()->Accept(&unparser, nullptr); + unparser.FlushLine(); + return unparsed_; + } + + return std::string(query); +} + +} // namespace rewriter +} // namespace hybridse diff --git a/hybridse/src/rewriter/ast_rewriter.h b/hybridse/src/rewriter/ast_rewriter.h new file mode 100644 index 00000000000..17ea7ad0d04 --- /dev/null +++ b/hybridse/src/rewriter/ast_rewriter.h @@ -0,0 +1,32 @@ +/** + * Copyright (c) 2024 4Paradigm + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef HYBRIDSE_SRC_REWRITER_AST_REWRITER_H_ +#define HYBRIDSE_SRC_REWRITER_AST_REWRITER_H_ + +#include + +#include "absl/status/statusor.h" + +namespace hybridse { +namespace rewriter { + +absl::StatusOr Rewrite(absl::string_view query); + +} // namespace rewriter +} // namespace hybridse + +#endif // HYBRIDSE_SRC_REWRITER_AST_REWRITER_H_ diff --git a/hybridse/src/rewriter/ast_rewriter_test.cc b/hybridse/src/rewriter/ast_rewriter_test.cc new file mode 100644 index 00000000000..38e0c7b3115 --- /dev/null +++ b/hybridse/src/rewriter/ast_rewriter_test.cc @@ -0,0 +1,67 @@ +/** + * Copyright 2024 OpenMLDB authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "rewriter/ast_rewriter.h" + +#include +#include + +#include "gtest/gtest.h" +#include "plan/plan_api.h" +#include "zetasql/parser/parser.h" + +namespace hybridse { +namespace rewriter { + +class ASTRewriterTest : public ::testing::Test {}; + +TEST_F(ASTRewriterTest, LastJoin) { + std::string str = R"( +SELECT id, val, k, ts, idr, valr FROM ( + SELECT t1.*, t2.id as idr, t2.val as valr, row_number() over w as row_id + FROM t1 LEFT JOIN t2 ON t1.k = t2.k + WINDOW w as (PARTITION BY t1.id,t1.k order by t2.ts desc) +) t WHERE row_id = 1)"; + + auto s = hybridse::rewriter::Rewrite(str); + ASSERT_TRUE(s.ok()) << s.status(); + + ASSERT_EQ(R"(SELECT + id, + val, + k, + ts, + idr, + valr +FROM t1 +LAST JOIN +t2 ORDER BY t2.ts +ON t1.k = t2.k +)", + s.value()); + + std::unique_ptr out; + auto ss = ::hybridse::plan::ParseStatement(s.value(), &out); + ASSERT_TRUE(ss.ok()) << ss; +} + +} // namespace rewriter +} // namespace hybridse + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +}