Skip to content

Commit

Permalink
Update stmtexec.go
Browse files Browse the repository at this point in the history
  • Loading branch information
donnie4w committed Sep 26, 2024
1 parent e0bca24 commit 65c004f
Showing 1 changed file with 18 additions and 11 deletions.
29 changes: 18 additions & 11 deletions stmtexec.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,17 @@ import (
. "github.com/donnie4w/gdao/base"
"github.com/donnie4w/gdao/util"
"github.com/donnie4w/gofer/hashmap"
goutil "github.com/donnie4w/gofer/util"
"sync"
"sync/atomic"
)

var sqlWare = hashmap.NewLimitHashMap[string, *int64](1 << 19)
var stmtExec = &stmtexec{stmtMap: hashmap.NewMap[*sql.DB, *hashmap.MapL[string, *sql.Stmt]](), mux: &sync.Mutex{}}
var sqlWare = hashmap.NewLimitHashMap[uint64, *int64](1 << 19)
var stmtExec = &stmtexec{stmtMap: hashmap.NewMap[*sql.DB, *hashmap.MapL[uint64, *sql.Stmt]](), mux: &sync.Mutex{}}
var errorStmt = errors.New("")

type stmtexec struct {
stmtMap *hashmap.Map[*sql.DB, *hashmap.MapL[string, *sql.Stmt]]
stmtMap *hashmap.Map[*sql.DB, *hashmap.MapL[uint64, *sql.Stmt]]
mux *sync.Mutex
lock int64
}
Expand All @@ -48,7 +49,7 @@ func (se *stmtexec) clear(db *sql.DB) {
defer atomic.StoreInt64(&se.lock, 0)
if sm, _ := se.stmtMap.Get(db); sm != nil {
if sm.Len() >= stmtLimit {
sm.Range(func(k string, v *sql.Stmt) bool {
sm.Range(func(k uint64, v *sql.Stmt) bool {
sm.Del(k)
v.Close()
return true
Expand All @@ -58,11 +59,11 @@ func (se *stmtexec) clear(db *sql.DB) {
}
}

func (se *stmtexec) newmap(db *sql.DB) (r *hashmap.MapL[string, *sql.Stmt]) {
func (se *stmtexec) newmap(db *sql.DB) (r *hashmap.MapL[uint64, *sql.Stmt]) {
se.mux.Lock()
defer se.mux.Unlock()
if !se.stmtMap.Has(db) {
r = hashmap.NewMapL[string, *sql.Stmt]()
r = hashmap.NewMapL[uint64, *sql.Stmt]()
se.stmtMap.Put(db, r)
}
return r
Expand All @@ -73,16 +74,21 @@ func (se *stmtexec) Prepare(sqlStr string, db *sql.DB) (stmt *sql.Stmt, err erro
se.clear(db)
return stmt, errorStmt
}
var hm *hashmap.MapL[string, *sql.Stmt]
var hm *hashmap.MapL[uint64, *sql.Stmt]
var sqlhs uint64
if hm, _ = se.stmtMap.Get(db); hm != nil {
if a, b := hm.Get(sqlStr); b {
sqlhs := goutil.Hash64([]byte(sqlStr))
if a, b := hm.Get(sqlhs); b {
return a, nil
}
} else {
hm = se.newmap(db)
}
if stmt, err = db.Prepare(sqlStr); err == nil {
if p, ok := hm.Put(sqlStr, stmt); ok && p != nil {
if sqlhs == 0 {
sqlhs = goutil.Hash64([]byte(sqlStr))
}
if p, ok := hm.Put(sqlhs, stmt); ok && p != nil {
p.Close()
}
}
Expand All @@ -103,10 +109,11 @@ func (se *stmtexec) nostmt(tx *sql.Tx, db *sql.DB, sqlstr string) (b bool) {
if stmtLimit == 0 || tx != nil || se.len(db) >= stmtLimit {
return true
}
if v, ok := sqlWare.Get(sqlstr); ok {
sqlhs := goutil.Hash64([]byte(sqlstr))
if v, ok := sqlWare.Get(sqlhs); ok {
b = atomic.AddInt64(v, 1) < 16
} else {
sqlWare.Put(sqlstr, new(int64))
sqlWare.Put(sqlhs, new(int64))
b = true
}
return
Expand Down

0 comments on commit 65c004f

Please sign in to comment.