diff --git a/blob.go b/blob.go index 6a59a858..0fd65b0b 100644 --- a/blob.go +++ b/blob.go @@ -92,8 +92,8 @@ func (b *Blob) Read(p []byte) (n int, err error) { want = avail } - defer b.c.arena.reset() - ptr := b.c.arena.new(uint64(want)) + ptr := b.c.new(uint64(want)) + defer b.c.free(ptr) r := b.c.call(b.c.api.blobRead, uint64(b.handle), uint64(ptr), uint64(want), uint64(b.offset)) @@ -158,8 +158,8 @@ func (b *Blob) WriteTo(w io.Writer) (n int64, err error) { // // https://sqlite.org/c3ref/blob_write.html func (b *Blob) Write(p []byte) (n int, err error) { - defer b.c.arena.reset() - ptr := b.c.arena.bytes(p) + ptr := b.c.newBytes(p) + defer b.c.free(ptr) r := b.c.call(b.c.api.blobWrite, uint64(b.handle), uint64(ptr), uint64(len(p)), uint64(b.offset)) diff --git a/sqlite.go b/sqlite.go index c3382db7..183f1248 100644 --- a/sqlite.go +++ b/sqlite.go @@ -185,6 +185,7 @@ func instantiateSQLite() (sqlt *sqlite, err error) { resultErrorBig: getFun("sqlite3_result_error_toobig"), createModule: getFun("sqlite3_create_module_go"), declareVTab: getFun("sqlite3_declare_vtab"), + vtabRHSValue: getFun("sqlite3_vtab_rhs_value"), } if err != nil { return nil, err @@ -411,6 +412,7 @@ type sqliteAPI struct { resultErrorBig api.Function createModule api.Function declareVTab api.Function + vtabRHSValue api.Function destructor uint32 } diff --git a/stmt.go b/stmt.go index f59cbfc5..98a27288 100644 --- a/stmt.go +++ b/stmt.go @@ -266,6 +266,7 @@ func (s *Stmt) bindRFC3339Nano(param int, value time.Time) error { // BindPointer binds a NULL to the prepared statement, just like [Stmt.BindNull], // but it also associates ptr with that NULL value such that it can be retrieved // within an application-defined SQL function using [Value.Pointer]. +// The leftmost SQL parameter has an index of 1. // // https://sqlite.org/c3ref/bind_blob.html func (s *Stmt) BindPointer(param int, ptr any) error { diff --git a/vtab.go b/vtab.go index e26c4e77..1eb15974 100644 --- a/vtab.go +++ b/vtab.go @@ -66,7 +66,7 @@ func implements[T any](typ reflect.Type) bool { } func (c *Conn) DeclareVtab(sql string) error { - defer c.arena.reset() + // defer c.arena.reset() sqlPtr := c.arena.string(sql) r := c.call(c.api.declareVTab, uint64(c.handle), uint64(sqlPtr)) return c.error(r) @@ -193,10 +193,11 @@ type VTabCursor interface { // // https://sqlite.org/c3ref/index_info.html type IndexInfo struct { - /* Inputs */ - Constraint []IndexConstraint - OrderBy []IndexOrderBy - /* Outputs */ + // Inputs + Constraint []IndexConstraint + OrderBy []IndexOrderBy + ColumnsUsed int64 + // Outputs ConstraintUsage []IndexConstraintUsage IdxNum int IdxStr string @@ -204,7 +205,9 @@ type IndexInfo struct { OrderByConsumed bool EstimatedCost float64 EstimatedRows int64 - ColumnsUsed int64 + // Internal + c *Conn + handle uint32 } // An IndexConstraint describes virtual table indexing constraint information. @@ -232,8 +235,28 @@ type IndexConstraintUsage struct { Omit bool } -func (idx *IndexInfo) load(ctx context.Context, mod api.Module, ptr uint32) { +// RHSValue returns the value of the right-hand operand of a constraint +// if the right-hand operand is known. +// +// https://sqlite.org/c3ref/vtab_rhs_value.html +func (idx *IndexInfo) RHSValue(column int) (*Value, error) { + // defer idx.c.arena.reset() + valPtr := idx.c.arena.new(ptrlen) + r := idx.c.call(idx.c.api.vtabRHSValue, + uint64(idx.handle), uint64(column), uint64(valPtr)) + if err := idx.c.error(r); err != nil { + return nil, err + } + return &Value{ + sqlite: idx.c.sqlite, + handle: util.ReadUint32(idx.c.mod, valPtr), + }, nil +} + +func (idx *IndexInfo) load() { // https://sqlite.org/c3ref/index_info.html + mod := idx.c.mod + ptr := idx.handle idx.Constraint = make([]IndexConstraint, util.ReadUint32(mod, ptr+0)) idx.ConstraintUsage = make([]IndexConstraintUsage, util.ReadUint32(mod, ptr+0)) @@ -242,9 +265,9 @@ func (idx *IndexInfo) load(ctx context.Context, mod api.Module, ptr uint32) { constraintPtr := util.ReadUint32(mod, ptr+4) for i := range idx.Constraint { idx.Constraint[i] = IndexConstraint{ - Column: int(util.ReadUint32(mod, constraintPtr+0)), + Column: int(int32(util.ReadUint32(mod, constraintPtr+0))), Op: IndexConstraintOp(util.ReadUint8(mod, constraintPtr+4)), - Usable: util.ReadUint8(mod, constraintPtr+8) != 0, + Usable: util.ReadUint8(mod, constraintPtr+5) != 0, } constraintPtr += 12 } @@ -252,15 +275,21 @@ func (idx *IndexInfo) load(ctx context.Context, mod api.Module, ptr uint32) { orderByPtr := util.ReadUint32(mod, ptr+12) for i := range idx.OrderBy { idx.OrderBy[i] = IndexOrderBy{ - Column: int(util.ReadUint32(mod, orderByPtr+0)), + Column: int(int32(util.ReadUint32(mod, orderByPtr+0))), Desc: util.ReadUint8(mod, orderByPtr+4) != 0, } orderByPtr += 8 } + + idx.EstimatedCost = util.ReadFloat64(mod, ptr+40) + idx.EstimatedRows = int64(util.ReadUint64(mod, ptr+48)) + idx.ColumnsUsed = int64(util.ReadUint64(mod, ptr+64)) } -func (idx *IndexInfo) save(ctx context.Context, mod api.Module, ptr uint32) { +func (idx *IndexInfo) save() { // https://sqlite.org/c3ref/index_info.html + mod := idx.c.mod + ptr := idx.handle usagePtr := util.ReadUint32(mod, ptr+16) for _, usage := range idx.ConstraintUsage { @@ -273,8 +302,7 @@ func (idx *IndexInfo) save(ctx context.Context, mod api.Module, ptr uint32) { util.WriteUint32(mod, ptr+20, uint32(idx.IdxNum)) if idx.IdxStr != "" { - db := ctx.Value(connKey{}).(*Conn) - util.WriteUint32(mod, ptr+24, db.newString(idx.IdxStr)) + util.WriteUint32(mod, ptr+24, idx.c.newString(idx.IdxStr)) util.WriteUint32(mod, ptr+28, 1) } if idx.OrderByConsumed { @@ -283,7 +311,6 @@ func (idx *IndexInfo) save(ctx context.Context, mod api.Module, ptr uint32) { util.WriteFloat64(mod, ptr+40, idx.EstimatedCost) util.WriteUint64(mod, ptr+48, uint64(idx.EstimatedRows)) util.WriteUint32(mod, ptr+56, uint32(idx.IdxFlags)) - util.WriteUint64(mod, ptr+64, uint64(idx.ColumnsUsed)) } // IndexConstraintOp is a virtual table constraint operator code. @@ -292,23 +319,23 @@ func (idx *IndexInfo) save(ctx context.Context, mod api.Module, ptr uint32) { type IndexConstraintOp uint8 const ( - Eq IndexConstraintOp = 2 - Gt IndexConstraintOp = 4 - Le IndexConstraintOp = 8 - Lt IndexConstraintOp = 16 - Ge IndexConstraintOp = 32 - Match IndexConstraintOp = 64 - Like IndexConstraintOp = 65 /* 3.10.0 and later */ - Glob IndexConstraintOp = 66 /* 3.10.0 and later */ - Regexp IndexConstraintOp = 67 /* 3.10.0 and later */ - Ne IndexConstraintOp = 68 /* 3.21.0 and later */ - IsNot IndexConstraintOp = 69 /* 3.21.0 and later */ - IsNotNull IndexConstraintOp = 70 /* 3.21.0 and later */ - IsNull IndexConstraintOp = 71 /* 3.21.0 and later */ - Is IndexConstraintOp = 72 /* 3.21.0 and later */ - Limit IndexConstraintOp = 73 /* 3.38.0 and later */ - Offset IndexConstraintOp = 74 /* 3.38.0 and later */ - Function IndexConstraintOp = 150 /* 3.25.0 and later */ + INDEX_CONSTRAINT_EQ IndexConstraintOp = 2 + INDEX_CONSTRAINT_GT IndexConstraintOp = 4 + INDEX_CONSTRAINT_LE IndexConstraintOp = 8 + INDEX_CONSTRAINT_LT IndexConstraintOp = 16 + INDEX_CONSTRAINT_GE IndexConstraintOp = 32 + INDEX_CONSTRAINT_MATCH IndexConstraintOp = 64 + INDEX_CONSTRAINT_LIKE IndexConstraintOp = 65 + INDEX_CONSTRAINT_GLOB IndexConstraintOp = 66 + INDEX_CONSTRAINT_REGEXP IndexConstraintOp = 67 + INDEX_CONSTRAINT_NE IndexConstraintOp = 68 + INDEX_CONSTRAINT_ISNOT IndexConstraintOp = 69 + INDEX_CONSTRAINT_ISNOTNULL IndexConstraintOp = 70 + INDEX_CONSTRAINT_ISNULL IndexConstraintOp = 71 + INDEX_CONSTRAINT_IS IndexConstraintOp = 72 + INDEX_CONSTRAINT_LIMIT IndexConstraintOp = 73 + INDEX_CONSTRAINT_OFFSET IndexConstraintOp = 74 + INDEX_CONSTRAINT_FUNCTION IndexConstraintOp = 150 ) // IndexScanFlag is a virtual table scan flag. @@ -317,22 +344,20 @@ const ( type IndexScanFlag uint32 const ( - Unique IndexScanFlag = 1 + INDEX_SCAN_UNIQUE IndexScanFlag = 1 ) func vtabReflectCallback(name string) func(_ context.Context, _ api.Module, _, _, _, _, _ uint32) uint32 { return func(ctx context.Context, mod api.Module, pMod, argc, argv, ppVTab, pzErr uint32) uint32 { - module := vtabGetHandle(ctx, mod, pMod) - db := ctx.Value(connKey{}).(*Conn) - arg := make([]reflect.Value, 1+argc) - arg[0] = reflect.ValueOf(db) + arg[0] = reflect.ValueOf(ctx.Value(connKey{})) for i := uint32(0); i < argc; i++ { ptr := util.ReadUint32(mod, argv+i*ptrlen) arg[i+1] = reflect.ValueOf(util.ReadString(mod, ptr, _MAX_STRING)) } + module := vtabGetHandle(ctx, mod, pMod) res := reflect.ValueOf(module).MethodByName(name).Call(arg) err, _ := res[1].Interface().(error) if err == nil { @@ -359,12 +384,14 @@ func vtabDestroyCallback(ctx context.Context, mod api.Module, pVTab uint32) uint func vtabBestIndexCallback(ctx context.Context, mod api.Module, pVTab, pIdxInfo uint32) uint32 { var info IndexInfo - info.load(ctx, mod, pIdxInfo) + info.handle = pIdxInfo + info.c = ctx.Value(connKey{}).(*Conn) + info.load() vtab := vtabGetHandle(ctx, mod, pVTab).(VTab) err := vtab.BestIndex(&info) - info.save(ctx, mod, pIdxInfo) + info.save() return vtabError(ctx, mod, pVTab, _VTAB_ERROR, err) } @@ -473,7 +500,11 @@ func cursorFilterCallback(ctx context.Context, mod api.Module, pCur, idxNum, idx cursor := vtabGetHandle(ctx, mod, pCur).(VTabCursor) db := ctx.Value(connKey{}).(*Conn) args := callbackArgs(db, argc, argv) - err := cursor.Filter(int(idxNum), util.ReadString(mod, idxStr, _MAX_STRING), args...) + var idxName string + if idxStr != 0 { + idxName = util.ReadString(mod, idxStr, _MAX_STRING) + } + err := cursor.Filter(int(idxNum), idxName, args...) return vtabError(ctx, mod, pCur, _CURSOR_ERROR, err) } diff --git a/vtab_test.go b/vtab_test.go index 04049347..9207503f 100644 --- a/vtab_test.go +++ b/vtab_test.go @@ -15,7 +15,7 @@ func ExampleCreateModule() { } defer db.Close() - err = sqlite3.CreateModule(db, "generate_series", seriesModule{}) + err = sqlite3.CreateModule(db, "generate_series", seriesTable{}) if err != nil { log.Fatal(err) } @@ -38,68 +38,57 @@ func ExampleCreateModule() { // 8 8 } -type seriesModule struct{} +type seriesTable struct{} -func (seriesModule) Connect(c *sqlite3.Conn, arg ...string) (*seriesTable, error) { - err := c.DeclareVtab(`CREATE TABLE x(value, start HIDDEN, stop HIDDEN, step HIDDEN)`) - if err != nil { - return nil, err - } - return &seriesTable{0, 0, 1}, nil -} - -type seriesTable struct { - start int64 - stop int64 - step int64 +func (seriesTable) Connect(c *sqlite3.Conn, arg ...string) (_ seriesTable, err error) { + err = c.DeclareVtab(`CREATE TABLE x(value, start HIDDEN, stop HIDDEN, step HIDDEN)`) + return } -func (*seriesTable) Disconnect() error { +func (seriesTable) Disconnect() error { return nil } -func (*seriesTable) BestIndex(idx *sqlite3.IndexInfo) error { - idx.IdxNum = 0 - idx.IdxStr = "default" - argv := 1 +func (seriesTable) BestIndex(idx *sqlite3.IndexInfo) error { for i, cst := range idx.Constraint { - if cst.Op == sqlite3.Eq { - idx.ConstraintUsage[i] = sqlite3.IndexConstraintUsage{ - ArgvIndex: argv, - Omit: true, + switch cst.Column { + case 1, 2, 3: // start, stop, step + if cst.Op == sqlite3.INDEX_CONSTRAINT_EQ && cst.Usable { + idx.ConstraintUsage[i] = sqlite3.IndexConstraintUsage{ + ArgvIndex: cst.Column, + Omit: true, + } } - argv++ } } return nil } -func (tab *seriesTable) Open() (sqlite3.VTabCursor, error) { - return &seriesCursor{tab, 0}, nil +func (seriesTable) Open() (sqlite3.VTabCursor, error) { + return &seriesCursor{}, nil } type seriesCursor struct { - *seriesTable + start int64 + stop int64 + step int64 value int64 } func (cur *seriesCursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error { - switch len(arg) { - case 0: - cur.seriesTable.start = 0 - cur.seriesTable.stop = 1000 - case 1: - cur.seriesTable.start = arg[0].Int64() - cur.seriesTable.stop = 1000 - case 2: - cur.seriesTable.start = arg[0].Int64() - cur.seriesTable.stop = arg[1].Int64() - case 3: - cur.seriesTable.start = arg[0].Int64() - cur.seriesTable.stop = arg[1].Int64() - cur.seriesTable.step = arg[2].Int64() + cur.start = 0 + cur.stop = 1000 + cur.step = 1 + if len(arg) > 0 { + cur.start = arg[0].Int64() + } + if len(arg) > 1 { + cur.stop = arg[1].Int64() + } + if len(arg) > 2 { + cur.step = arg[2].Int64() } - cur.value = cur.seriesTable.start + cur.value = cur.start return nil }