Skip to content

Commit

Permalink
feat(stmt):added PreparedStmt support in Tx (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
cnlangzi authored Feb 13, 2024
1 parent e7e3244 commit d385ed3
Show file tree
Hide file tree
Showing 7 changed files with 374 additions and 45 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- added custom `Binder` support on `Bind` (#4)
- added `Select` and `Delete` on `SQLBuilder` (#6)
- added `PreparedStmt` support on `Query` and `Exec` (#7)
- added `PreparedStmt` support on `Tx` (#8)
### Fixed
- fixed `sql.Scanner` support on `Bind` (#2)

Expand Down
51 changes: 32 additions & 19 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,29 @@ package sqle
import (
"context"
"database/sql"
"sync"

"github.com/rs/zerolog/log"
)

type DB struct {
noCopy
*sql.DB
noCopy //nolint

sync.Mutex
stmts map[string]*cachedStmt
stmtsMutex sync.RWMutex
}

func Open(db *sql.DB) *DB {
return &DB{
DB: db,
d := &DB{
DB: db,
stmts: make(map[string]*cachedStmt),
}

go d.closeIdleStmt()

return d
}

func (db *DB) Query(query string, args ...any) (*Rows, error) {
Expand All @@ -36,7 +46,7 @@ func (db *DB) QueryContext(ctx context.Context, query string, args ...any) (*Row
var stmt *sql.Stmt
var err error
if len(args) > 0 {
stmt, err = prepareStmt(ctx, db.DB, query)
stmt, err = db.prepareStmt(ctx, query)
if err == nil {
rows, err = stmt.QueryContext(ctx, args...)
if err != nil {
Expand Down Expand Up @@ -76,15 +86,22 @@ func (db *DB) QueryRowContext(ctx context.Context, query string, args ...any) *R
var err error

if len(args) > 0 {
stmt, err = prepareStmt(ctx, db.DB, query)
if err == nil {
rows, err = stmt.QueryContext(ctx, args...)
stmt, err = db.prepareStmt(ctx, query)
if err != nil {
return &Row{
err: err,
query: query,
}
}
rows, err = stmt.QueryContext(ctx, args...)
return &Row{
rows: rows,
err: err,
query: query,
}

} else {
rows, err = db.DB.QueryContext(ctx, query, args...)
}

rows, err = db.DB.QueryContext(ctx, query, args...)
return &Row{
rows: rows,
err: err,
Expand All @@ -107,7 +124,7 @@ func (db *DB) ExecBuilder(ctx context.Context, b *Builder) (sql.Result, error) {

func (db *DB) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
if len(args) > 0 {
stmt, err := prepareStmt(ctx, db.DB, query)
stmt, err := db.prepareStmt(ctx, query)
if err != nil {
return nil, err
}
Expand All @@ -118,12 +135,8 @@ func (db *DB) ExecContext(ctx context.Context, query string, args ...any) (sql.R
}

func (db *DB) Begin(opts *sql.TxOptions) (*Tx, error) {
tx, err := db.DB.BeginTx(context.Background(), opts)
if err != nil {
return nil, err
}
return db.BeginTx(context.TODO(), opts)

return &Tx{Tx: tx}, nil
}

func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) {
Expand All @@ -132,16 +145,16 @@ func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) {
return nil, err
}

return &Tx{Tx: tx}, nil
return &Tx{Tx: tx, cachedStmts: make(map[string]*sql.Stmt)}, nil
}

func (db *DB) Transaction(ctx context.Context, opts *sql.TxOptions, fn func(tx *Tx) error) error {
func (db *DB) Transaction(ctx context.Context, opts *sql.TxOptions, fn func(ctx context.Context, tx *Tx) error) error {
tx, err := db.BeginTx(ctx, opts)
if err != nil {
return err
}

err = fn(tx)
err = fn(ctx, tx)
if err != nil {
if e := tx.Rollback(); e != nil {
log.Error().Str("pkg", "sqle").Str("tag", "tx").Err(e)
Expand Down
32 changes: 32 additions & 0 deletions db_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package sqle

import (
"database/sql"
"os"

_ "github.com/mattn/go-sqlite3"
)

func createSQLite3() (*sql.DB, func(), error) {
f, err := os.CreateTemp(".", "*.db")
f.Close()

clean := func() {
os.Remove(f.Name()) //nolint
}

if err != nil {
return nil, clean, err
}

//db, err := sql.Open("sqlite3", "file:"+f.Name()+"?cache=shared")

db, err := sql.Open("sqlite3", "file::memory:")

if err != nil {
return nil, clean, err
}
//https://github.com/mattn/go-sqlite3/issues/209
// db.SetMaxOpenConns(1)
return db, clean, nil
}
2 changes: 1 addition & 1 deletion migrate/migrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ func (m *Migrator) Init(ctx context.Context) error {
func (m *Migrator) Migrate(ctx context.Context) error {
var err error
for _, v := range m.Versions {
err = m.db.Transaction(ctx, nil, func(tx *sqle.Tx) error {
err = m.db.Transaction(ctx, nil, func(ctx context.Context, tx *sqle.Tx) error {

var checksum string

Expand Down
35 changes: 13 additions & 22 deletions stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,59 +7,50 @@ import (
"time"
)

var (
stmts = make(map[string]*cachedStmt)
stmtsMutex sync.RWMutex
)

type cachedStmt struct {
sync.Mutex
stmt *sql.Stmt
lastUsed time.Time
}

func prepareStmt(ctx context.Context, db *sql.DB, query string) (*sql.Stmt, error) {
stmtsMutex.RLock()
s, ok := stmts[query]
stmtsMutex.RUnlock()
func (db *DB) prepareStmt(ctx context.Context, query string) (*sql.Stmt, error) {
db.stmtsMutex.RLock()
s, ok := db.stmts[query]
db.stmtsMutex.RUnlock()
if ok {
s.Lock()
s.lastUsed = time.Now()
s.Unlock()
return s.stmt, nil
}

stmt, err := db.PrepareContext(ctx, query)
stmt, err := db.DB.PrepareContext(ctx, query)
if err != nil {
return nil, err
}

stmtsMutex.Lock()
stmts[query] = &cachedStmt{
db.stmtsMutex.Lock()
db.stmts[query] = &cachedStmt{
stmt: stmt,
lastUsed: time.Now(),
}
stmtsMutex.Unlock()
db.stmtsMutex.Unlock()

return stmt, nil
}

func init() {
go releaseCachedStmt()
}

func releaseCachedStmt() {
func (db *DB) closeIdleStmt() {
for {
<-time.After(1 * time.Minute)

stmtsMutex.Lock()
db.stmtsMutex.Lock()
lastActive := time.Now().Add(-1 * time.Minute)
for k, v := range stmts {
for k, v := range db.stmts {
if v.lastUsed.Before(lastActive) {
delete(stmts, k)
delete(db.stmts, k)
go v.stmt.Close() //nolint: errcheck
}
}
stmtsMutex.Unlock()
db.stmtsMutex.Unlock()
}
}
89 changes: 86 additions & 3 deletions tx.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,33 @@ import (

type Tx struct {
*sql.Tx
noCopy
noCopy //nolint
cachedStmts map[string]*sql.Stmt
}

func (tx *Tx) prepareStmt(ctx context.Context, query string) (*sql.Stmt, error) {
if tx.cachedStmts == nil {
tx.cachedStmts = make(map[string]*sql.Stmt)
}
s, ok := tx.cachedStmts[query]
if ok {
return s, nil
}

s, err := tx.Tx.PrepareContext(ctx, query)
if err != nil {
return nil, err
}

tx.cachedStmts[query] = s

return s, nil
}

func (tx *Tx) closeStmt() {
for _, stmt := range tx.cachedStmts {
stmt.Close()
}
}

func (tx *Tx) Query(query string, args ...any) (*Rows, error) {
Expand All @@ -24,6 +50,20 @@ func (tx *Tx) QueryBuilder(b *Builder) (*Rows, error) {
}

func (tx *Tx) QueryContext(ctx context.Context, query string, args ...any) (*Rows, error) {

if len(args) > 0 {
stmt, err := tx.prepareStmt(ctx, query)
if err != nil {
return nil, err
}

rows, err := stmt.QueryContext(ctx, args...)
if err != nil {
return nil, err
}
return &Rows{Rows: rows, query: query}, nil
}

rows, err := tx.Tx.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
Expand All @@ -48,6 +88,29 @@ func (tx *Tx) QueryRowBuilder(b *Builder) *Row {
}

func (tx *Tx) QueryRowContext(ctx context.Context, query string, args ...any) *Row {

if len(args) > 0 {
stmt, err := tx.prepareStmt(ctx, query)
if err != nil {
return &Row{
err: err,
query: query,
}
}

rows, err := stmt.QueryContext(ctx, args...)
if err != nil {
return &Row{
err: err,
query: query,
}
}
return &Row{
rows: rows,
query: query,
}
}

rows, err := tx.Tx.QueryContext(ctx, query, args...)

return &Row{
Expand All @@ -58,17 +121,37 @@ func (tx *Tx) QueryRowContext(ctx context.Context, query string, args ...any) *R
}

func (tx *Tx) Exec(query string, args ...any) (sql.Result, error) {
return tx.Tx.ExecContext(context.Background(), query, args...)
return tx.ExecContext(context.Background(), query, args...)
}

func (tx *Tx) ExecBuilder(ctx context.Context, b *Builder) (sql.Result, error) {
query, args, err := b.Build()
if err != nil {
return nil, err
}
return tx.Tx.ExecContext(ctx, query, args...)
return tx.ExecContext(ctx, query, args...)
}

func (tx *Tx) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {

if len(args) > 0 {
stmt, err := tx.prepareStmt(ctx, query)
if err != nil {
return nil, err
}

return stmt.ExecContext(ctx, args...)
}

return tx.Tx.ExecContext(context.Background(), query, args...)
}

func (tx *Tx) Rollback() error {
defer tx.closeStmt()
return tx.Tx.Rollback()
}

func (tx *Tx) Commit() error {
defer tx.closeStmt()
return tx.Tx.Commit()
}
Loading

0 comments on commit d385ed3

Please sign in to comment.