Skip to content

Commit

Permalink
Reuse statement, API.
Browse files Browse the repository at this point in the history
  • Loading branch information
ncruces committed Dec 4, 2023
1 parent 8a0baed commit cd40213
Show file tree
Hide file tree
Showing 12 changed files with 144 additions and 83 deletions.
20 changes: 19 additions & 1 deletion const.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ const (
PREPARE_NO_VTAB PrepareFlag = 0x04
)

// FunctionFlag is a flag that can be passed to [Conn.PrepareFlags].
// FunctionFlag is a flag that can be passed to
// [Conn.CreateFunction] and [Conn.CreateWindowFunction].
//
// https://sqlite.org/c3ref/c_deterministic.html
type FunctionFlag uint32
Expand All @@ -181,6 +182,23 @@ const (
INNOCUOUS FunctionFlag = 0x000200000
)

// StmtStatus name counter values associated with the [Stmt.Status] method.
//
// https://sqlite.org/c3ref/c_stmtstatus_counter.html
type StmtStatus uint32

const (
STMTSTATUS_FULLSCAN_STEP StmtStatus = 1
STMTSTATUS_SORT StmtStatus = 2
STMTSTATUS_AUTOINDEX StmtStatus = 3
STMTSTATUS_VM_STEP StmtStatus = 4
STMTSTATUS_REPREPARE StmtStatus = 5
STMTSTATUS_RUN StmtStatus = 6
STMTSTATUS_FILTER_MISS StmtStatus = 7
STMTSTATUS_FILTER_HIT StmtStatus = 8
STMTSTATUS_MEMUSED StmtStatus = 99
)

// Datatype is a fundamental datatype of SQLite.
//
// https://sqlite.org/c3ref/c_blob.html
Expand Down
23 changes: 9 additions & 14 deletions driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e
s.Close()
return nil, util.TailErr
}
return &stmt{s, c.Conn}, nil
return &stmt{s}, nil
}

func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
Expand Down Expand Up @@ -281,8 +281,7 @@ func (*conn) CheckNamedValue(arg *driver.NamedValue) error {
}

type stmt struct {
Stmt *sqlite3.Stmt
Conn *sqlite3.Conn
*sqlite3.Stmt
}

var (
Expand All @@ -292,10 +291,6 @@ var (
_ driver.NamedValueChecker = &stmt{}
)

func (s *stmt) Close() error {
return s.Stmt.Close()
}

func (s *stmt) NumInput() int {
n := s.Stmt.BindCount()
for i := 1; i <= n; i++ {
Expand All @@ -322,23 +317,23 @@ func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive
return nil, err
}

old := s.Conn.SetInterrupt(ctx)
defer s.Conn.SetInterrupt(old)
old := s.Stmt.Conn().SetInterrupt(ctx)
defer s.Stmt.Conn().SetInterrupt(old)

err = s.Stmt.Exec()
if err != nil {
return nil, err
}

return newResult(s.Conn), nil
return newResult(s.Stmt.Conn()), nil
}

func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
err := s.setupBindings(args)
if err != nil {
return nil, err
}
return &rows{ctx, s.Stmt, s.Conn}, nil
return &rows{ctx, s.Stmt}, nil
}

