Skip to content

Commit

Permalink
add logic for splitting expressions
Browse files Browse the repository at this point in the history
Signed-off-by: Andres Taylor <[email protected]>
  • Loading branch information
systay committed Jan 30, 2025
1 parent 7a695b0 commit ba53da7
Show file tree
Hide file tree
Showing 17 changed files with 208 additions and 62 deletions.
8 changes: 4 additions & 4 deletions go/vt/vtgate/engine/cached_size.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,19 @@ type ValuesJoin struct {
// of the Join. They can be any primitive.
Left, Right Primitive

Vars []int
RowConstructorArg string
Cols []int
ColNames []string
// ColumnsFromLHS are the offsets of columns from LHS we are copying over to the RHS
// []int{0,2} means that the first column in the t-o-t is the first offset from the left and the second column is the third offset
ColumnsFromLHS []int

// The name for the bind var containing the tuple-of-tuples being sent to the RHS
BindVarName string

// Cols tells use which side the output columns come from:
// negative numbers are offsets to the left, and positive to the right
Cols []int

// ColNames are the output column names
ColNames []string
}

// TryExecute performs a non-streaming exec.
Expand All @@ -60,22 +69,22 @@ func (jv *ValuesJoin) TryExecute(ctx context.Context, vcursor VCursor, bindVars
}
bv.Values = append(bv.Values, sqltypes.TupleToProto(vals))

bindVars[jv.RowConstructorArg] = bv
bindVars[jv.BindVarName] = bv
return jv.Right.GetFields(ctx, vcursor, bindVars)
}

