diff --git a/server/handler.go b/server/handler.go index 9a2bf66b1c..3f5d47ac66 100644 --- a/server/handler.go +++ b/server/handler.go @@ -15,6 +15,7 @@ package server import ( + goerrors "errors" "io" "net" "regexp" @@ -23,8 +24,6 @@ import ( "sync" "time" - goerrors "errors" - "github.com/dolthub/vitess/go/mysql" "github.com/dolthub/vitess/go/netutil" "github.com/dolthub/vitess/go/sqltypes" diff --git a/sql/mysql_db/mysql_db.go b/sql/mysql_db/mysql_db.go index db0b7877e7..53aa1f4b54 100644 --- a/sql/mysql_db/mysql_db.go +++ b/sql/mysql_db/mysql_db.go @@ -23,6 +23,8 @@ import ( "net" "sort" "strings" + "sync" + "time" "github.com/dolthub/vitess/go/mysql" flatbuffers "github.com/google/flatbuffers/go" @@ -237,6 +239,48 @@ func (db *MySQLDb) AddSuperUser(username string, host string, password string) { db.updateCounter++ } +type LockUserMap struct { + sync.Map +} + +func (m *LockUserMap) SetUser(readUserEntry *User, value time.Time) { + m.Set(fmt.Sprintf("%s-%s", readUserEntry.User, readUserEntry.Host), value) +} + +func (m *LockUserMap) Set(key string, value time.Time) { + m.Store(key, value) +} + +func (m *LockUserMap) GetUser(readUserEntry *User) (time.Time, bool) { + return m.Get(fmt.Sprintf("%s-%s", readUserEntry.User, readUserEntry.Host)) +} + +func (m *LockUserMap) Get(key string) (time.Time, bool) { + val, ok := m.Load(key) + if !ok { + return time.Time{}, false + } + return val.(time.Time), true +} + +func (m *LockUserMap) RemoveUser(readUserEntry *User) { + m.Delete(fmt.Sprintf("%s-%s", readUserEntry.User, readUserEntry.Host)) +} + +func (m *LockUserMap) Remove(key string) { + m.Delete(key) +} + +var lockUserMap = &LockUserMap{} + +func (db *MySQLDb) LockUser(readUserEntry *User) { + if readUserEntry == nil { + return + } else { + lockUserMap.SetUser(readUserEntry, time.Now()) + } +} + // GetUser returns a user matching the given user and host if it exists. Due to the slight difference between users and // roles, roleSearch changes whether the search matches against user or role rules. func (db *MySQLDb) GetUser(user string, host string, roleSearch bool) *User { @@ -258,7 +302,31 @@ func (db *MySQLDb) GetUser(user string, host string, roleSearch bool) *User { }) if len(userEntries) == 1 { - return userEntries[0].(*User) + readUserEntry := userEntries[0].(*User) + + if lockTime, isLocked := lockUserMap.GetUser(readUserEntry); isLocked { + if time.Since(lockTime) > time.Hour { + readUserEntry.Locked = false + lockUserMap.RemoveUser(readUserEntry) + } else { + readUserEntry.Locked = true + return readUserEntry + } + } + if strings.Contains(readUserEntry.Host, "/") { + _, network, cidrParseErr := net.ParseCIDR(readUserEntry.Host) + if cidrParseErr == nil { + hostIp := net.ParseIP(host) + if hostIp != nil && network.Contains(hostIp) { + return readUserEntry + } else { + return nil + } + } else { + return nil + } + } + return readUserEntry } // First we check for matches on the same user, then we try the anonymous user @@ -268,6 +336,16 @@ func (db *MySQLDb) GetUser(user string, host string, roleSearch bool) *User { }) for _, readUserEntry := range userEntries { readUserEntry := readUserEntry.(*User) + + if lockTime, isLocked := lockUserMap.GetUser(readUserEntry); isLocked { + if time.Since(lockTime) > time.Hour { + readUserEntry.Locked = false + lockUserMap.RemoveUser(readUserEntry) + } else { + readUserEntry.Locked = true + return readUserEntry + } + } if strings.Contains(readUserEntry.Host, "/") { _, network, cidrParseErr := net.ParseCIDR(readUserEntry.Host) if cidrParseErr == nil { @@ -445,6 +523,7 @@ func (db *MySQLDb) ValidateHash(salt []byte, user string, authResponse []byte, a } if len(userEntry.Password) > 0 { if !validateMysqlNativePassword(authResponse, salt, userEntry.Password) { + db.LockUser(userEntry) return nil, mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "Access denied for user '%v'", user) } } else if len(authResponse) > 0 { // password is nil or empty, therefore no password is set