Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mpp 4892 test #47

Open
wants to merge 2 commits into
base: hotfix-350
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 57 additions & 30 deletions presto-main/src/main/java/io/prestosql/sql/analyzer/Analysis.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,10 @@
*/
package io.prestosql.sql.analyzer;

import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.HashMultiset;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Multimap;
import com.google.common.collect.Multiset;
import com.google.common.collect.Streams;
Expand Down Expand Up @@ -69,6 +67,7 @@
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Deque;
import java.util.HashSet;
import java.util.LinkedHashMap;
Expand All @@ -87,7 +86,6 @@
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static java.lang.String.format;
import static java.util.Collections.emptyList;
import static java.util.Collections.unmodifiableList;
import static java.util.Collections.unmodifiableMap;
import static java.util.Collections.unmodifiableSet;
import static java.util.Objects.requireNonNull;
Expand Down Expand Up @@ -139,11 +137,7 @@ public class Analysis

private final Map<NodeRef<Join>, Expression> joins = new LinkedHashMap<>();
private final Map<NodeRef<Join>, JoinUsingAnalysis> joinUsing = new LinkedHashMap<>();

private final ListMultimap<NodeRef<Node>, InPredicate> inPredicatesSubqueries = ArrayListMultimap.create();
private final ListMultimap<NodeRef<Node>, SubqueryExpression> scalarSubqueries = ArrayListMultimap.create();
private final ListMultimap<NodeRef<Node>, ExistsPredicate> existsSubqueries = ArrayListMultimap.create();
private final ListMultimap<NodeRef<Node>, QuantifiedComparisonExpression> quantifiedComparisonSubqueries = ArrayListMultimap.create();
private final Map<NodeRef<Node>, SubqueryAnalysis> subqueries = new LinkedHashMap<>();

private final Map<NodeRef<Table>, TableEntry> tables = new LinkedHashMap<>();

Expand Down Expand Up @@ -415,11 +409,11 @@ public Expression getJoinCriteria(Join join)

public void recordSubqueries(Node node, ExpressionAnalysis expressionAnalysis)
{
NodeRef<Node> key = NodeRef.of(node);
this.inPredicatesSubqueries.putAll(key, dereference(expressionAnalysis.getSubqueryInPredicates()));
this.scalarSubqueries.putAll(key, dereference(expressionAnalysis.getScalarSubqueries()));
this.existsSubqueries.putAll(key, dereference(expressionAnalysis.getExistsSubqueries()));
this.quantifiedComparisonSubqueries.putAll(key, dereference(expressionAnalysis.getQuantifiedComparisons()));
SubqueryAnalysis subqueries = this.subqueries.computeIfAbsent(NodeRef.of(node), key -> new SubqueryAnalysis());
subqueries.addInPredicates(dereference(expressionAnalysis.getSubqueryInPredicates()));
subqueries.addSubqueries(dereference(expressionAnalysis.getScalarSubqueries()));
subqueries.addExistsSubqueries(dereference(expressionAnalysis.getExistsSubqueries()));
subqueries.addQuantifiedComparisons(dereference(expressionAnalysis.getQuantifiedComparisons()));
}

private <T extends Node> List<T> dereference(Collection<NodeRef<T>> nodeRefs)
Expand All @@ -429,24 +423,9 @@ private <T extends Node> List<T> dereference(Collection<NodeRef<T>> nodeRefs)
.collect(toImmutableList());
}

public List<InPredicate> getInPredicateSubqueries(Node node)
{
return ImmutableList.copyOf(inPredicatesSubqueries.get(NodeRef.of(node)));
}

public List<SubqueryExpression> getScalarSubqueries(Node node)
public SubqueryAnalysis getSubqueries(Node node)
{
return ImmutableList.copyOf(scalarSubqueries.get(NodeRef.of(node)));
}

public List<ExistsPredicate> getExistsSubqueries(Node node)
{
return ImmutableList.copyOf(existsSubqueries.get(NodeRef.of(node)));
}

public List<QuantifiedComparisonExpression> getQuantifiedComparisonSubqueries(Node node)
{
return unmodifiableList(quantifiedComparisonSubqueries.get(NodeRef.of(node)));
return subqueries.computeIfAbsent(NodeRef.of(node), key -> new SubqueryAnalysis());
}

public void setWindowFunctions(QuerySpecification node, List<FunctionCall> functions)
Expand Down Expand Up @@ -1199,6 +1178,54 @@ public Optional<Field> getOrdinalityField()
}
}

