From 65c004f4b15c18fb6da4ae6d37b0e2a4d170c934 Mon Sep 17 00:00:00 2001 From: donnie4w Date: Thu, 26 Sep 2024 12:20:10 +0800 Subject: [PATCH] Update stmtexec.go --- stmtexec.go | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/stmtexec.go b/stmtexec.go index 03408f3..cfffbf4 100644 --- a/stmtexec.go +++ b/stmtexec.go @@ -13,16 +13,17 @@ import ( . "github.com/donnie4w/gdao/base" "github.com/donnie4w/gdao/util" "github.com/donnie4w/gofer/hashmap" + goutil "github.com/donnie4w/gofer/util" "sync" "sync/atomic" ) -var sqlWare = hashmap.NewLimitHashMap[string, *int64](1 << 19) -var stmtExec = &stmtexec{stmtMap: hashmap.NewMap[*sql.DB, *hashmap.MapL[string, *sql.Stmt]](), mux: &sync.Mutex{}} +var sqlWare = hashmap.NewLimitHashMap[uint64, *int64](1 << 19) +var stmtExec = &stmtexec{stmtMap: hashmap.NewMap[*sql.DB, *hashmap.MapL[uint64, *sql.Stmt]](), mux: &sync.Mutex{}} var errorStmt = errors.New("") type stmtexec struct { - stmtMap *hashmap.Map[*sql.DB, *hashmap.MapL[string, *sql.Stmt]] + stmtMap *hashmap.Map[*sql.DB, *hashmap.MapL[uint64, *sql.Stmt]] mux *sync.Mutex lock int64 } @@ -48,7 +49,7 @@ func (se *stmtexec) clear(db *sql.DB) { defer atomic.StoreInt64(&se.lock, 0) if sm, _ := se.stmtMap.Get(db); sm != nil { if sm.Len() >= stmtLimit { - sm.Range(func(k string, v *sql.Stmt) bool { + sm.Range(func(k uint64, v *sql.Stmt) bool { sm.Del(k) v.Close() return true @@ -58,11 +59,11 @@ func (se *stmtexec) clear(db *sql.DB) { } } -func (se *stmtexec) newmap(db *sql.DB) (r *hashmap.MapL[string, *sql.Stmt]) { +func (se *stmtexec) newmap(db *sql.DB) (r *hashmap.MapL[uint64, *sql.Stmt]) { se.mux.Lock() defer se.mux.Unlock() if !se.stmtMap.Has(db) { - r = hashmap.NewMapL[string, *sql.Stmt]() + r = hashmap.NewMapL[uint64, *sql.Stmt]() se.stmtMap.Put(db, r) } return r @@ -73,16 +74,21 @@ func (se *stmtexec) Prepare(sqlStr string, db *sql.DB) (stmt *sql.Stmt, err erro se.clear(db) return stmt, errorStmt } - var hm *hashmap.MapL[string, *sql.Stmt] + var hm *hashmap.MapL[uint64, *sql.Stmt] + var sqlhs uint64 if hm, _ = se.stmtMap.Get(db); hm != nil { - if a, b := hm.Get(sqlStr); b { + sqlhs := goutil.Hash64([]byte(sqlStr)) + if a, b := hm.Get(sqlhs); b { return a, nil } } else { hm = se.newmap(db) } if stmt, err = db.Prepare(sqlStr); err == nil { - if p, ok := hm.Put(sqlStr, stmt); ok && p != nil { + if sqlhs == 0 { + sqlhs = goutil.Hash64([]byte(sqlStr)) + } + if p, ok := hm.Put(sqlhs, stmt); ok && p != nil { p.Close() } } @@ -103,10 +109,11 @@ func (se *stmtexec) nostmt(tx *sql.Tx, db *sql.DB, sqlstr string) (b bool) { if stmtLimit == 0 || tx != nil || se.len(db) >= stmtLimit { return true } - if v, ok := sqlWare.Get(sqlstr); ok { + sqlhs := goutil.Hash64([]byte(sqlstr)) + if v, ok := sqlWare.Get(sqlhs); ok { b = atomic.AddInt64(v, 1) < 16 } else { - sqlWare.Put(sqlstr, new(int64)) + sqlWare.Put(sqlhs, new(int64)) b = true } return