func (s *stmt) setupBindings(args []driver.NamedValue) error {
Expand Down Expand Up @@ -442,10 +437,10 @@ func (r resultRowsAffected) RowsAffected() (int64, error) {
type rows struct {
ctx context.Context
Stmt *sqlite3.Stmt
Conn *sqlite3.Conn
}

func (r *rows) Close() error {
r.Stmt.ClearBindings()
return r.Stmt.Reset()
}

Expand All @@ -469,8 +464,8 @@ func (r *rows) ColumnTypeDatabaseTypeName(index int) string {
}

func (r *rows) Next(dest []driver.Value) error {
old := r.Conn.SetInterrupt(r.ctx)
defer r.Conn.SetInterrupt(old)
old := r.Stmt.Conn().SetInterrupt(r.ctx)
defer r.Stmt.Conn().SetInterrupt(old)

if !r.Stmt.Step() {
if err := r.Stmt.Err(); err != nil {
Expand Down
2 changes: 2 additions & 0 deletions embed/exports.txt
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ sqlite3_result_value
sqlite3_result_zeroblob64
sqlite3_set_auxdata_go
sqlite3_step
sqlite3_stmt_busy
sqlite3_stmt_readonly
sqlite3_stmt_status
sqlite3_uri_key
sqlite3_uri_parameter
sqlite3_user_data
Expand Down
Binary file modified embed/sqlite3.wasm
Binary file not shown.
4 changes: 2 additions & 2 deletions ext/csv/csv.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,9 @@ func (t *table) newReader() *csv.Reader {

type cursor struct {
table *table
rowID int64
row []string
csv *csv.Reader
row []string
rowID int64
}

func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
Expand Down
2 changes: 1 addition & 1 deletion ext/lines/lines.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ func (l lines) Open() (sqlite3.VTabCursor, error) {
}

type cursor struct {
reader bool
scanner *bufio.Scanner
closer io.Closer
rowID int64
eof bool
reader bool
}

func (c *cursor) Close() (err error) {
Expand Down
90 changes: 48 additions & 42 deletions ext/statement/stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,10 @@ func Register(db *sqlite3.Conn) {
sql = sql[1 : len-1]
}

table := &table{
db: db,
sql: sql,
}
err = table.declare()
table := &table{sql: sql}
err = table.declare(db)
if err != nil {
table.Close()
return nil, err
}
return table, nil
Expand All @@ -41,42 +39,40 @@ func Register(db *sqlite3.Conn) {
}

type table struct {
db *sqlite3.Conn
sql string
inputs int
outputs int
stmt *sqlite3.Stmt
sql string
inuse bool
}

func (t *table) declare() error {
stmt, tail, err := t.db.Prepare(t.sql)
func (t *table) declare(db *sqlite3.Conn) (err error) {
var tail string
t.stmt, tail, err = db.Prepare(t.sql)
if err != nil {
return err
}
defer stmt.Close()
if tail != "" {
return fmt.Errorf("statement: multiple statements")
}
if !stmt.ReadOnly() {
if !t.stmt.ReadOnly() {
return fmt.Errorf("statement: statement must be read only")
}

t.inputs = stmt.BindCount()
t.outputs = stmt.ColumnCount()

var sep = ""
var str strings.Builder
str.WriteString(`CREATE TABLE x(`)
for i := 0; i < t.outputs; i++ {
outputs := t.stmt.ColumnCount()
for i := 0; i < outputs; i++ {
str.WriteString(sep)
name := stmt.ColumnName(i)
name := t.stmt.ColumnName(i)
str.WriteString(sqlite3.QuoteIdentifier(name))
str.WriteByte(' ')
str.WriteString(stmt.ColumnDeclType(i))
str.WriteString(t.stmt.ColumnDeclType(i))
sep = ","
}
for i := 1; i <= t.inputs; i++ {
inputs := t.stmt.BindCount()
for i := 1; i <= inputs; i++ {
str.WriteString(sep)
name := stmt.BindName(i)
name := t.stmt.BindName(i)
if name == "" {
str.WriteString("[")
str.WriteString(strconv.Itoa(i))
Expand All @@ -87,22 +83,24 @@ func (t *table) declare() error {
}
sep = ","
}

str.WriteByte(')')
return t.db.DeclareVtab(str.String())
return db.DeclareVtab(str.String())
}

func (t *table) Close() error {
return t.stmt.Close()
}

func (t *table) BestIndex(idx *sqlite3.IndexInfo) error {
idx.OrderByConsumed = false
idx.EstimatedCost = 1
idx.EstimatedRows = 1
idx.EstimatedCost = 1000

var argvIndex = 1
var needIndex bool
var listIndex []int
outputs := t.stmt.ColumnCount()
for i, cst := range idx.Constraint {
// Skip if this is a constraint on one of our output columns.
if cst.Column < t.outputs {
if cst.Column < outputs {
continue
}

Expand All @@ -114,7 +112,7 @@ func (t *table) BestIndex(idx *sqlite3.IndexInfo) error {

// The non-zero argvIdx values must be contiguous.
// If they're not, build a list and serialize it through IdxStr.
nextIndex := cst.Column - t.outputs + 1
nextIndex := cst.Column - outputs + 1
idx.ConstraintUsage[i] = sqlite3.IndexConstraintUsage{
ArgvIndex: argvIndex,
Omit: true,
Expand All @@ -136,10 +134,15 @@ func (t *table) BestIndex(idx *sqlite3.IndexInfo) error {
return nil
}

func (t *table) Open() (sqlite3.VTabCursor, error) {
stmt, _, err := t.db.Prepare(t.sql)
if err != nil {
return nil, err
func (t *table) Open() (_ sqlite3.VTabCursor, err error) {
stmt := t.stmt
if !t.inuse {
t.inuse = true
} else {
stmt, _, err = t.stmt.Conn().Prepare(t.sql)
if err != nil {
return nil, err
}
}
return &cursor{table: t, stmt: stmt}, nil
}
Expand All @@ -153,26 +156,29 @@ type cursor struct {
stmt *sqlite3.Stmt
arg []sqlite3.Value
rowID int64
done bool
}

func (c *cursor) Close() error {
if c.stmt == c.table.stmt {
c.table.inuse = false
c.stmt.ClearBindings()
return c.stmt.Reset()
}
return c.stmt.Close()
}

func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
c.arg = arg
c.rowID = 0
if err := c.stmt.ClearBindings(); err != nil {
return err
}
c.stmt.ClearBindings()
if err := c.stmt.Reset(); err != nil {
return err
}

var list []int
if idxStr != "" {
err := json.Unmarshal([]byte(idxStr), &list)
buf := unsafe.Slice(unsafe.StringData(idxStr), len(idxStr))
err := json.Unmarshal(buf, &list)
if err != nil {
return err
}
Expand All @@ -196,23 +202,23 @@ func (c *cursor) Next() error {
c.rowID++
return nil
}
c.done = true
return c.stmt.Err()
}

func (c *cursor) EOF() bool {
return c.done
return !c.stmt.Busy()
}

func (c *cursor) RowID() (int64, error) {
return c.rowID, nil
}

func (c *cursor) Column(ctx *sqlite3.Context, col int) error {
if col < c.table.outputs {
switch outputs := c.stmt.ColumnCount(); {
case col < outputs:
ctx.ResultValue(c.stmt.ColumnValue(col))
} else if col-c.table.outputs < len(c.arg) {
ctx.ResultValue(c.arg[col-c.table.outputs])
case col-outputs < len(c.arg):
ctx.ResultValue(c.arg[col-outputs])
}
return nil
}
7 changes: 5 additions & 2 deletions func.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error {
// CreateFunction defines a new scalar SQL function.
//
// https://sqlite.org/c3ref/create_function.html
func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn func(ctx Context, arg ...Value)) error {
func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn ScalarFunction) error {
defer c.arena.mark()()
namePtr := c.arena.string(name)
funcPtr := util.AddHandle(c.ctx, fn)
Expand All @@ -42,6 +42,9 @@ func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn func(
return c.error(r)
}

// ScalarFunction is the type of a scalar SQL function.
type ScalarFunction func(ctx Context, arg ...Value)

// CreateWindowFunction defines a new aggregate or aggregate window SQL function.
// If fn returns a [WindowFunction], then an aggregate window function is created.
// If fn returns an [io.Closer], it will be called to free resources.
Expand Down Expand Up @@ -95,7 +98,7 @@ func compareCallback(ctx context.Context, mod api.Module, pApp, nKey1, pKey1, nK

func funcCallback(ctx context.Context, mod api.Module, pCtx, nArg, pArg uint32) {
db := ctx.Value(connKey{}).(*Conn)
fn := userDataHandle(db, pCtx).(func(ctx Context, arg ...Value))
fn := userDataHandle(db, pCtx).(ScalarFunction)
fn(Context{db, pCtx}, callbackArgs(db, nArg, pArg)...)
}

Expand Down
Loading

0 comments on commit cd40213

Please sign in to comment.