Skip to content

Commit

Permalink
Merge pull request #11 from trimble-oss/cidr_check_not_grants
Browse files Browse the repository at this point in the history
Only check cidr in query auth scenario.
  • Loading branch information
joel-rieke authored Mar 7, 2024
2 parents 6bc5051 + 820714b commit f08eaec
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 26 deletions.
2 changes: 1 addition & 1 deletion sql/analyzer/privileges.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func validatePrivileges(ctx *sql.Context, a *Analyzer, n sql.Node, scope *Scope,
}

client := ctx.Session.Client()
user := mysqlDb.GetUser(client.User, client.Address, false)
user := mysqlDb.GetUser(client.User, client.Address, false, false)
if user == nil {
return nil, transform.SameTree, mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "Access denied for user '%v'", ctx.Session.Client().User)
}
Expand Down
21 changes: 12 additions & 9 deletions sql/mysql_db/mysql_db.go
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ func (db *MySQLDb) LockUser(readUserEntry *User) {

// GetUser returns a user matching the given user and host if it exists. Due to the slight difference between users and
// roles, roleSearch changes whether the search matches against user or role rules.
func (db *MySQLDb) GetUser(user string, host string, roleSearch bool) *User {
func (db *MySQLDb) GetUser(user string, host string, roleSearch bool, skipCidrChecks bool) *User {
//TODO: Determine what the localhost is on the machine, then handle the conversion between IP and localhost.
// For now, this just treats localhost and 127.0.0.1 as the same.
//TODO: Determine how to match anonymous roles (roles with an empty user string), which differs from users
Expand All @@ -303,6 +303,9 @@ func (db *MySQLDb) GetUser(user string, host string, roleSearch bool) *User {

if len(userEntries) == 1 {
readUserEntry := userEntries[0].(*User)
if skipCidrChecks {
return readUserEntry
}

if lockTime, isLocked := lockUserMap.GetUser(readUserEntry); isLocked {
if time.Since(lockTime) > time.Hour {
Expand All @@ -321,9 +324,6 @@ func (db *MySQLDb) GetUser(user string, host string, roleSearch bool) *User {
if hostIp != nil && network.Contains(hostIp) {
return readUserEntry
} else {
if readUserEntry.IsSuperUser {
return readUserEntry
}
return nil
}
} else {
Expand All @@ -340,6 +340,9 @@ func (db *MySQLDb) GetUser(user string, host string, roleSearch bool) *User {
})
for _, readUserEntry := range userEntries {
readUserEntry := readUserEntry.(*User)
if skipCidrChecks {
return readUserEntry
}

if lockTime, isLocked := lockUserMap.GetUser(readUserEntry); isLocked {
if time.Since(lockTime) > time.Hour {
Expand Down Expand Up @@ -381,7 +384,7 @@ func (db *MySQLDb) UserActivePrivilegeSet(ctx *sql.Context) PrivilegeSet {
}

client := ctx.Session.Client()
user := db.GetUser(client.User, client.Address, false)
user := db.GetUser(client.User, client.Address, false, false)
if user == nil {
return NewPrivilegeSet()
}
Expand All @@ -395,7 +398,7 @@ func (db *MySQLDb) UserActivePrivilegeSet(ctx *sql.Context) PrivilegeSet {
//TODO: System variable "activate_all_roles_on_login", if set, will set all roles as active upon logging in
for _, roleEdgeEntry := range roleEdgeEntries {
roleEdge := roleEdgeEntry.(*RoleEdge)
role := db.GetUser(roleEdge.FromUser, roleEdge.FromHost, true)
role := db.GetUser(roleEdge.FromUser, roleEdge.FromHost, true, false)
if role != nil {
privSet.UnionWith(role.PrivilegeSet)
}
Expand Down Expand Up @@ -489,7 +492,7 @@ func (db *MySQLDb) AuthMethod(user, addr string) (string, error) {
host = splitHost
}

u := db.GetUser(user, host, false)
u := db.GetUser(user, host, false, false)
if u == nil {
return "", mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "User not found '%v'", user)
}
Expand Down Expand Up @@ -521,7 +524,7 @@ func (db *MySQLDb) ValidateHash(salt []byte, user string, authResponse []byte, a
return MysqlConnectionUser{User: user, Host: host}, nil
}

userEntry := db.GetUser(user, host, false)
userEntry := db.GetUser(user, host, false, false)
if userEntry == nil || userEntry.Locked {
return nil, mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "Access denied for user '%v'", user)
}
Expand Down Expand Up @@ -555,7 +558,7 @@ func (db *MySQLDb) Negotiate(c *mysql.Conn, user string, addr net.Addr) (mysql.G
if !db.Enabled {
return connUser, nil
}
userEntry := db.GetUser(user, host, false)
userEntry := db.GetUser(user, host, false, false)

if userEntry.Plugin != "" {
authplugin, ok := db.plugins[userEntry.Plugin]
Expand Down
2 changes: 1 addition & 1 deletion sql/plan/drop_user.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func (n *DropUser) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error) {
userTableData := mysqlDb.UserTable().Data()
roleEdgesData := mysqlDb.RoleEdgesTable().Data()
for _, user := range n.Users {
existingUser := mysqlDb.GetUser(user.Name, user.Host, false)
existingUser := mysqlDb.GetUser(user.Name, user.Host, false, true)
if existingUser == nil {
if n.IfExists {
continue
Expand Down
14 changes: 7 additions & 7 deletions sql/plan/grant.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ func (n *Grant) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error) {
return nil, fmt.Errorf("GRANT has not yet implemented user assumption")
}
for _, grantUser := range n.Users {
user := mysqlDb.GetUser(grantUser.Name, grantUser.Host, false)
user := mysqlDb.GetUser(grantUser.Name, grantUser.Host, false, true)
if user == nil {
return nil, sql.ErrGrantUserDoesNotExist.New()
}
Expand All @@ -235,7 +235,7 @@ func (n *Grant) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error) {
return nil, fmt.Errorf("GRANT has not yet implemented user assumption")
}
for _, grantUser := range n.Users {
user := mysqlDb.GetUser(grantUser.Name, grantUser.Host, false)
user := mysqlDb.GetUser(grantUser.Name, grantUser.Host, false, true)
if user == nil {
return nil, sql.ErrGrantUserDoesNotExist.New()
}
Expand All @@ -262,7 +262,7 @@ func (n *Grant) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error) {
return nil, fmt.Errorf("GRANT has not yet implemented user assumption")
}
for _, grantUser := range n.Users {
user := mysqlDb.GetUser(grantUser.Name, grantUser.Host, false)
user := mysqlDb.GetUser(grantUser.Name, grantUser.Host, false, true)
if user == nil {
return nil, sql.ErrGrantUserDoesNotExist.New()
}
Expand Down Expand Up @@ -638,7 +638,7 @@ func (n *GrantRole) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOp
//TODO: only active roles may be assigned if the SUPER privilege is not held
mysqlDb := n.MySQLDb.(*mysql_db.MySQLDb)
client := ctx.Session.Client()
user := mysqlDb.GetUser(client.User, client.Address, false)
user := mysqlDb.GetUser(client.User, client.Address, false, true)
if user == nil {
return false
}
Expand All @@ -647,7 +647,7 @@ func (n *GrantRole) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOp
ToUser: user.User,
})
for _, roleName := range n.Roles {
role := mysqlDb.GetUser(roleName.Name, roleName.Host, true)
role := mysqlDb.GetUser(roleName.Name, roleName.Host, true, true)
if role == nil {
return false
}
Expand Down Expand Up @@ -677,12 +677,12 @@ func (n *GrantRole) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error)
}
roleEdgesData := mysqlDb.RoleEdgesTable().Data()
for _, targetUser := range n.TargetUsers {
user := mysqlDb.GetUser(targetUser.Name, targetUser.Host, false)
user := mysqlDb.GetUser(targetUser.Name, targetUser.Host, false, true)
if user == nil {
return nil, sql.ErrGrantRevokeRoleDoesNotExist.New(targetUser.String("`"))
}
for _, targetRole := range n.Roles {
role := mysqlDb.GetUser(targetRole.Name, targetRole.Host, true)
role := mysqlDb.GetUser(targetRole.Name, targetRole.Host, true, true)
if role == nil {
return nil, sql.ErrGrantRevokeRoleDoesNotExist.New(targetRole.String("`"))
}
Expand Down
14 changes: 7 additions & 7 deletions sql/plan/revoke.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ func (n *Revoke) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error) {
return nil, sql.ErrGrantRevokeIllegalPrivilege.New()
}
for _, revokeUser := range n.Users {
user := mysqlDb.GetUser(revokeUser.Name, revokeUser.Host, false)
user := mysqlDb.GetUser(revokeUser.Name, revokeUser.Host, false, true)
if user == nil {
return nil, sql.ErrGrantUserDoesNotExist.New()
}
Expand All @@ -222,7 +222,7 @@ func (n *Revoke) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error) {
return nil, sql.ErrGrantRevokeIllegalPrivilege.New()
}
for _, revokeUser := range n.Users {
user := mysqlDb.GetUser(revokeUser.Name, revokeUser.Host, false)
user := mysqlDb.GetUser(revokeUser.Name, revokeUser.Host, false, true)
if user == nil {
return nil, sql.ErrGrantUserDoesNotExist.New()
}
Expand All @@ -243,7 +243,7 @@ func (n *Revoke) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error) {
return nil, fmt.Errorf("GRANT has not yet implemented object types")
}
for _, grantUser := range n.Users {
user := mysqlDb.GetUser(grantUser.Name, grantUser.Host, false)
user := mysqlDb.GetUser(grantUser.Name, grantUser.Host, false, true)
if user == nil {
return nil, sql.ErrGrantUserDoesNotExist.New()
}
Expand Down Expand Up @@ -590,7 +590,7 @@ func (n *RevokeRole) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedO
//TODO: only active roles may be revoked if the SUPER privilege is not held
mysqlDb := n.MySQLDb.(*mysql_db.MySQLDb)
client := ctx.Session.Client()
user := mysqlDb.GetUser(client.User, client.Address, false)
user := mysqlDb.GetUser(client.User, client.Address, false, true)
if user == nil {
return false
}
Expand All @@ -599,7 +599,7 @@ func (n *RevokeRole) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedO
ToUser: user.User,
})
for _, roleName := range n.Roles {
role := mysqlDb.GetUser(roleName.Name, roleName.Host, true)
role := mysqlDb.GetUser(roleName.Name, roleName.Host, true, true)
if role == nil {
return false
}
Expand Down Expand Up @@ -629,12 +629,12 @@ func (n *RevokeRole) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error)
}
roleEdgesData := mysqlDb.RoleEdgesTable().Data()
for _, targetUser := range n.TargetUsers {
user := mysqlDb.GetUser(targetUser.Name, targetUser.Host, false)
user := mysqlDb.GetUser(targetUser.Name, targetUser.Host, false, true)
if user == nil {
return nil, sql.ErrGrantRevokeRoleDoesNotExist.New(targetUser.String("`"))
}
for _, targetRole := range n.Roles {
role := mysqlDb.GetUser(targetRole.Name, targetRole.Host, true)
role := mysqlDb.GetUser(targetRole.Name, targetRole.Host, true, true)
if role == nil {
return nil, sql.ErrGrantRevokeRoleDoesNotExist.New(targetRole.String("`"))
}
Expand Down
2 changes: 1 addition & 1 deletion sql/plan/show_grants.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ func (n *ShowGrants) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error)
Host: client.Address,
}
}
user := mysqlDb.GetUser(n.For.Name, n.For.Host, false)
user := mysqlDb.GetUser(n.For.Name, n.For.Host, false, true)
if user == nil {
return nil, sql.ErrShowGrantsUserDoesNotExist.New(n.For.Name, n.For.Host)
}
Expand Down

0 comments on commit f08eaec

Please sign in to comment.