From 559eb25244d488b3ca8a5dcba189ac27d5d71dd1 Mon Sep 17 00:00:00 2001 From: "Min(Dongmin Yu)" Date: Sat, 30 Apr 2022 00:28:35 +0900 Subject: [PATCH 1/2] Revert "Fix unexpected predicate elimination in correlated subquery (#42)" This reverts commit eae7c0f61f08bfe7ece6d98139671355bc010bf9. --- .../sql/planner/EqualityInference.java | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/EqualityInference.java b/presto-main/src/main/java/io/prestosql/sql/planner/EqualityInference.java index c219427895ef..54506c012b29 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/EqualityInference.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/EqualityInference.java @@ -24,6 +24,7 @@ import io.prestosql.metadata.Metadata; import io.prestosql.sql.tree.ComparisonExpression; import io.prestosql.sql.tree.Expression; +import io.prestosql.sql.tree.SymbolReference; import io.prestosql.util.DisjointSet; import java.util.ArrayList; @@ -321,16 +322,16 @@ Expression getScopedCanonical(Expression expression, Predicate symbolSco } Collection equivalences = equalitySets.get(canonicalIndex); -// if (expression instanceof SymbolReference) { -// boolean inScope = equivalences.stream() -// .filter(SymbolReference.class::isInstance) -// .map(Symbol::from) -// .anyMatch(symbolScope); -// -// if (!inScope) { -// return null; -// } -// } + if (expression instanceof SymbolReference) { + boolean inScope = equivalences.stream() + .filter(SymbolReference.class::isInstance) + .map(Symbol::from) + .anyMatch(symbolScope); + + if (!inScope) { + return null; + } + } Set candidates = equivalences.stream() .filter(e -> isScoped(e, symbolScope)) From ac60b1927a637a79cf1ad0791182355c4f9e643c Mon Sep 17 00:00:00 2001 From: "Min(Dongmin Yu)" Date: Sat, 30 Apr 2022 00:59:03 +0900 Subject: [PATCH 2/2] Test some subquery planning changes --- .../io/prestosql/sql/analyzer/Analysis.java | 87 ++++++++++++------- .../sql/analyzer/ExpressionAnalyzer.java | 12 +-- .../java/io/prestosql/sql/analyzer/Scope.java | 10 +++ .../prestosql/sql/planner/QueryPlanner.java | 12 +-- .../sql/planner/RelationPlanner.java | 16 ++-- .../sql/planner/SubqueryPlanner.java | 66 ++++++++++---- .../tests/AbstractTestEngineOnlyQueries.java | 6 -- 7 files changed, 135 insertions(+), 74 deletions(-) diff --git a/presto-main/src/main/java/io/prestosql/sql/analyzer/Analysis.java b/presto-main/src/main/java/io/prestosql/sql/analyzer/Analysis.java index 7dc216dc2300..986d09297c42 100644 --- a/presto-main/src/main/java/io/prestosql/sql/analyzer/Analysis.java +++ b/presto-main/src/main/java/io/prestosql/sql/analyzer/Analysis.java @@ -13,12 +13,10 @@ */ package io.prestosql.sql.analyzer; -import com.google.common.collect.ArrayListMultimap; import com.google.common.collect.HashMultimap; import com.google.common.collect.HashMultiset; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ListMultimap; import com.google.common.collect.Multimap; import com.google.common.collect.Multiset; import com.google.common.collect.Streams; @@ -69,6 +67,7 @@ import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; import java.util.Deque; import java.util.HashSet; import java.util.LinkedHashMap; @@ -87,7 +86,6 @@ import static com.google.common.collect.ImmutableSet.toImmutableSet; import static java.lang.String.format; import static java.util.Collections.emptyList; -import static java.util.Collections.unmodifiableList; import static java.util.Collections.unmodifiableMap; import static java.util.Collections.unmodifiableSet; import static java.util.Objects.requireNonNull; @@ -139,11 +137,7 @@ public class Analysis private final Map, Expression> joins = new LinkedHashMap<>(); private final Map, JoinUsingAnalysis> joinUsing = new LinkedHashMap<>(); - - private final ListMultimap, InPredicate> inPredicatesSubqueries = ArrayListMultimap.create(); - private final ListMultimap, SubqueryExpression> scalarSubqueries = ArrayListMultimap.create(); - private final ListMultimap, ExistsPredicate> existsSubqueries = ArrayListMultimap.create(); - private final ListMultimap, QuantifiedComparisonExpression> quantifiedComparisonSubqueries = ArrayListMultimap.create(); + private final Map, SubqueryAnalysis> subqueries = new LinkedHashMap<>(); private final Map, TableEntry> tables = new LinkedHashMap<>(); @@ -415,11 +409,11 @@ public Expression getJoinCriteria(Join join) public void recordSubqueries(Node node, ExpressionAnalysis expressionAnalysis) { - NodeRef key = NodeRef.of(node); - this.inPredicatesSubqueries.putAll(key, dereference(expressionAnalysis.getSubqueryInPredicates())); - this.scalarSubqueries.putAll(key, dereference(expressionAnalysis.getScalarSubqueries())); - this.existsSubqueries.putAll(key, dereference(expressionAnalysis.getExistsSubqueries())); - this.quantifiedComparisonSubqueries.putAll(key, dereference(expressionAnalysis.getQuantifiedComparisons())); + SubqueryAnalysis subqueries = this.subqueries.computeIfAbsent(NodeRef.of(node), key -> new SubqueryAnalysis()); + subqueries.addInPredicates(dereference(expressionAnalysis.getSubqueryInPredicates())); + subqueries.addSubqueries(dereference(expressionAnalysis.getScalarSubqueries())); + subqueries.addExistsSubqueries(dereference(expressionAnalysis.getExistsSubqueries())); + subqueries.addQuantifiedComparisons(dereference(expressionAnalysis.getQuantifiedComparisons())); } private List dereference(Collection> nodeRefs) @@ -429,24 +423,9 @@ private List dereference(Collection> nodeRefs) .collect(toImmutableList()); } - public List getInPredicateSubqueries(Node node) - { - return ImmutableList.copyOf(inPredicatesSubqueries.get(NodeRef.of(node))); - } - - public List getScalarSubqueries(Node node) + public SubqueryAnalysis getSubqueries(Node node) { - return ImmutableList.copyOf(scalarSubqueries.get(NodeRef.of(node))); - } - - public List getExistsSubqueries(Node node) - { - return ImmutableList.copyOf(existsSubqueries.get(NodeRef.of(node))); - } - - public List getQuantifiedComparisonSubqueries(Node node) - { - return unmodifiableList(quantifiedComparisonSubqueries.get(NodeRef.of(node))); + return subqueries.computeIfAbsent(NodeRef.of(node), key -> new SubqueryAnalysis()); } public void setWindowFunctions(QuerySpecification node, List functions) @@ -1199,6 +1178,54 @@ public Optional getOrdinalityField() } } + public static class SubqueryAnalysis + { + private final List inPredicatesSubqueries = new ArrayList<>(); + private final List subqueries = new ArrayList<>(); + private final List existsSubqueries = new ArrayList<>(); + private final List quantifiedComparisonSubqueries = new ArrayList<>(); + + public void addInPredicates(List expressions) + { + inPredicatesSubqueries.addAll(expressions); + } + + public void addSubqueries(List expressions) + { + subqueries.addAll(expressions); + } + + public void addExistsSubqueries(List expressions) + { + existsSubqueries.addAll(expressions); + } + + public void addQuantifiedComparisons(List expressions) + { + quantifiedComparisonSubqueries.addAll(expressions); + } + + public List getInPredicatesSubqueries() + { + return Collections.unmodifiableList(inPredicatesSubqueries); + } + + public List getSubqueries() + { + return Collections.unmodifiableList(subqueries); + } + + public List getExistsSubqueries() + { + return Collections.unmodifiableList(existsSubqueries); + } + + public List getQuantifiedComparisonSubqueries() + { + return Collections.unmodifiableList(quantifiedComparisonSubqueries); + } + } + public static final class AccessControlInfo { private final AccessControl accessControl; diff --git a/presto-main/src/main/java/io/prestosql/sql/analyzer/ExpressionAnalyzer.java b/presto-main/src/main/java/io/prestosql/sql/analyzer/ExpressionAnalyzer.java index 6c6544e72b1a..93df7fbec467 100644 --- a/presto-main/src/main/java/io/prestosql/sql/analyzer/ExpressionAnalyzer.java +++ b/presto-main/src/main/java/io/prestosql/sql/analyzer/ExpressionAnalyzer.java @@ -216,7 +216,7 @@ public class ExpressionAnalyzer private final boolean isDescribe; private final Map, ResolvedFunction> resolvedFunctions = new LinkedHashMap<>(); - private final Set> scalarSubqueries = new LinkedHashSet<>(); + private final Set> subqueries = new LinkedHashSet<>(); private final Set> existsSubqueries = new LinkedHashSet<>(); private final Map, Type> expressionCoercions = new LinkedHashMap<>(); private final Set> typeOnlyCoercions = new LinkedHashSet<>(); @@ -353,9 +353,9 @@ private Type analyze(Expression expression, Scope baseScope, Context context) return visitor.process(expression, new StackableAstVisitor.StackableAstVisitorContext<>(context)); } - public Set> getScalarSubqueries() + public Set> getSubqueries() { - return unmodifiableSet(scalarSubqueries); + return unmodifiableSet(subqueries); } public Set> getExistsSubqueries() @@ -1511,7 +1511,7 @@ else if (previousNode instanceof QuantifiedComparisonExpression) { quantifiedComparisons.add(NodeRef.of((QuantifiedComparisonExpression) previousNode)); } else { - scalarSubqueries.add(NodeRef.of(node)); + subqueries.add(NodeRef.of(node)); } Type type = getOnlyElement(queryScope.getRelationType().getVisibleFields()).getType(); @@ -1892,7 +1892,7 @@ public static ExpressionAnalysis analyzeExpressions( analyzer.getExpressionTypes(), analyzer.getExpressionCoercions(), analyzer.getSubqueryInPredicates(), - analyzer.getScalarSubqueries(), + analyzer.getSubqueries(), analyzer.getExistsSubqueries(), analyzer.getColumnReferences(), analyzer.getTypeOnlyCoercions(), @@ -1939,7 +1939,7 @@ public static ExpressionAnalysis analyzeExpression( expressionTypes, expressionCoercions, analyzer.getSubqueryInPredicates(), - analyzer.getScalarSubqueries(), + analyzer.getSubqueries(), analyzer.getExistsSubqueries(), analyzer.getColumnReferences(), analyzer.getTypeOnlyCoercions(), diff --git a/presto-main/src/main/java/io/prestosql/sql/analyzer/Scope.java b/presto-main/src/main/java/io/prestosql/sql/analyzer/Scope.java index ba57a5a5a78b..5248e226a596 100644 --- a/presto-main/src/main/java/io/prestosql/sql/analyzer/Scope.java +++ b/presto-main/src/main/java/io/prestosql/sql/analyzer/Scope.java @@ -339,6 +339,16 @@ public static final class Builder private Optional parent = Optional.empty(); private boolean queryBoundary; + public Builder like(Scope other) + { + relationId = other.relationId; + relationType = other.relation; + namedQueries.putAll(other.namedQueries); + parent = other.parent; + queryBoundary = other.queryBoundary; + return this; + } + public Builder withRelationType(RelationId relationId, RelationType relationType) { this.relationId = requireNonNull(relationId, "relationId is null"); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/QueryPlanner.java b/presto-main/src/main/java/io/prestosql/sql/planner/QueryPlanner.java index 38ff34d161e4..f44dc40de71c 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/QueryPlanner.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/QueryPlanner.java @@ -174,7 +174,7 @@ public RelationPlan plan(Query query) PlanBuilder builder = planQueryBody(query); List orderBy = analysis.getOrderByExpressions(query); - builder = subqueryPlanner.handleSubqueries(builder, orderBy, query); + builder = subqueryPlanner.handleSubqueries(builder, orderBy, analysis.getSubqueries(query)); List selectExpressions = analysis.getSelectExpressions(query); List outputs = selectExpressions.stream() @@ -372,7 +372,7 @@ public RelationPlan plan(QuerySpecification node) List expressions = selectExpressions.stream() .map(SelectExpression::getExpression) .collect(toImmutableList()); - builder = subqueryPlanner.handleSubqueries(builder, expressions, node); + builder = subqueryPlanner.handleSubqueries(builder, expressions, analysis.getSubqueries(node)); if (hasExpressionsToUnfold(selectExpressions)) { // pre-project the folded expressions to preserve any non-deterministic semantics of functions that might be referenced @@ -409,7 +409,7 @@ public RelationPlan plan(QuerySpecification node) } List orderBy = analysis.getOrderByExpressions(node); - builder = subqueryPlanner.handleSubqueries(builder, orderBy, node); + builder = subqueryPlanner.handleSubqueries(builder, orderBy, analysis.getSubqueries(node)); builder = builder.appendProjections(Iterables.concat(orderBy, outputs), symbolAllocator, idAllocator); builder = distinct(builder, node, outputs); @@ -506,7 +506,7 @@ private PlanBuilder filter(PlanBuilder subPlan, Expression predicate, Node node) return subPlan; } - subPlan = subqueryPlanner.handleSubqueries(subPlan, predicate, node); + subPlan = subqueryPlanner.handleSubqueries(subPlan, predicate, analysis.getSubqueries(node)); return subPlan.withNewRoot(new FilterNode(idAllocator.getNextId(), subPlan.getRoot(), subPlan.rewrite(predicate))); } @@ -542,7 +542,7 @@ private PlanBuilder aggregate(PlanBuilder subPlan, QuerySpecification node) inputBuilder.addAll(groupingSetAnalysis.getComplexExpressions()); List inputs = inputBuilder.build(); - subPlan = subqueryPlanner.handleSubqueries(subPlan, inputs, node); + subPlan = subqueryPlanner.handleSubqueries(subPlan, inputs, analysis.getSubqueries(node)); subPlan = subPlan.appendProjections(inputs, symbolAllocator, idAllocator); // Add projection to coerce inputs to their site-specific types. @@ -837,7 +837,7 @@ private PlanBuilder window(Node node, PlanBuilder subPlan, List wi List inputs = inputsBuilder.build(); - subPlan = subqueryPlanner.handleSubqueries(subPlan, inputs, node); + subPlan = subqueryPlanner.handleSubqueries(subPlan, inputs, analysis.getSubqueries(node)); subPlan = subPlan.appendProjections(inputs, symbolAllocator, idAllocator); // Add projection to coerce inputs to their site-specific types. diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/RelationPlanner.java b/presto-main/src/main/java/io/prestosql/sql/planner/RelationPlanner.java index acb3db103b36..12cf97c21863 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/RelationPlanner.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/RelationPlanner.java @@ -226,7 +226,7 @@ private RelationPlan addRowFilters(Table node, RelationPlan plan) .withScope(analysis.getAccessControlScope(node), plan.getFieldMappings()); // The fields in the access control scope has the same layout as those for the table scope for (Expression filter : filters) { - planBuilder = subqueryPlanner.handleSubqueries(planBuilder, filter, filter); + planBuilder = subqueryPlanner.handleSubqueries(planBuilder, filter, analysis.getSubqueries(filter)); planBuilder = planBuilder.withNewRoot(new FilterNode( idAllocator.getNextId(), @@ -255,7 +255,7 @@ private RelationPlan addColumnMasks(Table table, RelationPlan plan) Field field = plan.getDescriptor().getFieldByIndex(i); for (Expression mask : columnMasks.getOrDefault(field.getName().get(), ImmutableList.of())) { - planBuilder = subqueryPlanner.handleSubqueries(planBuilder, mask, mask); + planBuilder = subqueryPlanner.handleSubqueries(planBuilder, mask, analysis.getSubqueries(mask)); Map assignments = new LinkedHashMap<>(); for (Symbol symbol : planBuilder.getRoot().getOutputSymbols()) { @@ -405,8 +405,8 @@ else if (firstDependencies.stream().allMatch(right::canResolve) && secondDepende } } - leftPlanBuilder = subqueryPlanner.handleSubqueries(leftPlanBuilder, leftComparisonExpressions, node); - rightPlanBuilder = subqueryPlanner.handleSubqueries(rightPlanBuilder, rightComparisonExpressions, node); + leftPlanBuilder = subqueryPlanner.handleSubqueries(leftPlanBuilder, leftComparisonExpressions, analysis.getSubqueries(node)); + rightPlanBuilder = subqueryPlanner.handleSubqueries(rightPlanBuilder, rightComparisonExpressions, analysis.getSubqueries(node)); // Add projections for join criteria leftPlanBuilder = leftPlanBuilder.appendProjections(leftComparisonExpressions, symbolAllocator, idAllocator); @@ -460,10 +460,10 @@ else if (firstDependencies.stream().allMatch(right::canResolve) && secondDepende // t JOIN u ON t.x = (...) get's planned on the u side // t JOIN u ON t.x + u.x = (...) get's planned on an arbitrary side if (dependencies.stream().allMatch(left::canResolve)) { - leftPlanBuilder = subqueryPlanner.handleSubqueries(leftPlanBuilder, complexExpression, node); + leftPlanBuilder = subqueryPlanner.handleSubqueries(leftPlanBuilder, complexExpression, analysis.getSubqueries(node)); } else { - rightPlanBuilder = subqueryPlanner.handleSubqueries(rightPlanBuilder, complexExpression, node); + rightPlanBuilder = subqueryPlanner.handleSubqueries(rightPlanBuilder, complexExpression, analysis.getSubqueries(node)); } } } @@ -493,7 +493,7 @@ else if (firstDependencies.stream().allMatch(right::canResolve) && secondDepende if (node.getType() == INNER) { // rewrite all the other conditions using output symbols from left + right plan node. PlanBuilder rootPlanBuilder = new PlanBuilder(translationMap, root); - rootPlanBuilder = subqueryPlanner.handleSubqueries(rootPlanBuilder, complexJoinExpressions, node); + rootPlanBuilder = subqueryPlanner.handleSubqueries(rootPlanBuilder, complexJoinExpressions, analysis.getSubqueries(node)); for (Expression expression : complexJoinExpressions) { postInnerJoinConditions.add(rootPlanBuilder.rewrite(expression)); @@ -684,7 +684,7 @@ private RelationPlan planCorrelatedJoin(Join join, RelationPlan leftPlan, Latera PlanBuilder planBuilder = subqueryPlanner.appendCorrelatedJoin( leftPlanBuilder, - rightPlanBuilder, + rightPlanBuilder.getRoot(), lateral.getQuery(), CorrelatedJoinNode.Type.typeConvert(join.getType()), rewrittenFilterCondition, diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/SubqueryPlanner.java b/presto-main/src/main/java/io/prestosql/sql/planner/SubqueryPlanner.java index 2fdd061c14ee..c5277793bb33 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/SubqueryPlanner.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/SubqueryPlanner.java @@ -20,13 +20,19 @@ import com.google.common.graph.Traverser; import io.prestosql.Session; import io.prestosql.metadata.Metadata; +import io.prestosql.spi.type.Type; import io.prestosql.sql.analyzer.Analysis; +import io.prestosql.sql.analyzer.Field; +import io.prestosql.sql.analyzer.RelationType; import io.prestosql.sql.analyzer.Scope; +import io.prestosql.sql.analyzer.TypeSignatureTranslator; import io.prestosql.sql.planner.plan.ApplyNode; import io.prestosql.sql.planner.plan.Assignments; import io.prestosql.sql.planner.plan.CorrelatedJoinNode; import io.prestosql.sql.planner.plan.EnforceSingleRowNode; +import io.prestosql.sql.planner.plan.PlanNode; import io.prestosql.sql.planner.plan.ProjectNode; +import io.prestosql.sql.tree.Cast; import io.prestosql.sql.tree.ComparisonExpression; import io.prestosql.sql.tree.ExistsPredicate; import io.prestosql.sql.tree.Expression; @@ -38,6 +44,7 @@ import io.prestosql.sql.tree.QuantifiedComparisonExpression; import io.prestosql.sql.tree.QuantifiedComparisonExpression.Quantifier; import io.prestosql.sql.tree.Query; +import io.prestosql.sql.tree.Row; import io.prestosql.sql.tree.SubqueryExpression; import io.prestosql.type.TypeCoercion; @@ -104,27 +111,27 @@ class SubqueryPlanner this.recursiveSubqueries = recursiveSubqueries; } - public PlanBuilder handleSubqueries(PlanBuilder builder, Collection expressions, Node node) + public PlanBuilder handleSubqueries(PlanBuilder builder, Collection expressions, Analysis.SubqueryAnalysis subqueries) { for (Expression expression : expressions) { - builder = handleSubqueries(builder, expression, node); + builder = handleSubqueries(builder, expression, subqueries); } return builder; } - public PlanBuilder handleSubqueries(PlanBuilder builder, Expression expression, Node node) + public PlanBuilder handleSubqueries(PlanBuilder builder, Expression expression, Analysis.SubqueryAnalysis subqueries) { - for (Cluster cluster : cluster(builder.getScope(), selectSubqueries(builder, expression, analysis.getInPredicateSubqueries(node)))) { - builder = planInPredicate(builder, cluster, node); + for (Cluster cluster : cluster(builder.getScope(), selectSubqueries(builder, expression, subqueries.getInPredicatesSubqueries()))) { + builder = planInPredicate(builder, cluster, subqueries); } - for (Cluster cluster : cluster(builder.getScope(), selectSubqueries(builder, expression, analysis.getScalarSubqueries(node)))) { + for (Cluster cluster : cluster(builder.getScope(), selectSubqueries(builder, expression, subqueries.getSubqueries()))) { builder = planScalarSubquery(builder, cluster); } - for (Cluster cluster : cluster(builder.getScope(), selectSubqueries(builder, expression, analysis.getExistsSubqueries(node)))) { + for (Cluster cluster : cluster(builder.getScope(), selectSubqueries(builder, expression, subqueries.getExistsSubqueries()))) { builder = planExists(builder, cluster); } - for (Cluster cluster : cluster(builder.getScope(), selectSubqueries(builder, expression, analysis.getQuantifiedComparisonSubqueries(node)))) { - builder = planQuantifiedComparison(builder, cluster, node); + for (Cluster cluster : cluster(builder.getScope(), selectSubqueries(builder, expression, subqueries.getQuantifiedComparisonSubqueries()))) { + builder = planQuantifiedComparison(builder, cluster, subqueries); } return builder; @@ -172,7 +179,7 @@ private Collection> cluster(Scope scope, List< .collect(toImmutableList()); } - private PlanBuilder planInPredicate(PlanBuilder subPlan, Cluster cluster, Node node) + private PlanBuilder planInPredicate(PlanBuilder subPlan, Cluster cluster, Analysis.SubqueryAnalysis subqueries) { // Plan one of the predicates from the cluster InPredicate predicate = cluster.getRepresentative(); @@ -181,7 +188,7 @@ private PlanBuilder planInPredicate(PlanBuilder subPlan, Cluster cl SubqueryExpression subquery = (SubqueryExpression) predicate.getValueList(); Symbol output = symbolAllocator.newSymbol(predicate, BOOLEAN); - subPlan = handleSubqueries(subPlan, value, node); + subPlan = handleSubqueries(subPlan, value, subqueries); subPlan = planInPredicate(subPlan, value, subquery, output, predicate); return new PlanBuilder( @@ -240,18 +247,41 @@ private PlanBuilder planScalarSubquery(PlanBuilder subPlan, Cluster fieldMappings = relationPlan.getFieldMappings(); + Symbol column; + if (descriptor.getVisibleFieldCount() > 1) { + column = symbolAllocator.newSymbol("row", type); + + ImmutableList.Builder fields = ImmutableList.builder(); + for (int i = 0; i < descriptor.getAllFieldCount(); i++) { + Field field = descriptor.getFieldByIndex(i); + if (!field.isHidden()) { + fields.add(fieldMappings.get(i).toSymbolReference()); + } + } + + Expression expression = new Cast(new Row(fields.build()), TypeSignatureTranslator.toSqlType(type)); + + root = new ProjectNode(idAllocator.getNextId(), root, Assignments.of(column, expression)); + } + else { + column = getOnlyElement(fieldMappings); + } return appendCorrelatedJoin( subPlan, - subqueryPlan, + root, scalarSubquery.getQuery(), CorrelatedJoinNode.Type.INNER, TRUE_LITERAL, - mapAll(cluster, subPlan.getScope(), getOnlyElement(relationPlan.getFieldMappings()))); + mapAll(cluster, subPlan.getScope(), column)); } - public PlanBuilder appendCorrelatedJoin(PlanBuilder subPlan, PlanBuilder subqueryPlan, Query query, CorrelatedJoinNode.Type type, Expression filterCondition, Map, Symbol> mappings) + public PlanBuilder appendCorrelatedJoin(PlanBuilder subPlan, PlanNode subquery, Query query, CorrelatedJoinNode.Type type, Expression filterCondition, Map, Symbol> mappings) { return new PlanBuilder( subPlan.getTranslations() @@ -259,7 +289,7 @@ public PlanBuilder appendCorrelatedJoin(PlanBuilder subPlan, PlanBuilder subquer new CorrelatedJoinNode( idAllocator.getNextId(), subPlan.getRoot(), - subqueryPlan.getRoot(), + subquery, subPlan.getRoot().getOutputSymbols(), type, filterCondition, @@ -292,7 +322,7 @@ private RelationPlan planSubquery(Expression subquery, TranslationMap outerConte .process(subquery, null); } - private PlanBuilder planQuantifiedComparison(PlanBuilder subPlan, Cluster cluster, Node node) + private PlanBuilder planQuantifiedComparison(PlanBuilder subPlan, Cluster cluster, Analysis.SubqueryAnalysis subqueries) { // Plan one of the predicates from the cluster QuantifiedComparisonExpression quantifiedComparison = cluster.getRepresentative(); @@ -302,7 +332,7 @@ private PlanBuilder planQuantifiedComparison(PlanBuilder subPlan, Cluster