diff --git a/ext/blob/blob.go b/ext/blob/blob.go index e9a6e79d..9b1b7d66 100644 --- a/ext/blob/blob.go +++ b/ext/blob/blob.go @@ -7,7 +7,15 @@ import ( "github.com/ncruces/go-sqlite3" ) -// Register registers the blob_open SQL function. +// Register registers the blob_open SQL function: +// +// blob_open(schema, table, column, rowid, flags, callback, args...) +// +// The callback must be a [sqlite3.Pointer] to an [OpenCallback]. +// Any optional args will be passed to the callback, +// along with the [sqlite3.Blob] handle. +// +// https://sqlite.org/c3ref/blob.html func Register(db *sqlite3.Conn) { db.CreateFunction("blob_open", -1, sqlite3.DETERMINISTIC|sqlite3.DIRECTONLY, openBlob) @@ -56,4 +64,5 @@ func openBlob(ctx sqlite3.Context, arg ...sqlite3.Value) { ctx.SetAuxData(2, blob) } +// OpenCallback is the type for the blob_open callback. type OpenCallback func(*sqlite3.Blob, ...sqlite3.Value) error diff --git a/func.go b/func.go index 53e53cbd..e7ecc60d 100644 --- a/func.go +++ b/func.go @@ -25,11 +25,7 @@ func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error { funcPtr := util.AddHandle(c.ctx, fn) r := c.call(c.api.createCollation, uint64(c.handle), uint64(namePtr), uint64(funcPtr)) - if err := c.error(r); err != nil { - util.DelHandle(c.ctx, funcPtr) - return err - } - return nil + return c.error(r) } // CreateFunction defines a new scalar SQL function. diff --git a/sqlite.go b/sqlite.go index 4dc68cdc..6b025c88 100644 --- a/sqlite.go +++ b/sqlite.go @@ -413,11 +413,11 @@ type sqliteAPI struct { func exportCallbacks(env wazero.HostModuleBuilder) wazero.HostModuleBuilder { util.ExportFuncII(env, "go_progress", callbackProgress) util.ExportFuncVI(env, "go_destroy", callbackDestroy) - util.ExportFuncIIIIII(env, "go_compare", callbackCompare) util.ExportFuncVIII(env, "go_func", callbackFunc) util.ExportFuncVIII(env, "go_step", callbackStep) util.ExportFuncVI(env, "go_final", callbackFinal) util.ExportFuncVI(env, "go_value", callbackValue) util.ExportFuncVIII(env, "go_inverse", callbackInverse) + util.ExportFuncIIIIII(env, "go_compare", callbackCompare) return env } diff --git a/sqlite3/func.c b/sqlite3/func.c index f7241fa1..c10bbed1 100644 --- a/sqlite3/func.c +++ b/sqlite3/func.c @@ -2,54 +2,48 @@ #include "sqlite3.h" -int go_compare(void *, int, const void *, int, const void *); +typedef void *go_handle; + +void go_destroy(go_handle); + +static_assert(sizeof(go_handle) == 4, "Unexpected size"); + void go_func(sqlite3_context *, int, sqlite3_value **); void go_step(sqlite3_context *, int, sqlite3_value **); void go_final(sqlite3_context *); void go_value(sqlite3_context *); void go_inverse(sqlite3_context *, int, sqlite3_value **); -void go_destroy(void *); -int sqlite3_create_collation_go(sqlite3 *db, const char *zName, void *pApp) { - return sqlite3_create_collation_v2(db, zName, SQLITE_UTF8, pApp, go_compare, - go_destroy); +int go_compare(go_handle, int, const void *, int, const void *); + +int sqlite3_create_collation_go(sqlite3 *db, const char *name, go_handle app) { + int rc = sqlite3_create_collation_v2(db, name, SQLITE_UTF8, app, go_compare, + go_destroy); + if (rc) go_destroy(app); + return rc; } -int sqlite3_create_function_go(sqlite3 *db, const char *zName, int nArg, - int flags, void *pApp) { - return sqlite3_create_function_v2(db, zName, nArg, SQLITE_UTF8 | flags, pApp, +int sqlite3_create_function_go(sqlite3 *db, const char *name, int argc, + int flags, go_handle app) { + return sqlite3_create_function_v2(db, name, argc, SQLITE_UTF8 | flags, app, go_func, /*step=*/NULL, /*final=*/NULL, go_destroy); } -int sqlite3_create_aggregate_function_go(sqlite3 *db, const char *zName, - int nArg, int flags, void *pApp) { - return sqlite3_create_window_function(db, zName, nArg, SQLITE_UTF8 | flags, - pApp, go_step, go_final, /*value=*/NULL, +int sqlite3_create_aggregate_function_go(sqlite3 *db, const char *name, + int argc, int flags, go_handle app) { + return sqlite3_create_window_function(db, name, argc, SQLITE_UTF8 | flags, + app, go_step, go_final, /*value=*/NULL, /*inverse=*/NULL, go_destroy); } -int sqlite3_create_window_function_go(sqlite3 *db, const char *zName, int nArg, - int flags, void *pApp) { - return sqlite3_create_window_function(db, zName, nArg, SQLITE_UTF8 | flags, - pApp, go_step, go_final, go_value, +int sqlite3_create_window_function_go(sqlite3 *db, const char *name, int argc, + int flags, go_handle app) { + return sqlite3_create_window_function(db, name, argc, SQLITE_UTF8 | flags, + app, go_step, go_final, go_value, go_inverse, go_destroy); } -void sqlite3_set_auxdata_go(sqlite3_context *ctx, int iArg, void *pAux) { - sqlite3_set_auxdata(ctx, iArg, pAux, go_destroy); -} - -#define GO_POINTER_TYPE "github.com/ncruces/go-sqlite3.Pointer" - -int sqlite3_bind_pointer_go(sqlite3_stmt *stmt, int i, void *pApp) { - return sqlite3_bind_pointer(stmt, i, pApp, GO_POINTER_TYPE, go_destroy); -} - -void sqlite3_result_pointer_go(sqlite3_context *ctx, void *pApp) { - sqlite3_result_pointer(ctx, pApp, GO_POINTER_TYPE, go_destroy); -} - -void *sqlite3_value_pointer_go(sqlite3_value *val) { - return sqlite3_value_pointer(val, GO_POINTER_TYPE); -} +void sqlite3_set_auxdata_go(sqlite3_context *ctx, int i, go_handle aux) { + sqlite3_set_auxdata(ctx, i, aux, go_destroy); +} \ No newline at end of file diff --git a/sqlite3/main.c b/sqlite3/main.c index 81b2e216..61899768 100644 --- a/sqlite3/main.c +++ b/sqlite3/main.c @@ -1,7 +1,5 @@ // Amalgamation #include "sqlite3.c" -// VFS -#include "vfs.c" // Extensions #include "ext/anycollseq.c" #include "ext/base64.c" @@ -11,8 +9,11 @@ #include "ext/uint.c" #include "ext/uuid.c" #include "func.c" +#include "pointer.c" #include "progress.c" #include "time.c" +#include "vfs.c" +// #include "vtab.c" __attribute__((constructor)) void init() { sqlite3_initialize(); diff --git a/sqlite3/pointer.c b/sqlite3/pointer.c new file mode 100644 index 00000000..d9a317e0 --- /dev/null +++ b/sqlite3/pointer.c @@ -0,0 +1,16 @@ + +#include "sqlite3.h" + +#define GO_POINTER_TYPE "github.com/ncruces/go-sqlite3.Pointer" + +int sqlite3_bind_pointer_go(sqlite3_stmt *stmt, int i, go_handle app) { + return sqlite3_bind_pointer(stmt, i, app, GO_POINTER_TYPE, go_destroy); +} + +void sqlite3_result_pointer_go(sqlite3_context *ctx, go_handle app) { + sqlite3_result_pointer(ctx, app, GO_POINTER_TYPE, go_destroy); +} + +go_handle sqlite3_value_pointer_go(sqlite3_value *val) { + return sqlite3_value_pointer(val, GO_POINTER_TYPE); +} \ No newline at end of file diff --git a/sqlite3/vfs.c b/sqlite3/vfs.c index ec6cdf2b..896f86f0 100644 --- a/sqlite3/vfs.c +++ b/sqlite3/vfs.c @@ -60,7 +60,7 @@ static int go_open_wrapper(sqlite3_vfs *vfs, sqlite3_filename zName, struct go_file { sqlite3_file base; - int handle; + go_handle handle; }; int sqlite3_os_init() { diff --git a/sqlite3/vtab.c b/sqlite3/vtab.c new file mode 100644 index 00000000..6daee2b1 --- /dev/null +++ b/sqlite3/vtab.c @@ -0,0 +1,215 @@ +#include + +#include "sqlite3.h" + +// https://github.com/JuliaLang/julia/blob/v1.9.4/src/julia.h#L67-L68 +#define container_of(ptr, type, member) \ + ((type *)((char *)(ptr)-offsetof(type, member))) + +#define SQLITE_MOD_CREATOR_GO /*******/ 0x01 +#define SQLITE_VTAB_UPDATER_GO /******/ 0x02 +#define SQLITE_VTAB_RENAMER_GO /******/ 0x04 +#define SQLITE_VTAB_OVERLOADER_GO /***/ 0x08 +#define SQLITE_VTAB_CHECKER_GO /******/ 0x10 +#define SQLITE_VTAB_TX_GO /***********/ 0x20 +#define SQLITE_VTAB_SAVEPOINTER_GO /**/ 0x40 + +int go_mod_create(sqlite3_module *, int argc, const char *const *argv, + sqlite3_vtab **, char **pzErr); +int go_mod_connect(sqlite3_module *, int argc, const char *const *argv, + sqlite3_vtab **, char **pzErr); + +int go_vtab_disconnect(sqlite3_vtab *); +int go_vtab_destroy(sqlite3_vtab *); +int go_vtab_best_index(sqlite3_vtab *, sqlite3_index_info *); +int go_vtab_open(sqlite3_vtab *, sqlite3_vtab_cursor **); + +int go_cur_close(sqlite3_vtab_cursor *); +int go_cur_filter(sqlite3_vtab_cursor *, int idxNum, const char *idxStr, + int argc, sqlite3_value **argv); +int go_cur_next(sqlite3_vtab_cursor *); +int go_cur_eof(sqlite3_vtab_cursor *); +int go_cur_column(sqlite3_vtab_cursor *, sqlite3_context *, int); +int go_cur_rowid(sqlite3_vtab_cursor *, sqlite3_int64 *pRowid); + +int go_vtab_update(sqlite3_vtab *, int, sqlite3_value **, sqlite3_int64 *); +int go_vtab_rename(sqlite3_vtab *, const char *zNew); +int go_vtab_find_function(sqlite3_vtab *, int nArg, const char *zName, + go_handle *pxFunc); + +int go_vtab_begin(sqlite3_vtab *); +int go_vtab_sync(sqlite3_vtab *); +int go_vtab_commit(sqlite3_vtab *); +int go_vtab_rollback(sqlite3_vtab *); + +int go_vtab_savepoint(sqlite3_vtab *, int); +int go_vtab_release(sqlite3_vtab *, int); +int go_vtab_rollback_to(sqlite3_vtab *, int); + +int go_vtab_integrity(sqlite3_vtab *, const char *zSchema, const char *zTabName, + int mFlags, char **pzErr); + +struct go_module { + go_handle handle; + sqlite3_module base; +}; + +struct go_vtab { + go_handle handle; + sqlite3_vtab base; +}; + +struct go_cursor { + go_handle handle; + sqlite3_vtab_cursor base; +}; + +static void go_mod_destroy(void *pAux) { + struct go_module *mod = (struct go_module *)pAux; + void *handle = mod->handle; + free(mod); + go_destroy(handle); +} + +static int go_mod_create_wrapper(sqlite3 *db, void *pAux, int argc, + const char *const *argv, sqlite3_vtab **ppVTab, + char **pzErr) { + struct go_vtab *vtab = calloc(1, sizeof(struct go_vtab)); + if (vtab == NULL) return SQLITE_NOMEM; + *ppVTab = &vtab->base; + + struct go_module *mod = (struct go_module *)pAux; + int rc = go_mod_create(&mod->base, argc, argv, ppVTab, pzErr); + if (rc) { + if (*pzErr) *pzErr = sqlite3_mprintf("%s", *pzErr); + free(vtab); + } + return rc; +} + +static int go_mod_connect_wrapper(sqlite3 *db, void *pAux, int argc, + const char *const *argv, + sqlite3_vtab **ppVTab, char **pzErr) { + struct go_vtab *vtab = calloc(1, sizeof(struct go_vtab)); + if (vtab == NULL) return SQLITE_NOMEM; + *ppVTab = &vtab->base; + + struct go_module *mod = (struct go_module *)pAux; + int rc = go_mod_connect(&mod->base, argc, argv, ppVTab, pzErr); + if (rc) { + free(vtab); + if (*pzErr) *pzErr = sqlite3_mprintf("%s", *pzErr); + } + return rc; +} + +static int go_vtab_disconnect_wrapper(sqlite3_vtab *pVTab) { + struct go_vtab *vtab = container_of(pVTab, struct go_vtab, base); + int rc = go_vtab_disconnect(pVTab); + free(vtab); + return rc; +} + +static int go_vtab_destroy_wrapper(sqlite3_vtab *pVTab) { + struct go_vtab *vtab = container_of(pVTab, struct go_vtab, base); + int rc = go_vtab_destroy(pVTab); + free(vtab); + return rc; +} + +static int go_vtab_open_wrapper(sqlite3_vtab *pVTab, + sqlite3_vtab_cursor **ppCursor) { + struct go_cursor *cur = calloc(1, sizeof(struct go_cursor)); + if (cur == NULL) return SQLITE_NOMEM; + *ppCursor = &cur->base; + + int rc = go_vtab_open(pVTab, ppCursor); + if (rc) free(cur); + return rc; +} + +static int go_cur_close_wrapper(sqlite3_vtab_cursor *pCursor) { + struct go_cursor *cur = container_of(pCursor, struct go_cursor, base); + int rc = go_cur_close(pCursor); + free(cur); + return rc; +} + +static int go_vtab_find_function_wrapper( + sqlite3_vtab *pVTab, int nArg, const char *zName, + void (**pxFunc)(sqlite3_context *, int, sqlite3_value **), void **ppArg) { + struct go_vtab *vtab = container_of(pVTab, struct go_vtab, base); + + go_handle handle; + int rc = go_vtab_find_function(pVTab, nArg, zName, &handle); + if (rc) { + *pxFunc = go_func; + *ppArg = handle; + } + return rc; +} + +static int go_vtab_integrity_wrapper(sqlite3_vtab *pVTab, const char *zSchema, + const char *zTabName, int mFlags, + char **pzErr) { + int rc = go_vtab_integrity(pVTab, zSchema, zTabName, mFlags, pzErr); + if (rc && *pzErr) *pzErr = sqlite3_mprintf("%s", *pzErr); + return rc; +} + +int sqlite3_create_module_go(sqlite3 *db, const char *zName, int flags, + void *handle) { + struct go_module *mod = malloc(sizeof(struct go_module)); + if (mod == NULL) { + go_destroy(handle); + return SQLITE_NOMEM; + } + + mod->handle = handle; + mod->base = (sqlite3_module){ + .iVersion = 4, + .xConnect = go_mod_connect_wrapper, + .xDisconnect = go_vtab_disconnect_wrapper, + .xBestIndex = go_vtab_best_index, + .xOpen = go_vtab_open_wrapper, + .xClose = go_cur_close_wrapper, + .xFilter = go_cur_filter, + .xNext = go_cur_next, + .xEof = go_cur_eof, + .xColumn = go_cur_column, + .xRowid = go_cur_rowid, + }; + if (flags & SQLITE_MOD_CREATOR_GO) { + mod->base.xCreate = go_mod_create_wrapper; + mod->base.xDestroy = go_vtab_destroy_wrapper; + } + if (flags & SQLITE_VTAB_UPDATER_GO) { + mod->base.xUpdate = go_vtab_update; + } + if (flags & SQLITE_VTAB_RENAMER_GO) { + mod->base.xRename = go_vtab_rename; + } + if (flags & SQLITE_VTAB_OVERLOADER_GO) { + mod->base.xFindFunction = go_vtab_find_function_wrapper; + } + if (flags & SQLITE_VTAB_CHECKER_GO) { + mod->base.xIntegrity = go_vtab_integrity_wrapper; + } + if (flags & SQLITE_VTAB_TX_GO) { + mod->base.xBegin = go_vtab_begin; + mod->base.xSync = go_vtab_sync; + mod->base.xCommit = go_vtab_commit; + mod->base.xRollback = go_vtab_rollback; + } + if (flags & SQLITE_VTAB_SAVEPOINTER_GO) { + mod->base.xSavepoint = go_vtab_savepoint; + mod->base.xRelease = go_vtab_release; + mod->base.xRollbackTo = go_vtab_rollback_to; + } + + return sqlite3_create_module_v2(db, zName, &mod->base, mod, go_mod_destroy); +} + +static_assert(offsetof(struct go_module, base) == 4, "Unexpected offset"); +static_assert(offsetof(struct go_vtab, base) == 4, "Unexpected offset"); +static_assert(offsetof(struct go_cursor, base) == 4, "Unexpected offset"); \ No newline at end of file diff --git a/vtab.go b/vtab.go index 6057b354..c1f9c9c2 100644 --- a/vtab.go +++ b/vtab.go @@ -1,68 +1,90 @@ package sqlite3 +// https://sqlite.org/vtab.html#xconnect type Module interface { - Connect(db *Conn, arg ...string) (Vtab, error) + Connect(db *Conn, arg ...string) (VTab, error) } +// https://sqlite.org/vtab.html#xcreate type ModuleCreator interface { Module - Create(db *Conn, arg ...string) (Vtab, error) + Create(db *Conn, arg ...string) (VTabDestroyer, error) } -type ModuleShadowNamer interface { - Module - ShadowName(suffix string) bool -} - -type Vtab interface { +type VTab interface { + // https://sqlite.org/vtab.html#xbestindex BestIndex(*IndexInfo) error + // https://sqlite.org/vtab.html#xdisconnect Disconnect() error + // https://sqlite.org/vtab.html#xopen + Open() (VTabCursor, error) +} + +// https://sqlite.org/vtab.html#sqlite3_module.xDestroy +type VTabDestroyer interface { + VTab Destroy() error - Open() (VtabCursor, error) } -type VtabUpdater interface { - Vtab +// https://sqlite.org/vtab.html#xupdate +type VTabUpdater interface { + VTab Update(arg ...Value) (rowid int64, err error) } -type VtabRenamer interface { - Vtab +// https://sqlite.org/vtab.html#xrename +type VTabRenamer interface { + VTab Rename(new string) error } -type VtabOverloader interface { - Vtab - FindFunction(arg int, name string) (func(ctx Context, arg ...Value), error) +// https://sqlite.org/vtab.html#xfindfunction +type VTabOverloader interface { + VTab + FindFunction(arg int, name string) (func(ctx Context, arg ...Value), IndexConstraint) } -type VtabChecker interface { - Vtab +// https://sqlite.org/vtab.html#xintegrity +type VTabChecker interface { + VTab Integrity(schema, table string, flags int) error } -type VtabTx interface { - Vtab +type VTabTx interface { + VTab + // https://sqlite.org/vtab.html#xBegin Begin() error + // https://sqlite.org/vtab.html#xsync Sync() error + // https://sqlite.org/vtab.html#xcommit Commit() error + // https://sqlite.org/vtab.html#xrollback Rollback() error } -type VtabSavepointer interface { - VtabTx - Savepoint(n int) error - Release(n int) error - RollbackTo(n int) error +// https://sqlite.org/vtab.html#xsavepoint +type VTabSavepointer interface { + VTabTx + Savepoint(id int) error + Release(id int) error + RollbackTo(id int) error } -type VtabCursor interface { +type VTabCursor interface { + // https://sqlite.org/vtab.html#xclose Close() error - Filter(idxNum int, idxStr string, arg ...Value) + // https://sqlite.org/vtab.html#xfilter + Filter(idxNum int, idxStr string, arg ...Value) error + // https://sqlite.org/vtab.html#xnext Next() error - Eof() bool + // https://sqlite.org/vtab.html#xeof + EOF() bool + // https://sqlite.org/vtab.html#xcolumn Column(ctx *Context, n int) error - Rowid() (int64, error) + // https://sqlite.org/vtab.html#xrowid + RowID() (int64, error) } type IndexInfo struct{} + +type IndexConstraint uint8