for i, row := range lresult.Rows {
newRow := make(sqltypes.Row, 0, len(jv.Vars)+1) // +1 since we always add the row ID
newRow = append(newRow, sqltypes.NewInt64(int64(i))) // Adding the LHS row ID
newRow := make(sqltypes.Row, 0, len(jv.ColumnsFromLHS)+1) // +1 since we always add the row ID
newRow = append(newRow, sqltypes.NewInt64(int64(i))) // Adding the LHS row ID

for _, loffset := range jv.Vars {
for _, loffset := range jv.ColumnsFromLHS {
newRow = append(newRow, row[loffset])
}

bv.Values = append(bv.Values, sqltypes.TupleToProto(newRow))
}

bindVars[jv.RowConstructorArg] = bv
bindVars[jv.BindVarName] = bv
rresult, err := vcursor.ExecutePrimitive(ctx, jv.Right, bindVars, wantfields)
if err != nil {
return nil, err
Expand Down Expand Up @@ -143,8 +152,8 @@ func (jv *ValuesJoin) description() PrimitiveDescription {
OperatorType: "Join",
Variant: "Values",
Other: map[string]any{
"ValuesArg": jv.RowConstructorArg,
"Vars": jv.Vars,
"BindVarName": jv.BindVarName,
"ColumnsFromLHS": jv.ColumnsFromLHS,
},
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,12 @@ func TestJoinValuesExecute(t *testing.T) {
}

vjn := &ValuesJoin{
Left: leftPrim,
Right: rightPrim,
Vars: []int{0},
RowConstructorArg: "v",
Cols: []int{-1, -2, -3, -1, 1, 2},
ColNames: []string{"col1", "col2", "col3", "col4", "col5", "col6"},
Left: leftPrim,
Right: rightPrim,
ColumnsFromLHS: []int{0},
BindVarName: "v",
Cols: []int{-1, -2, -3, -1, 1, 2},
ColNames: []string{"col1", "col2", "col3", "col4", "col5", "col6"},
}

r, err := vjn.TryExecute(context.Background(), &noopVCursor{}, bv, true)
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/aggregation_pushing.go
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ func splitGroupingToLeftAndRight(
rhs.addGrouping(ctx, groupBy)
columns.addRight(groupBy.Inner)
case deps.IsSolvedBy(lhs.tableID.Merge(rhs.tableID)):
jc := breakExpressionInLHSandRHS(ctx, groupBy.Inner, lhs.tableID)
jc := breakApplyJoinExpressionInLHSandRHS(ctx, groupBy.Inner, lhs.tableID)
for _, lhsExpr := range jc.LHSExprs {
e := lhsExpr.Expr
lhs.addGrouping(ctx, NewGroupBy(e))
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/planbuilder/operators/apply_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ func (aj *ApplyJoin) AddJoinPredicate(ctx *plancontext.PlanningContext, expr sql
rhs := aj.RHS
predicates := sqlparser.SplitAndExpression(nil, expr)
for _, pred := range predicates {
col := breakExpressionInLHSandRHS(ctx, pred, TableID(aj.LHS))
col := breakApplyJoinExpressionInLHSandRHS(ctx, pred, TableID(aj.LHS))
aj.JoinPredicates.add(col)
ctx.AddJoinPredicates(pred, col.RHSExpr)
rhs = rhs.AddPredicate(ctx, col.RHSExpr)
Expand Down Expand Up @@ -199,7 +199,7 @@ func (aj *ApplyJoin) getJoinColumnFor(ctx *plancontext.PlanningContext, orig *sq
case deps.IsSolvedBy(rhs):
col.RHSExpr = e
case deps.IsSolvedBy(both):
col = breakExpressionInLHSandRHS(ctx, e, TableID(aj.LHS))
col = breakApplyJoinExpressionInLHSandRHS(ctx, e, TableID(aj.LHS))
default:
panic(vterrors.VT13001(fmt.Sprintf("expression depends on tables outside this join: %s", sqlparser.String(e))))
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/ast_to_op.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ func (jpc *joinPredicateCollector) inspectPredicate(
// then we can use this predicate to connect the subquery to the outer query
if !deps.IsSolvedBy(jpc.subqID) && deps.IsSolvedBy(jpc.totalID) {
jpc.predicates = append(jpc.predicates, predicate)
jc := breakExpressionInLHSandRHS(ctx, predicate, jpc.outerID)
jc := breakApplyJoinExpressionInLHSandRHS(ctx, predicate, jpc.outerID)
jpc.joinColumns = append(jpc.joinColumns, jc)
pred = jc.RHSExpr
}
Expand Down
21 changes: 19 additions & 2 deletions go/vt/vtgate/planbuilder/operators/expressions.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ import (
"vitess.io/vitess/go/vt/vtgate/semantics"
)

// breakExpressionInLHSandRHS takes an expression and
// breakApplyJoinExpressionInLHSandRHS takes an expression and
// extracts the parts that are coming from one of the sides into `ColName`s that are needed
func breakExpressionInLHSandRHS(
func breakApplyJoinExpressionInLHSandRHS(
ctx *plancontext.PlanningContext,
expr sqlparser.Expr,
lhs semantics.TableSet,
Expand Down Expand Up @@ -129,3 +129,20 @@ func getFirstSelect(selStmt sqlparser.TableStatement) *sqlparser.Select {
}
return firstSelect
}

func breakValuesJoinExpressionInLHS(ctx *plancontext.PlanningContext,
expr sqlparser.Expr,
lhs semantics.TableSet,
) (results []*sqlparser.ColName) {
_ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
col, ok := node.(*sqlparser.ColName)
if !ok {
return true, nil
}
if ctx.SemTable.RecursiveDeps(col) == lhs {
results = append(results, col)
}
return true, nil
}, expr)
return
}
58 changes: 58 additions & 0 deletions go/vt/vtgate/planbuilder/operators/expressions_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
Copyright 2025 The Vitess Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package operators

import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"vitess.io/vitess/go/slice"
"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext"
"vitess.io/vitess/go/vt/vtgate/semantics"
)

func TestSplitComplexPredicateToLHS(t *testing.T) {
ast, err := sqlparser.NewTestParser().ParseExpr("l.foo + r.bar - l.baz / r.tata = 0")
require.NoError(t, err)
lID := semantics.SingleTableSet(0)
rID := semantics.SingleTableSet(1)
ctx := plancontext.CreateEmptyPlanningContext()
ctx.SemTable = semantics.EmptySemTable()
// simple sem analysis using the column prefix
_ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
col, ok := node.(*sqlparser.ColName)
if !ok {
return true, nil
}
if col.Qualifier.Name.String() == "l" {
ctx.SemTable.Recursive[col] = lID
} else {
ctx.SemTable.Recursive[col] = rID
}
return false, nil
}, ast)

lhsExprs := breakValuesJoinExpressionInLHS(ctx, ast, lID)
nodes := slice.Map(lhsExprs, func(from *sqlparser.ColName) string {
return sqlparser.String(from)
})

assert.Equal(t, []string{"l.foo", "l.baz"}, nodes)
}
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ func addCTEPredicate(
}

func breakCTEExpressionInLhsAndRhs(ctx *plancontext.PlanningContext, pred sqlparser.Expr, lhsID semantics.TableSet) *plancontext.RecurseExpression {
col := breakExpressionInLHSandRHS(ctx, pred, lhsID)
col := breakApplyJoinExpressionInLHSandRHS(ctx, pred, lhsID)

lhsExprs := slice.Map(col.LHSExprs, func(bve BindVarExpr) plancontext.BindVarExpr {
col, ok := bve.Expr.(*sqlparser.ColName)
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/op_to_ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ func buildValues(op *Values, qb *queryBuilder) {
Select: &sqlparser.ValuesStatement{
ListArg: sqlparser.NewListArg(op.Arg),
},
}, nil, op.Columns)
}, nil, op.getColsFromCtx(qb.ctx))
}

func buildDelete(op *Delete, qb *queryBuilder) {
Expand Down
15 changes: 8 additions & 7 deletions go/vt/vtgate/planbuilder/operators/op_to_ast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ import (

func TestToSQLValues(t *testing.T) {
ctx := plancontext.CreateEmptyPlanningContext()
bindVarName := "toto"
ctx.ValuesJoinColumns[bindVarName] = sqlparser.Columns{sqlparser.NewIdentifierCI("user_id")}

tableName := sqlparser.NewTableName("x")
tableColumn := sqlparser.NewColName("id")
Expand All @@ -38,9 +40,8 @@ func TestToSQLValues(t *testing.T) {
},
Columns: []*sqlparser.ColName{tableColumn},
}),
Columns: sqlparser.Columns{sqlparser.NewIdentifierCI("user_id")},
Name: "t",
Arg: "toto",
Name: "t",
Arg: bindVarName,
}

stmt, _, err := ToAST(ctx, op)
Expand All @@ -63,6 +64,7 @@ func TestToSQLValues(t *testing.T) {
}

func TestToSQLValuesJoin(t *testing.T) {
// Build a SQL AST from a values join that has been pushed under a route
ctx := plancontext.CreateEmptyPlanningContext()
parser := sqlparser.NewTestParser()

Expand All @@ -83,7 +85,7 @@ func TestToSQLValuesJoin(t *testing.T) {
}

const argumentName = "v"

ctx.ValuesJoinColumns[argumentName] = sqlparser.Columns{sqlparser.NewIdentifierCI("id")}
rhsTableName := sqlparser.NewTableName("y")
rhsTableColumn := sqlparser.NewColName("tata")
rhsFilterPred, err := parser.ParseExpr("y.tata = 42")
Expand All @@ -100,9 +102,8 @@ func TestToSQLValuesJoin(t *testing.T) {
},
Columns: []*sqlparser.ColName{rhsTableColumn},
}),
Columns: sqlparser.Columns{sqlparser.NewIdentifierCI("id")},
Name: lhsTableName.Name.String(),
Arg: argumentName,
Name: lhsTableName.Name.String(),
Arg: argumentName,
}),
Predicates: []sqlparser.Expr{rhsFilterPred, rhsJoinFilterPred},
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/subquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func (sq *SubQuery) GetJoinColumns(ctx *plancontext.PlanningContext, outer Opera
}
sq.outerID = outerID
mapper := func(in sqlparser.Expr) (applyJoinColumn, error) {
return breakExpressionInLHSandRHS(ctx, in, outerID), nil
return breakApplyJoinExpressionInLHSandRHS(ctx, in, outerID), nil
}
joinPredicates, err := slice.MapWithError(sq.Predicates, mapper)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/subquery_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ func extractLHSExpr(
lhs semantics.TableSet,
) func(expr sqlparser.Expr) sqlparser.Expr {
return func(expr sqlparser.Expr) sqlparser.Expr {
col := breakExpressionInLHSandRHS(ctx, expr, lhs)
col := breakApplyJoinExpressionInLHSandRHS(ctx, expr, lhs)
if col.IsPureLeft() {
panic(vterrors.VT13001("did not expect to find any predicates that do not need data from the inner here"))
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ func prepareUpdateExpressionList(ctx *plancontext.PlanningContext, upd *sqlparse
for _, ue := range upd.Exprs {
target := ctx.SemTable.DirectDeps(ue.Name)
exprDeps := ctx.SemTable.RecursiveDeps(ue.Expr)
jc := breakExpressionInLHSandRHS(ctx, ue.Expr, exprDeps.Remove(target))
jc := breakApplyJoinExpressionInLHSandRHS(ctx, ue.Expr, exprDeps.Remove(target))
ueMap[target] = append(ueMap[target], updColumn{ue.Name, jc})
}

Expand Down
26 changes: 13 additions & 13 deletions go/vt/vtgate/planbuilder/operators/values.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@ limitations under the License.
package operators

import (
"fmt"
"slices"

"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vterrors"
"vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext"
Expand All @@ -28,17 +25,12 @@ import (
type Values struct {
unaryOperator

Columns sqlparser.Columns
Name string
Arg string

// TODO: let's see if we want to have noColumns or no
// noColumns
Name string
Arg string
}

func (v *Values) Clone(inputs []Operator) Operator {
clone := *v
clone.Columns = slices.Clone(v.Columns)
return &clone
}

Expand All @@ -59,17 +51,25 @@ func (v *Values) FindCol(ctx *plancontext.PlanningContext, expr sqlparser.Expr,
if !ok {
return -1
}
for i, column := range v.Columns {
for i, column := range v.getColsFromCtx(ctx) {
if col.Name.Equal(column) {
return i
}
}
return -1
}

func (v *Values) getColsFromCtx(ctx *plancontext.PlanningContext) sqlparser.Columns {
columns, found := ctx.ValuesJoinColumns[v.Arg]
if !found {
panic(vterrors.VT13001("columns not found"))
}
return columns
}

func (v *Values) GetColumns(ctx *plancontext.PlanningContext) []*sqlparser.AliasedExpr {
var cols []*sqlparser.AliasedExpr
for _, column := range v.Columns {
for _, column := range v.getColsFromCtx(ctx) {
cols = append(cols, sqlparser.NewAliasedExpr(sqlparser.NewColName(column.String()), ""))
}
return cols
Expand All @@ -85,7 +85,7 @@ func (v *Values) GetSelectExprs(ctx *plancontext.PlanningContext) sqlparser.Sele
}

func (v *Values) ShortDescription() string {
return fmt.Sprintf("%s (%s)", v.Name, sqlparser.String(v.Columns))
return v.Name
}

func (v *Values) GetOrdering(ctx *plancontext.PlanningContext) []OrderBy {
Expand Down
Loading

0 comments on commit ba53da7

Please sign in to comment.