Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
137079: sql/opt: implement vector search in the optimizer r=DrewKimball a=DrewKimball

#### sql/opt: implement vector search operators in the optimizer

This commit adds the optimizer representations for two operators that
will be used to support vector indexes:
1. `VectorSearchExpr` is used to search the vector index. It returns a
   set of candidate primary keys, which can be used to retrieve full
   vectors for re-ranking against the query vector.
2. `VectorPartitionSearchExpr` is used to prepare a mutation that
   modifies the vector index. It determines the partition that contains
   (or should contain) each input vector, so that the mutation "knows
   where to look".

The following commits will add optimizer tests for these expressions.

Epic: None
Release note: None

#### sql/opt: plan VectorPartitionSearch operators for mutations

This commit adds the logic that places `VectorPartitionSearch` operators
in the input of a mutation that modifies one or more vector indexes.
The operators are used to determine which partition contains, or should
contain, each vector involved in the mutation.

Epic: None
Release note: None

#### sql/opt: add rule to generate vector search operators

This commit adds the exploration rule `GenerateVectorSearch`, which
matches a `Limit` ordered by the indexed vector column of an input
`Scan` operator. The replacement expression takes the candidate primary
keys returned by the `VectorSearch` operator, and produces a KNN result
by fetching the full vectors, calculating distances to the query vector,
and feeding the result into a top-k operator.

Epic: None
Release note: None

Co-authored-by: Drew Kimball <[email protected]>
  • Loading branch information
craig[bot] and DrewKimball committed Jan 7, 2025
2 parents 833fbe4 + 00b4930 commit f93f61a
Show file tree
Hide file tree
Showing 28 changed files with 2,116 additions and 288 deletions.
2 changes: 1 addition & 1 deletion pkg/ccl/logictestccl/testdata/logic_test/explain_redact
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ insert p
query T
EXPLAIN (OPT, MEMO, REDACT) INSERT INTO p VALUES (1, 1)
----
memo (optimized, ~4KB, required=[presentation: info:9] [distribution: test])
memo (optimized, ~5KB, required=[presentation: info:9] [distribution: test])
├── G1: (explain G2 [distribution: test])
│ └── [presentation: info:9] [distribution: test]
│ ├── best: (explain G2="[distribution: test]" [distribution: test])
Expand Down
146 changes: 67 additions & 79 deletions pkg/sql/opt/exec/execbuilder/mutation.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,18 @@ func (b *Builder) buildMutationInput(
}
}

