Skip to content

Commit

Permalink
Merge pull request #14 from trimble-oss/rolling_locks
Browse files Browse the repository at this point in the history
Adding thresh-hold rolling locking.
  • Loading branch information
joel-rieke authored Mar 13, 2024
2 parents c3744ab + 5a3d366 commit 663421f
Showing 1 changed file with 27 additions and 15 deletions.
42 changes: 27 additions & 15 deletions sql/mysql_db/mysql_db.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,32 +243,41 @@ type LockUserMap struct {
sync.Map
}

func (m *LockUserMap) lockUser(user string, host string, value time.Time) {
type LockEntry struct {
LockTime time.Time
LockCount int
}

func (m *LockUserMap) lockUser(user string, host string, value *LockEntry) {
m.Set(fmt.Sprintf("%s-%s", user, host), value)
}

func (m *LockUserMap) Set(key string, value time.Time) {
func (m *LockUserMap) Set(key string, value *LockEntry) {
m.Store(key, value)
}

func (m *LockUserMap) GetUser(readUserEntry *User, host string) (time.Time, bool) {
func (m *LockUserMap) GetUser(readUserEntry *User, host string) (*LockEntry, bool) {
return m.Get(fmt.Sprintf("%s-%s", readUserEntry.User, readUserEntry.Host))
}

func (db *MySQLDb) AddUser(readUserEntry *User, host string) {
if readUserEntry == nil {
return
} else {
lockUserMap.lockUser(readUserEntry.User, host, time.Now())
if lockEntry, isLocked := lockUserMap.GetUser(readUserEntry, host); isLocked {
lockEntry.LockCount++
} else {
lockUserMap.lockUser(readUserEntry.User, host, &LockEntry{LockTime: time.Now(), LockCount: 0})
}
}
}

func (m *LockUserMap) Get(key string) (time.Time, bool) {
func (m *LockUserMap) Get(key string) (*LockEntry, bool) {
val, ok := m.Load(key)
if !ok {
return time.Time{}, false
return nil, false
}
return val.(time.Time), true
return val.(*LockEntry), true
}

func (m *LockUserMap) RemoveUser(readUserEntry *User, host string) {
Expand Down Expand Up @@ -303,13 +312,14 @@ func (db *MySQLDb) GetUser(user string, host string, roleSearch bool, skipCidrCh
return readUserEntry
}

if lockTime, isLocked := lockUserMap.GetUser(readUserEntry, host); isLocked {
if time.Since(lockTime) > time.Hour {
if lockedUser, isLocking := lockUserMap.GetUser(readUserEntry, host); isLocking {
if time.Since(lockedUser.LockTime) > time.Hour {
readUserEntry.Locked = false
lockUserMap.RemoveUser(readUserEntry, host)
} else {
readUserEntry.Locked = true
return readUserEntry
if lockedUser.LockCount > 250 {
readUserEntry.Locked = true
}
}
}

Expand Down Expand Up @@ -340,15 +350,17 @@ func (db *MySQLDb) GetUser(user string, host string, roleSearch bool, skipCidrCh
return readUserEntry
}

if lockTime, isLocked := lockUserMap.GetUser(readUserEntry, host); isLocked {
if time.Since(lockTime) > time.Hour {
if lockedUser, isLocking := lockUserMap.GetUser(readUserEntry, host); isLocking {
if time.Since(lockedUser.LockTime) > time.Hour {
readUserEntry.Locked = false
lockUserMap.RemoveUser(readUserEntry, host)
} else {
readUserEntry.Locked = true
return readUserEntry
if lockedUser.LockCount > 250 {
readUserEntry.Locked = true
}
}
}

if strings.Contains(readUserEntry.Host, "/") {
_, network, cidrParseErr := net.ParseCIDR(readUserEntry.Host)
if cidrParseErr == nil {
Expand Down

0 comments on commit 663421f

Please sign in to comment.