Skip to content

Commit

Permalink
feat(offline): support UNION ALL/DISTINCT
Browse files Browse the repository at this point in the history
  • Loading branch information
aceforeverd committed Dec 1, 2023
1 parent 8b54dec commit 9b2abeb
Show file tree
Hide file tree
Showing 8 changed files with 175 additions and 13 deletions.
6 changes: 3 additions & 3 deletions hybridse/include/vm/physical_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -1423,7 +1423,7 @@ class PhysicalRequestJoinNode : public PhysicalBinaryNode {
class PhysicalSetOperationNode : public PhysicalOpNode {
public:
PhysicalSetOperationNode(node::SetOperationType type, absl::Span<PhysicalOpNode *const> inputs, bool distinct)
: PhysicalOpNode(kPhysicalOpSetOperation, false), op_type_(type), distinct_(distinct) {
: PhysicalOpNode(kPhysicalOpSetOperation, false), set_type_(type), distinct_(distinct) {
for (auto n : inputs) {
AddProducer(n);
}
Expand All @@ -1435,7 +1435,7 @@ class PhysicalSetOperationNode : public PhysicalOpNode {
}
}

if (group_optimized && op_type_ == node::SetOperationType::UNION) {
if (group_optimized && set_type_ == node::SetOperationType::UNION) {
output_type_ = kSchemaTypeGroup;
} else {
output_type_ = kSchemaTypeTable;
Expand All @@ -1452,7 +1452,7 @@ class PhysicalSetOperationNode : public PhysicalOpNode {

absl::StatusOr<ColProducerTraceInfo> TraceColID(absl::string_view col_name) const override;

node::SetOperationType op_type_;
node::SetOperationType set_type_;
const bool distinct_ = false;
static PhysicalSetOperationNode *CastFrom(PhysicalOpNode *node);
};
Expand Down
2 changes: 1 addition & 1 deletion hybridse/src/passes/physical/group_and_sort_optimized.cc
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,7 @@ bool GroupAndSortOptimized::KeysOptimizedImpl(const SchemasContext* root_schemas
}
vm::PhysicalSetOperationNode* opt_set = nullptr;
if (!plan_ctx_
->CreateOp<vm::PhysicalSetOperationNode>(&opt_set, set_op->op_type_, opt_inputs, set_op->distinct_)
->CreateOp<vm::PhysicalSetOperationNode>(&opt_set, set_op->set_type_, opt_inputs, set_op->distinct_)

Check warning on line 732 in hybridse/src/passes/physical/group_and_sort_optimized.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: hybridse/src/passes/physical/group_and_sort_optimized.cc:732: Lines should be <= 120 characters long [whitespace/line_length] [2]
.isOK()) {
return false;

Check warning on line 734 in hybridse/src/passes/physical/group_and_sort_optimized.cc

View check run for this annotation

Codecov / codecov/patch

hybridse/src/passes/physical/group_and_sort_optimized.cc#L734

Added line #L734 was not covered by tests
}
Expand Down
4 changes: 2 additions & 2 deletions hybridse/src/vm/physical_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1234,13 +1234,13 @@ Status PhysicalSetOperationNode::InitSchema(PhysicalPlanContext* ctx) {
Status PhysicalSetOperationNode::WithNewChildren(node::NodeManager* nm, const std::vector<PhysicalOpNode*>& children,

Check warning on line 1234 in hybridse/src/vm/physical_op.cc

View check run for this annotation

Codecov / codecov/patch

hybridse/src/vm/physical_op.cc#L1234

Added line #L1234 was not covered by tests
PhysicalOpNode** out) {
absl::Span<PhysicalOpNode* const> sp = absl::MakeSpan(children);
*out = nm->RegisterNode(new PhysicalSetOperationNode(op_type_, sp, distinct_));
*out = nm->RegisterNode(new PhysicalSetOperationNode(set_type_, sp, distinct_));

Check warning on line 1237 in hybridse/src/vm/physical_op.cc

View check run for this annotation

Codecov / codecov/patch

hybridse/src/vm/physical_op.cc#L1236-L1237

Added lines #L1236 - L1237 were not covered by tests
return Status::OK();
}

void PhysicalSetOperationNode::Print(std::ostream& output, const std::string& tab) const {
PhysicalOpNode::Print(output, tab);
output << "(" << node::SetOperatorName(op_type_, distinct_) << ")\n";
output << "(" << node::SetOperatorName(set_type_, distinct_) << ")\n";
PrintChildren(output, tab);
}

Expand Down
7 changes: 6 additions & 1 deletion hybridse/src/vm/runner_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -523,8 +523,13 @@ ClusterTask RunnerBuilder::Build(PhysicalOpNode* node, Status& status) {
}
case kPhysicalOpSetOperation: {
auto set = dynamic_cast<vm::PhysicalSetOperationNode*>(node);
if (set->distinct_) {
status.msg = "online un-implemented: UNION DISTINCT";
status.code = common::kExecutionPlanError;
return fail;

Check warning on line 529 in hybridse/src/vm/runner_builder.cc

View check run for this annotation

Codecov / codecov/patch

hybridse/src/vm/runner_builder.cc#L527-L529

Added lines #L527 - L529 were not covered by tests
}
auto set_runner =
CreateRunner<SetOperationRunner>(id_++, node->schemas_ctx(), set->op_type_, set->distinct_);
CreateRunner<SetOperationRunner>(id_++, node->schemas_ctx(), set->set_type_, set->distinct_);
std::vector<ClusterTask> tasks;
for (auto n : node->GetProducers()) {
auto task = Build(n, status);
Expand Down
2 changes: 0 additions & 2 deletions hybridse/src/vm/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,6 @@ Status BatchModeTransformer::TransformSetOperation(const node::SetOperationPlanN
PhysicalSetOperationNode** out) {
CHECK_TRUE(node != nullptr && out != nullptr, kPlanError, "Input node or output node is null");

CHECK_TRUE(!node->distinct(), common::kPhysicalPlanError, "un-implemented: UNION DISTINCT");

std::vector<PhysicalOpNode*> inputs;
const SchemasContext* expect_sc = nullptr;
for (auto n : node->inputs()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ import com._4paradigm.hybridse.vm.{CoreAPI, Engine, PhysicalConstProjectNode, Ph
PhysicalDataProviderNode, PhysicalFilterNode, PhysicalGroupAggrerationNode, PhysicalGroupNode, PhysicalJoinNode,
PhysicalLimitNode, PhysicalLoadDataNode, PhysicalOpNode, PhysicalOpType, PhysicalProjectNode, PhysicalRenameNode,
PhysicalSelectIntoNode, PhysicalSimpleProjectNode, PhysicalSortNode, PhysicalTableProjectNode,
PhysicalWindowAggrerationNode, ProjectType}
PhysicalWindowAggrerationNode, ProjectType, PhysicalSetOperationNode}
import com._4paradigm.openmldb.batch.api.OpenmldbSession
import com._4paradigm.openmldb.batch.nodes.{ConstProjectPlan, CreateTablePlan, DataProviderPlan, FilterPlan,
GroupByAggregationPlan, GroupByPlan, JoinPlan, LimitPlan, LoadDataPlan, RenamePlan, RowProjectPlan, SelectIntoPlan,
SimpleProjectPlan, SortByPlan, WindowAggPlan}
SimpleProjectPlan, SortByPlan, WindowAggPlan, SetOperationPlan}
import com._4paradigm.openmldb.batch.utils.{DataTypeUtil, ExternalUdfUtil, GraphvizUtil, HybridseUtil, NodeIndexInfo,
NodeIndexType}
import com._4paradigm.openmldb.sdk.impl.SqlClusterExecutor
Expand Down Expand Up @@ -271,6 +271,8 @@ class SparkPlanner(session: SparkSession, config: OpenmldbBatchConfig, sparkAppN
SelectIntoPlan.gen(ctx, PhysicalSelectIntoNode.CastFrom(root), children.head)
case PhysicalOpType.kPhysicalCreateTable =>
CreateTablePlan.gen(ctx, PhysicalCreateTableNode.CastFrom(root))
case PhysicalOpType.kPhysicalOpSetOperation =>
SetOperationPlan.gen(ctx, PhysicalSetOperationNode.CastFrom(root), children)
case _ =>
throw new UnsupportedHybridSeException(s"Plan type $opType not supported")
}
Expand Down Expand Up @@ -399,5 +401,3 @@ class SparkPlanner(session: SparkSession, config: OpenmldbBatchConfig, sparkAppN
}
}
}


Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Copyright 2021 4Paradigm
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com._4paradigm.openmldb.batch.nodes

import com._4paradigm.openmldb.batch.PlanContext
import com._4paradigm.hybridse.vm.PhysicalSetOperationNode
import com._4paradigm.hybridse.node.SetOperationType
import com._4paradigm.openmldb.batch.SparkInstance
import org.slf4j.LoggerFactory
import com._4paradigm.hybridse.sdk.HybridSeException

// UNION [ ALL | DISTINCT ] : YES
// EXCEPT : NO
// INTERSECT : NO
object SetOperationPlan {

Check warning on line 29 in java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/SetOperationPlan.scala

View check run for this annotation

Codecov / codecov/patch

java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/SetOperationPlan.scala#L29

Added line #L29 was not covered by tests
private val logger = LoggerFactory.getLogger(this.getClass)

def gen(
ctx: PlanContext,
node: PhysicalSetOperationNode,
inputs: Array[SparkInstance]
): SparkInstance = {
val setType = node.getSet_type_()
if (setType != SetOperationType.UNION) {
throw new HybridSeException(s"Set Operation type $setType not supported")

Check warning on line 39 in java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/SetOperationPlan.scala

View check run for this annotation

Codecov / codecov/patch

java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/SetOperationPlan.scala#L39

Added line #L39 was not covered by tests
}

if (inputs.size < 2) {
throw new HybridSeException(s"Set Operation requires input size >= 2")

Check warning on line 43 in java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/SetOperationPlan.scala

View check run for this annotation

Codecov / codecov/patch

java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/SetOperationPlan.scala#L43

Added line #L43 was not covered by tests
}

val unionAll = inputs

Check warning on line 46 in java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/SetOperationPlan.scala

View check run for this annotation

Codecov / codecov/patch

java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/SetOperationPlan.scala#L46

Added line #L46 was not covered by tests
.map(inst => inst.getDf())
.reduceLeft({ (acc, df) =>
{
acc.union(df)
}
})

val outputDf = if (node.getDistinct_()) {
unionAll.distinct()
} else {
unionAll
}

SparkInstance.createConsideringIndex(ctx, node.GetNodeId(), outputDf)
}
}

Check warning on line 62 in java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/SetOperationPlan.scala

View check run for this annotation

Codecov / codecov/patch

java/openmldb-batch/src/main/scala/com/_4paradigm/openmldb/batch/nodes/SetOperationPlan.scala#L62

Added line #L62 was not covered by tests
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Copyright 2021 4Paradigm
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com._4paradigm.openmldb.batch

import com._4paradigm.openmldb.batch.api.OpenmldbSession
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{
IntegerType,
StringType,
StructField,
StructType
}
import com._4paradigm.openmldb.batch.utils.SparkUtil

class TestSetOperation extends SparkTestSuite {

test("Test UNION ALL") {
val spark = getSparkSession
val sess = new OpenmldbSession(spark)

val schema = StructType(
List(StructField("id", IntegerType), StructField("user", StringType))
)
val data1 = Seq(Row(1, "tom"), Row(2, "amy"))
val df1 = spark.createDataFrame(spark.sparkContext.makeRDD(data1), schema)
val data2 = Seq(Row(1, "tom"))
val df2 = spark.createDataFrame(spark.sparkContext.makeRDD(data2), schema)

sess.registerTable("t1", df1)
sess.registerTable("t2", df2)
df1.createOrReplaceTempView("t1")
df2.createOrReplaceTempView("t2")

val sqlText = "SELECT * FROM t1 UNION ALL SELECT * FROM t2"

val outputDf = sess.sql(sqlText)
outputDf.show()
val sparksqlOutputDf = sess.sparksql(sqlText)
sparksqlOutputDf.show()
assert(outputDf.getSparkDf().count() == 3)
assert(
SparkUtil.approximateDfEqual(
outputDf.getSparkDf(),
sparksqlOutputDf,
true
)
)
}

test("Test UNION DISTINCT") {
val spark = getSparkSession
val sess = new OpenmldbSession(spark)

val schema = StructType(
List(StructField("id", IntegerType), StructField("user", StringType))
)
val data1 = Seq(Row(1, "tom"), Row(2, "amy"))
val df1 = spark.createDataFrame(spark.sparkContext.makeRDD(data1), schema)
val data2 = Seq(Row(1, "tom"))
val df2 = spark.createDataFrame(spark.sparkContext.makeRDD(data2), schema)

sess.registerTable("t1", df1)
sess.registerTable("t2", df2)
df1.createOrReplaceTempView("t1")
df2.createOrReplaceTempView("t2")

val sqlText = "SELECT * FROM t1 UNION DISTINCT SELECT * FROM t2"

val outputDf = sess.sql(sqlText)
outputDf.show()
val sparksqlOutputDf = sess.sparksql(sqlText)
sparksqlOutputDf.show()
assert(outputDf.getSparkDf().count() == 2)
assert(
SparkUtil.approximateDfEqual(
outputDf.getSparkDf(),
sparksqlOutputDf,
true
)
)
}

}

0 comments on commit 9b2abeb

Please sign in to comment.