Skip to content

Commit

Permalink
Merge pull request #195 from wubin1989/main
Browse files Browse the repository at this point in the history
...
  • Loading branch information
wubin1989 authored Apr 20, 2024
2 parents aed011e + d4a213e commit e456925
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 40 deletions.
12 changes: 7 additions & 5 deletions framework/database/database.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package database

import (
gocache "github.com/eko/gocache/lib/v4/cache"
"log"
"os"
"strings"
"time"

gocache "github.com/eko/gocache/lib/v4/cache"
"gorm.io/driver/clickhouse"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
Expand Down Expand Up @@ -283,9 +283,12 @@ func NewDb(conf config.Config) (db *gorm.DB) {
case DriverPostgres:
collectors = append(collectors, &prometheus.Postgres{})
}
ConfigureMetrics(Db, conf.Db.Prometheus.DBName, uint32(conf.Db.Prometheus.RefreshInterval),
ConfigureMetrics(db, conf.Db.Prometheus.DBName, uint32(conf.Db.Prometheus.RefreshInterval),
nil, collectors...)
}
if conf.Db.Cache.Enable && stringutils.IsNotEmpty(conf.Cache.Stores) {
ConfigureDBCache(db, cache.NewCacheManager(conf))
}
return
}

Expand All @@ -299,9 +302,8 @@ func ConfigureMetrics(db *gorm.DB, dbName string, refreshInterval uint32, labels
}

func ConfigureDBCache(db *gorm.DB, cacheManager gocache.CacheInterface[any]) {
cachesPlugin := &caches.Caches{Conf: &caches.Config{
db.Use(&caches.Caches{Conf: &caches.Config{
Easer: true,
Cacher: NewCacherAdapter(cacheManager),
}}
db.Use(cachesPlugin)
}})
}
88 changes: 54 additions & 34 deletions toolkit/caches/caches.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@ import (
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/callbacks"
"strings"
"sync"
)

type Caches struct {
Conf *Config

queue *sync.Map
queryCb func(*gorm.DB)
Conf *Config
queue *sync.Map
}

type Config struct {
Expand All @@ -43,9 +42,9 @@ func (c *Caches) Initialize(db *gorm.DB) error {
c.queue = &sync.Map{}
}

c.queryCb = db.Callback().Query().Get("gorm:query")
callback := db.Callback().Query().Get("gorm:query")

err := db.Callback().Query().Replace("gorm:query", c.Query)
err := db.Callback().Query().Replace("gorm:query", c.Query(callback))
if err != nil {
return err
}
Expand All @@ -65,33 +64,41 @@ func (c *Caches) Initialize(db *gorm.DB) error {
return err
}

err = db.Callback().Raw().After("gorm:raw").Register("cache:after_raw", c.AfterWrite)
if err != nil {
return err
}

return nil
}

func (c *Caches) Query(db *gorm.DB) {
if c.Conf.Easer == false && c.Conf.Cacher == nil {
c.queryCb(db)
return
}
func (c *Caches) Query(callback func(*gorm.DB)) func(*gorm.DB) {
return func(db *gorm.DB) {
if c.Conf.Easer == false && c.Conf.Cacher == nil {
callback(db)
return
}

identifier := buildIdentifier(db)
identifier := buildIdentifier(db)

if db.DryRun {
return
}
if db.DryRun {
return
}

if c.checkCache(db, identifier) {
return
}
if res, ok := c.checkCache(identifier); ok {
res.replaceOn(db)
return
}

c.ease(db, identifier)
if db.Error != nil {
return
}
c.ease(db, identifier, callback)
if db.Error != nil {
return
}

c.storeInCache(db, identifier)
if db.Error != nil {
return
c.storeInCache(db, identifier)
if db.Error != nil {
return
}
}
}

Expand All @@ -114,16 +121,16 @@ func (c *Caches) AfterWrite(db *gorm.DB) {
}
}

func (c *Caches) ease(db *gorm.DB, identifier string) {
func (c *Caches) ease(db *gorm.DB, identifier string, callback func(*gorm.DB)) {
if c.Conf.Easer == false {
c.queryCb(db)
callback(db)
return
}

res := ease(&queryTask{
id: identifier,
db: db,
queryCb: c.queryCb,
queryCb: callback,
}, c.queue).(*queryTask)

if db.Error != nil {
Expand All @@ -141,17 +148,17 @@ func (c *Caches) ease(db *gorm.DB, identifier string) {
q.replaceOn(res.db)
}

func (c *Caches) checkCache(db *gorm.DB, identifier string) bool {
func (c *Caches) checkCache(identifier string) (res *Query, ok bool) {
if c.Conf.Cacher != nil {
if res := c.Conf.Cacher.Get(identifier); res != nil {
res.replaceOn(db)
return true
if res = c.Conf.Cacher.Get(identifier); res != nil {
return res, true
}
}
return false
return nil, false
}

func getTables(db *gorm.DB) []string {
callbacks.BuildQuerySQL(db)
switch db.Dialector.(type) {
case *mysql.Dialector:
return getTablesMysql(db)
Expand Down Expand Up @@ -189,7 +196,20 @@ func getTablesPostgres(db *gorm.DB) []string {
//log.Printf("%T", node)
switch expr := node.(type) {
case *tree.TableName:
tableNames = append(tableNames, expr.Table())
var sb strings.Builder
fmtCtx := tree.NewFmtCtx(tree.FmtSimple)
expr.TableNamePrefix.Format(fmtCtx)
sb.WriteString(fmtCtx.String())

if sb.Len() > 0 {
sb.WriteString(".")
}

fmtCtx = tree.NewFmtCtx(tree.FmtSimple)
expr.TableName.Format(fmtCtx)
sb.WriteString(fmtCtx.String())

tableNames = append(tableNames, sb.String())
case *tree.Insert:
fmtCtx := tree.NewFmtCtx(tree.FmtSimple)
expr.Table.Format(fmtCtx)
Expand Down
2 changes: 1 addition & 1 deletion toolkit/gormgen/do.go
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,7 @@ func (d *DO) Rows() (*sql.Rows, error) {

// Scan ...
func (d *DO) Scan(dest interface{}) error {
return d.db.Model(d.newResultPointer()).Scan(dest).Error
return d.db.Model(d.newResultPointer()).Find(dest).Error
}

// Pluck ...
Expand Down

0 comments on commit e456925

Please sign in to comment.