public static class SubqueryAnalysis
{
private final List<InPredicate> inPredicatesSubqueries = new ArrayList<>();
private final List<SubqueryExpression> subqueries = new ArrayList<>();
private final List<ExistsPredicate> existsSubqueries = new ArrayList<>();
private final List<QuantifiedComparisonExpression> quantifiedComparisonSubqueries = new ArrayList<>();

public void addInPredicates(List<InPredicate> expressions)
{
inPredicatesSubqueries.addAll(expressions);
}

public void addSubqueries(List<SubqueryExpression> expressions)
{
subqueries.addAll(expressions);
}

public void addExistsSubqueries(List<ExistsPredicate> expressions)
{
existsSubqueries.addAll(expressions);
}

public void addQuantifiedComparisons(List<QuantifiedComparisonExpression> expressions)
{
quantifiedComparisonSubqueries.addAll(expressions);
}

public List<InPredicate> getInPredicatesSubqueries()
{
return Collections.unmodifiableList(inPredicatesSubqueries);
}

public List<SubqueryExpression> getSubqueries()
{
return Collections.unmodifiableList(subqueries);
}

public List<ExistsPredicate> getExistsSubqueries()
{
return Collections.unmodifiableList(existsSubqueries);
}

public List<QuantifiedComparisonExpression> getQuantifiedComparisonSubqueries()
{
return Collections.unmodifiableList(quantifiedComparisonSubqueries);
}
}

public static final class AccessControlInfo
{
private final AccessControl accessControl;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ public class ExpressionAnalyzer
private final boolean isDescribe;

private final Map<NodeRef<FunctionCall>, ResolvedFunction> resolvedFunctions = new LinkedHashMap<>();
private final Set<NodeRef<SubqueryExpression>> scalarSubqueries = new LinkedHashSet<>();
private final Set<NodeRef<SubqueryExpression>> subqueries = new LinkedHashSet<>();
private final Set<NodeRef<ExistsPredicate>> existsSubqueries = new LinkedHashSet<>();
private final Map<NodeRef<Expression>, Type> expressionCoercions = new LinkedHashMap<>();
private final Set<NodeRef<Expression>> typeOnlyCoercions = new LinkedHashSet<>();
Expand Down Expand Up @@ -353,9 +353,9 @@ private Type analyze(Expression expression, Scope baseScope, Context context)
return visitor.process(expression, new StackableAstVisitor.StackableAstVisitorContext<>(context));
}

public Set<NodeRef<SubqueryExpression>> getScalarSubqueries()
public Set<NodeRef<SubqueryExpression>> getSubqueries()
{
return unmodifiableSet(scalarSubqueries);
return unmodifiableSet(subqueries);
}

public Set<NodeRef<ExistsPredicate>> getExistsSubqueries()
Expand Down Expand Up @@ -1511,7 +1511,7 @@ else if (previousNode instanceof QuantifiedComparisonExpression) {
quantifiedComparisons.add(NodeRef.of((QuantifiedComparisonExpression) previousNode));
}
else {
scalarSubqueries.add(NodeRef.of(node));
subqueries.add(NodeRef.of(node));
}