// Currently, the execution engine requires one input column for each fetch,
// insert, update, and delete expression, so use ensureColumns to map and
// reorder columns so that they correspond to target table columns.
// For example:
//
// UPDATE xyz SET x=1, y=1
//
// Here, the input has just one column (because the constant is shared), and
// so must be mapped to two separate update columns.
//
// TODO(andyk): Using ensureColumns here can result in an extra Render.
// Upgrade execution engine to not require this.
input, inputCols, err = b.ensureColumns(
input, inputCols, inputExpr, colList,
inputExpr.ProvidedPhysical().Ordering, true, /* reuseInputCols */
Expand Down Expand Up @@ -88,10 +100,10 @@ func (b *Builder) buildInsert(ins *memo.InsertExpr) (_ execPlan, outputCols colO
}
// Construct list of columns that only contains columns that need to be
// inserted (e.g. delete-only mutation columns don't need to be inserted).
colList := make(opt.ColList, 0, len(ins.InsertCols)+len(ins.CheckCols)+len(ins.PartialIndexPutCols))
colList = appendColsWhenPresent(colList, ins.InsertCols)
colList = appendColsWhenPresent(colList, ins.CheckCols)
colList = appendColsWhenPresent(colList, ins.PartialIndexPutCols)
colList := appendColsWhenPresent(
ins.InsertCols, ins.CheckCols, ins.PartialIndexPutCols,
ins.VectorIndexPutPartitionCols, ins.VectorIndexPutCentroidCols,
)
input, _, err := b.buildMutationInput(ins, ins.Input, colList, &ins.MutationPrivate)
if err != nil {
return execPlan{}, colOrdMap{}, err
Expand Down Expand Up @@ -321,10 +333,10 @@ func (b *Builder) tryBuildFastPathInsert(
}
}

colList := make(opt.ColList, 0, len(ins.InsertCols)+len(ins.CheckCols)+len(ins.PartialIndexPutCols))
colList = appendColsWhenPresent(colList, ins.InsertCols)
colList = appendColsWhenPresent(colList, ins.CheckCols)
colList = appendColsWhenPresent(colList, ins.PartialIndexPutCols)
colList := appendColsWhenPresent(
ins.InsertCols, ins.CheckCols, ins.PartialIndexPutCols,
ins.VectorIndexDelPartitionCols, ins.VectorIndexPutCentroidCols,
)
rows, err := b.buildValuesRows(values)
if err != nil {
return execPlan{}, colOrdMap{}, false, err
Expand Down Expand Up @@ -392,32 +404,20 @@ func rearrangeColumns(
}

func (b *Builder) buildUpdate(upd *memo.UpdateExpr) (_ execPlan, outputCols colOrdMap, err error) {
// Currently, the execution engine requires one input column for each fetch
// and update expression, so use ensureColumns to map and reorder columns so
// that they correspond to target table columns. For example:
//
// UPDATE xyz SET x=1, y=1
//
// Here, the input has just one column (because the constant is shared), and
// so must be mapped to two separate update columns.
//
// TODO(andyk): Using ensureColumns here can result in an extra Render.
// Upgrade execution engine to not require this.
cnt := len(upd.FetchCols) + len(upd.UpdateCols) + len(upd.PassthroughCols) +
len(upd.CheckCols) + len(upd.PartialIndexPutCols) + len(upd.PartialIndexDelCols)
colList := make(opt.ColList, 0, cnt)
colList = appendColsWhenPresent(colList, upd.FetchCols)
colList = appendColsWhenPresent(colList, upd.UpdateCols)
// The RETURNING clause of the Update can refer to the columns
// in any of the FROM tables. As a result, the Update may need
// to passthrough those columns so the projection above can use
// them.
var neededPassThroughCols opt.OptionalColList
if upd.NeedResults() {
colList = append(colList, upd.PassthroughCols...)
}
colList = appendColsWhenPresent(colList, upd.CheckCols)
colList = appendColsWhenPresent(colList, upd.PartialIndexPutCols)
colList = appendColsWhenPresent(colList, upd.PartialIndexDelCols)
// The RETURNING clause of the Update can refer to the columns
// in any of the FROM tables. As a result, the Update may need
// to passthrough those columns so the projection above can use
// them.
neededPassThroughCols = opt.OptionalColList(upd.PassthroughCols)
}
colList := appendColsWhenPresent(
upd.FetchCols, upd.UpdateCols, neededPassThroughCols, upd.CheckCols,
upd.PartialIndexPutCols, upd.PartialIndexDelCols,
upd.VectorIndexPutPartitionCols, upd.VectorIndexPutCentroidCols,
upd.VectorIndexDelPartitionCols,
)

input, _, err := b.buildMutationInput(upd, upd.Input, colList, &upd.MutationPrivate)
if err != nil {
Expand Down Expand Up @@ -482,36 +482,16 @@ func (b *Builder) buildUpdate(upd *memo.UpdateExpr) (_ execPlan, outputCols colO
}

func (b *Builder) buildUpsert(ups *memo.UpsertExpr) (_ execPlan, outputCols colOrdMap, err error) {
// Currently, the execution engine requires one input column for each insert,
// fetch, and update expression, so use ensureColumns to map and reorder
// columns so that they correspond to target table columns. For example:
//
// INSERT INTO xyz (x, y) VALUES (1, 1)
// ON CONFLICT (x) DO UPDATE SET x=2, y=2
//
// Here, both insert values and update values come from the same input column
// (because the constants are shared), and so must be mapped to separate
// output columns.
//
// If CanaryCol = 0, then this is the "blind upsert" case, which uses a KV
// "Put" to insert new rows or blindly overwrite existing rows. Existing rows
// do not need to be fetched or separately updated (i.e. ups.FetchCols and
// ups.UpdateCols are both empty).
//
// TODO(andyk): Using ensureColumns here can result in an extra Render.
// Upgrade execution engine to not require this.
cnt := len(ups.InsertCols) + len(ups.FetchCols) + len(ups.UpdateCols) + len(ups.CheckCols) +
len(ups.PartialIndexPutCols) + len(ups.PartialIndexDelCols) + 1
colList := make(opt.ColList, 0, cnt)
colList = appendColsWhenPresent(colList, ups.InsertCols)
colList = appendColsWhenPresent(colList, ups.FetchCols)
colList = appendColsWhenPresent(colList, ups.UpdateCols)
if ups.CanaryCol != 0 {
colList = append(colList, ups.CanaryCol)
}
colList = appendColsWhenPresent(colList, ups.CheckCols)
colList = appendColsWhenPresent(colList, ups.PartialIndexPutCols)
colList = appendColsWhenPresent(colList, ups.PartialIndexDelCols)
colList := appendColsWhenPresent(
ups.InsertCols, ups.FetchCols, ups.UpdateCols, opt.OptionalColList{ups.CanaryCol},
ups.CheckCols, ups.PartialIndexPutCols, ups.PartialIndexDelCols,
ups.VectorIndexPutPartitionCols, ups.VectorIndexPutCentroidCols,
ups.VectorIndexDelPartitionCols,
)

input, inputCols, err := b.buildMutationInput(ups, ups.Input, colList, &ups.MutationPrivate)
if err != nil {
Expand Down Expand Up @@ -584,20 +564,17 @@ func (b *Builder) buildDelete(del *memo.DeleteExpr) (_ execPlan, outputCols colO
if ep, ok, err := b.tryBuildDeleteRange(del); err != nil || ok {
return ep, colOrdMap{}, err
}

// Ensure that order of input columns matches order of target table columns.
//
// TODO(andyk): Using ensureColumns here can result in an extra Render.
// Upgrade execution engine to not require this.
colList := make(opt.ColList, 0, len(del.FetchCols)+len(del.PassthroughCols)+len(del.PartialIndexDelCols))
colList = appendColsWhenPresent(colList, del.FetchCols)
// The RETURNING clause of the Delete can refer to the columns in any of the
// USING tables. As a result, the Update may need to passthrough those
// columns so the projection above can use them.
var neededPassThroughCols opt.OptionalColList
if del.NeedResults() {
colList = append(colList, del.PassthroughCols...)
}
colList = appendColsWhenPresent(colList, del.PartialIndexDelCols)
// The RETURNING clause of the Delete can refer to the columns
// in any of the FROM tables. As a result, the Delete may need
// to passthrough those columns so the projection above can use
// them.
neededPassThroughCols = opt.OptionalColList(del.PassthroughCols)
}
colList := appendColsWhenPresent(
del.FetchCols, neededPassThroughCols, del.PartialIndexDelCols, del.VectorIndexDelPartitionCols,
)

input, _, err := b.buildMutationInput(del, del.Input, colList, &del.MutationPrivate)
if err != nil {
Expand Down Expand Up @@ -744,15 +721,26 @@ func (b *Builder) buildDeleteRange(del *memo.DeleteExpr) (execPlan, error) {
return execPlan{root: root}, nil
}

// appendColsWhenPresent appends non-zero column IDs from the src list into the
// dst list, and returns the possibly grown list.
func appendColsWhenPresent(dst opt.ColList, src opt.OptionalColList) opt.ColList {
for _, col := range src {
if col != 0 {
dst = append(dst, col)
// appendColsWhenPresent combines all non-zero column IDs from the given lists
// into a single column list, and returns the combined list.
func appendColsWhenPresent(lists ...opt.OptionalColList) opt.ColList {
var cnt int
for _, list := range lists {
for _, id := range list {
if id != 0 {
cnt++
}
}
}
combined := make(opt.ColList, 0, cnt)
for _, list := range lists {
for _, col := range list {
if col != 0 {
combined = append(combined, col)
}
}
}
return dst
return combined
}

// ordinalSetFromColList returns the set of ordinal positions of each non-zero
Expand Down
4 changes: 4 additions & 0 deletions pkg/sql/opt/exec/execbuilder/relational.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,10 @@ func (b *Builder) buildRelational(e memo.RelExpr) (_ execPlan, outputCols colOrd
case *memo.LockExpr:
ep, outputCols, err = b.buildLock(t)

case *memo.VectorSearchExpr, *memo.VectorPartitionSearchExpr:
err = unimplemented.New("vector index search",
"execution planning for vector index search is not yet implemented")

case *memo.BarrierExpr:
ep, outputCols, err = b.buildBarrier(t)

Expand Down
58 changes: 54 additions & 4 deletions pkg/sql/opt/memo/expr_format.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,10 +274,11 @@ func (f *ExprFmtCtx) formatRelational(e RelExpr, tp treeprinter.Node) {

case *ScanExpr, *PlaceholderScanExpr, *IndexJoinExpr, *ShowTraceForSessionExpr,
*InsertExpr, *UpdateExpr, *UpsertExpr, *DeleteExpr, *LockExpr, *SequenceSelectExpr,
*WindowExpr, *OpaqueRelExpr, *OpaqueMutationExpr, *OpaqueDDLExpr,
*AlterTableSplitExpr, *AlterTableUnsplitExpr, *AlterTableUnsplitAllExpr,
*AlterTableRelocateExpr, *AlterRangeRelocateExpr, *ControlJobsExpr, *CancelQueriesExpr,
*CancelSessionsExpr, *CreateViewExpr, *ExportExpr, *ShowCompletionsExpr:
*WindowExpr, *VectorSearchExpr, *VectorPartitionSearchExpr, *OpaqueRelExpr,
*OpaqueMutationExpr, *OpaqueDDLExpr, *AlterTableSplitExpr, *AlterTableUnsplitExpr,
*AlterTableUnsplitAllExpr, *AlterTableRelocateExpr, *AlterRangeRelocateExpr,
*ControlJobsExpr, *CancelQueriesExpr, *CancelSessionsExpr, *CreateViewExpr,
*ExportExpr, *ShowCompletionsExpr:
fmt.Fprintf(f.Buffer, "%v", e.Op())
FormatPrivate(f, e.Private(), required)

Expand Down Expand Up @@ -648,6 +649,8 @@ func (f *ExprFmtCtx) formatRelational(e RelExpr, tp treeprinter.Node) {
f.formatMutationCols(e, tp, "return-mapping:", t.ReturnCols, t.Table)
f.formatOptionalColList(e, tp, "check columns:", t.CheckCols)
f.formatOptionalColList(e, tp, "partial index put columns:", t.PartialIndexPutCols)
f.formatOptionalColList(e, tp, "vector index put partition columns:", t.VectorIndexPutPartitionCols)
f.formatOptionalColList(e, tp, "vector index put centroid columns:", t.VectorIndexPutCentroidCols)
f.formatBeforeTriggers(tp, t.Table, tree.TriggerEventInsert)
f.formatMutationCommon(tp, &t.MutationPrivate)
}
Expand All @@ -665,6 +668,9 @@ func (f *ExprFmtCtx) formatRelational(e RelExpr, tp treeprinter.Node) {
f.formatOptionalColList(e, tp, "check columns:", t.CheckCols)
f.formatOptionalColList(e, tp, "partial index put columns:", t.PartialIndexPutCols)
f.formatOptionalColList(e, tp, "partial index del columns:", t.PartialIndexDelCols)
f.formatOptionalColList(e, tp, "vector index del partition columns:", t.VectorIndexDelPartitionCols)
f.formatOptionalColList(e, tp, "vector index put partition columns:", t.VectorIndexPutPartitionCols)
f.formatOptionalColList(e, tp, "vector index put centroid columns:", t.VectorIndexPutCentroidCols)
f.formatBeforeTriggers(tp, t.Table, tree.TriggerEventUpdate)
f.formatMutationCommon(tp, &t.MutationPrivate)
}
Expand All @@ -689,6 +695,9 @@ func (f *ExprFmtCtx) formatRelational(e RelExpr, tp treeprinter.Node) {
f.formatOptionalColList(e, tp, "check columns:", t.CheckCols)
f.formatOptionalColList(e, tp, "partial index put columns:", t.PartialIndexPutCols)
f.formatOptionalColList(e, tp, "partial index del columns:", t.PartialIndexDelCols)
f.formatOptionalColList(e, tp, "vector index del partition columns:", t.VectorIndexDelPartitionCols)
f.formatOptionalColList(e, tp, "vector index put partition columns:", t.VectorIndexPutPartitionCols)
f.formatOptionalColList(e, tp, "vector index put centroid columns:", t.VectorIndexPutCentroidCols)
f.formatBeforeTriggers(tp, t.Table, tree.TriggerEventInsert, tree.TriggerEventUpdate)
f.formatMutationCommon(tp, &t.MutationPrivate)
}
Expand All @@ -702,6 +711,7 @@ func (f *ExprFmtCtx) formatRelational(e RelExpr, tp treeprinter.Node) {
f.formatMutationCols(e, tp, "return-mapping:", t.ReturnCols, t.Table)
f.formatOptionalColList(e, tp, "passthrough columns", opt.OptionalColList(t.PassthroughCols))
f.formatOptionalColList(e, tp, "partial index del columns:", t.PartialIndexDelCols)
f.formatOptionalColList(e, tp, "vector index del partition columns:", t.VectorIndexDelPartitionCols)
f.formatBeforeTriggers(tp, t.Table, tree.TriggerEventDelete)
f.formatMutationCommon(tp, &t.MutationPrivate)
}
Expand Down Expand Up @@ -732,6 +742,37 @@ func (f *ExprFmtCtx) formatRelational(e RelExpr, tp treeprinter.Node) {
}
}

case *VectorSearchExpr:
if c := t.PrefixConstraint; c != nil {
if c.IsContradiction() {
tp.Childf("prefix constraint: contradiction")
} else if c.Spans.Count() == 1 {
tp.Childf(
"prefix constraint: %s: %s", c.Columns.String(),
cat.MaybeMarkRedactable(c.Spans.Get(0).String(), f.RedactableValues),
)
} else {
n := tp.Childf("prefix constraint: %s", c.Columns.String())
for i := 0; i < c.Spans.Count(); i++ {
n.Child(cat.MaybeMarkRedactable(c.Spans.Get(i).String(), f.RedactableValues))
}
}
}
tp.Childf("target nearest neighbors: %d", t.TargetNeighborCount)

case *VectorPartitionSearchExpr:
if len(t.PrefixKeyCols) > 0 {
tp.Childf("prefix key columns: %v", t.PrefixKeyCols)
}
tp.Childf("query vector column: %s", f.ColumnString(t.QueryVectorCol))
if !t.PrimaryKeyCols.Empty() {
tp.Childf("primary key columns: %v", t.PrimaryKeyCols)
}
tp.Childf("partition col: %s", f.ColumnString(t.PartitionCol))
if t.CentroidCol != 0 {
tp.Childf("centroid col: %s", f.ColumnString(t.CentroidCol))
}

case *CreateTableExpr:
fmtFlags := tree.FmtSimple
if f.RedactableValues {
Expand Down Expand Up @@ -1435,6 +1476,9 @@ func (f *ExprFmtCtx) formatIndex(tabID opt.TableID, idxOrd cat.IndexOrdinal, rev
if index.IsInverted() {
f.Buffer.WriteString(",inverted")
}
if index.IsVector() {
f.Buffer.WriteString(",vector")
}
if _, isPartial := index.Predicate(); isPartial {
f.Buffer.WriteString(",partial")
}
Expand Down Expand Up @@ -1874,6 +1918,12 @@ func FormatPrivate(f *ExprFmtCtx, private interface{}, physProps *physical.Requi
fmt.Fprintf(f.Buffer, " ordering=%s", t.Ordering)
}

case *VectorSearchPrivate:
f.formatIndex(t.Table, t.Index, false /* reverse */)

case *VectorPartitionSearchPrivate:
f.formatIndex(t.Table, t.Index, false /* reverse */)

case *props.OrderingChoice:
if !t.Any() {
fmt.Fprintf(f.Buffer, " ordering=%s", t)
Expand Down
Loading

0 comments on commit f93f61a

Please sign in to comment.