Type type = getOnlyElement(queryScope.getRelationType().getVisibleFields()).getType();
Expand Down Expand Up @@ -1892,7 +1892,7 @@ public static ExpressionAnalysis analyzeExpressions(
analyzer.getExpressionTypes(),
analyzer.getExpressionCoercions(),
analyzer.getSubqueryInPredicates(),
analyzer.getScalarSubqueries(),
analyzer.getSubqueries(),
analyzer.getExistsSubqueries(),
analyzer.getColumnReferences(),
analyzer.getTypeOnlyCoercions(),
Expand Down Expand Up @@ -1939,7 +1939,7 @@ public static ExpressionAnalysis analyzeExpression(
expressionTypes,
expressionCoercions,
analyzer.getSubqueryInPredicates(),
analyzer.getScalarSubqueries(),
analyzer.getSubqueries(),
analyzer.getExistsSubqueries(),
analyzer.getColumnReferences(),
analyzer.getTypeOnlyCoercions(),
Expand Down
10 changes: 10 additions & 0 deletions presto-main/src/main/java/io/prestosql/sql/analyzer/Scope.java
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,16 @@ public static final class Builder
private Optional<Scope> parent = Optional.empty();
private boolean queryBoundary;

public Builder like(Scope other)
{
relationId = other.relationId;
relationType = other.relation;
namedQueries.putAll(other.namedQueries);
parent = other.parent;
queryBoundary = other.queryBoundary;
return this;
}

public Builder withRelationType(RelationId relationId, RelationType relationType)
{
this.relationId = requireNonNull(relationId, "relationId is null");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import io.prestosql.metadata.Metadata;
import io.prestosql.sql.tree.ComparisonExpression;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.SymbolReference;
import io.prestosql.util.DisjointSet;

import java.util.ArrayList;
Expand Down Expand Up @@ -321,16 +322,16 @@ Expression getScopedCanonical(Expression expression, Predicate<Symbol> symbolSco
}

Collection<Expression> equivalences = equalitySets.get(canonicalIndex);
// if (expression instanceof SymbolReference) {
// boolean inScope = equivalences.stream()
// .filter(SymbolReference.class::isInstance)
// .map(Symbol::from)
// .anyMatch(symbolScope);
//
// if (!inScope) {
// return null;
// }
// }
if (expression instanceof SymbolReference) {
boolean inScope = equivalences.stream()
.filter(SymbolReference.class::isInstance)
.map(Symbol::from)
.anyMatch(symbolScope);

if (!inScope) {
return null;
}
}

Set<Expression> candidates = equivalences.stream()
.filter(e -> isScoped(e, symbolScope))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ public RelationPlan plan(Query query)
PlanBuilder builder = planQueryBody(query);

List<Expression> orderBy = analysis.getOrderByExpressions(query);
builder = subqueryPlanner.handleSubqueries(builder, orderBy, query);
builder = subqueryPlanner.handleSubqueries(builder, orderBy, analysis.getSubqueries(query));

List<SelectExpression> selectExpressions = analysis.getSelectExpressions(query);
List<Expression> outputs = selectExpressions.stream()
Expand Down Expand Up @@ -372,7 +372,7 @@ public RelationPlan plan(QuerySpecification node)
List<Expression> expressions = selectExpressions.stream()
.map(SelectExpression::getExpression)
.collect(toImmutableList());
builder = subqueryPlanner.handleSubqueries(builder, expressions, node);
builder = subqueryPlanner.handleSubqueries(builder, expressions, analysis.getSubqueries(node));

if (hasExpressionsToUnfold(selectExpressions)) {
// pre-project the folded expressions to preserve any non-deterministic semantics of functions that might be referenced
Expand Down Expand Up @@ -409,7 +409,7 @@ public RelationPlan plan(QuerySpecification node)
}

List<Expression> orderBy = analysis.getOrderByExpressions(node);
builder = subqueryPlanner.handleSubqueries(builder, orderBy, node);
builder = subqueryPlanner.handleSubqueries(builder, orderBy, analysis.getSubqueries(node));
builder = builder.appendProjections(Iterables.concat(orderBy, outputs), symbolAllocator, idAllocator);

builder = distinct(builder, node, outputs);
Expand Down Expand Up @@ -506,7 +506,7 @@ private PlanBuilder filter(PlanBuilder subPlan, Expression predicate, Node node)
return subPlan;
}

subPlan = subqueryPlanner.handleSubqueries(subPlan, predicate, node);
subPlan = subqueryPlanner.handleSubqueries(subPlan, predicate, analysis.getSubqueries(node));

return subPlan.withNewRoot(new FilterNode(idAllocator.getNextId(), subPlan.getRoot(), subPlan.rewrite(predicate)));
}
Expand Down Expand Up @@ -542,7 +542,7 @@ private PlanBuilder aggregate(PlanBuilder subPlan, QuerySpecification node)
inputBuilder.addAll(groupingSetAnalysis.getComplexExpressions());

List<Expression> inputs = inputBuilder.build();
subPlan = subqueryPlanner.handleSubqueries(subPlan, inputs, node);
subPlan = subqueryPlanner.handleSubqueries(subPlan, inputs, analysis.getSubqueries(node));
subPlan = subPlan.appendProjections(inputs, symbolAllocator, idAllocator);

// Add projection to coerce inputs to their site-specific types.
Expand Down Expand Up @@ -837,7 +837,7 @@ private PlanBuilder window(Node node, PlanBuilder subPlan, List<FunctionCall> wi

List<Expression> inputs = inputsBuilder.build();

subPlan = subqueryPlanner.handleSubqueries(subPlan, inputs, node);
subPlan = subqueryPlanner.handleSubqueries(subPlan, inputs, analysis.getSubqueries(node));
subPlan = subPlan.appendProjections(inputs, symbolAllocator, idAllocator);

// Add projection to coerce inputs to their site-specific types.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ private RelationPlan addRowFilters(Table node, RelationPlan plan)
.withScope(analysis.getAccessControlScope(node), plan.getFieldMappings()); // The fields in the access control scope has the same layout as those for the table scope

for (Expression filter : filters) {
planBuilder = subqueryPlanner.handleSubqueries(planBuilder, filter, filter);
planBuilder = subqueryPlanner.handleSubqueries(planBuilder, filter, analysis.getSubqueries(filter));

planBuilder = planBuilder.withNewRoot(new FilterNode(
idAllocator.getNextId(),
Expand Down Expand Up @@ -255,7 +255,7 @@ private RelationPlan addColumnMasks(Table table, RelationPlan plan)
Field field = plan.getDescriptor().getFieldByIndex(i);

for (Expression mask : columnMasks.getOrDefault(field.getName().get(), ImmutableList.of())) {
planBuilder = subqueryPlanner.handleSubqueries(planBuilder, mask, mask);
planBuilder = subqueryPlanner.handleSubqueries(planBuilder, mask, analysis.getSubqueries(mask));

Map<Symbol, Expression> assignments = new LinkedHashMap<>();
for (Symbol symbol : planBuilder.getRoot().getOutputSymbols()) {
Expand Down Expand Up @@ -405,8 +405,8 @@ else if (firstDependencies.stream().allMatch(right::canResolve) && secondDepende
}
}

leftPlanBuilder = subqueryPlanner.handleSubqueries(leftPlanBuilder, leftComparisonExpressions, node);
rightPlanBuilder = subqueryPlanner.handleSubqueries(rightPlanBuilder, rightComparisonExpressions, node);
leftPlanBuilder = subqueryPlanner.handleSubqueries(leftPlanBuilder, leftComparisonExpressions, analysis.getSubqueries(node));
rightPlanBuilder = subqueryPlanner.handleSubqueries(rightPlanBuilder, rightComparisonExpressions, analysis.getSubqueries(node));

// Add projections for join criteria
leftPlanBuilder = leftPlanBuilder.appendProjections(leftComparisonExpressions, symbolAllocator, idAllocator);
Expand Down Expand Up @@ -460,10 +460,10 @@ else if (firstDependencies.stream().allMatch(right::canResolve) && secondDepende
// t JOIN u ON t.x = (...) get's planned on the u side
// t JOIN u ON t.x + u.x = (...) get's planned on an arbitrary side
if (dependencies.stream().allMatch(left::canResolve)) {
leftPlanBuilder = subqueryPlanner.handleSubqueries(leftPlanBuilder, complexExpression, node);
leftPlanBuilder = subqueryPlanner.handleSubqueries(leftPlanBuilder, complexExpression, analysis.getSubqueries(node));
}
else {
rightPlanBuilder = subqueryPlanner.handleSubqueries(rightPlanBuilder, complexExpression, node);
rightPlanBuilder = subqueryPlanner.handleSubqueries(rightPlanBuilder, complexExpression, analysis.getSubqueries(node));
}
}
}
Expand Down Expand Up @@ -493,7 +493,7 @@ else if (firstDependencies.stream().allMatch(right::canResolve) && secondDepende
if (node.getType() == INNER) {
// rewrite all the other conditions using output symbols from left + right plan node.
PlanBuilder rootPlanBuilder = new PlanBuilder(translationMap, root);
rootPlanBuilder = subqueryPlanner.handleSubqueries(rootPlanBuilder, complexJoinExpressions, node);
rootPlanBuilder = subqueryPlanner.handleSubqueries(rootPlanBuilder, complexJoinExpressions, analysis.getSubqueries(node));

for (Expression expression : complexJoinExpressions) {
postInnerJoinConditions.add(rootPlanBuilder.rewrite(expression));
Expand Down Expand Up @@ -684,7 +684,7 @@ private RelationPlan planCorrelatedJoin(Join join, RelationPlan leftPlan, Latera

PlanBuilder planBuilder = subqueryPlanner.appendCorrelatedJoin(
leftPlanBuilder,
rightPlanBuilder,
rightPlanBuilder.getRoot(),
lateral.getQuery(),
CorrelatedJoinNode.Type.typeConvert(join.getType()),
rewrittenFilterCondition,
Expand Down
Loading