diff --git a/cadence/server/db_redis.go b/cadence/server/db_redis.go index cd377854..7676991d 100644 --- a/cadence/server/db_redis.go +++ b/cadence/server/db_redis.go @@ -5,6 +5,7 @@ package main import ( "context" + "net" "net/http" "time" @@ -28,8 +29,12 @@ func redisInit() { func rateLimit(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ip := r.RemoteAddr - _, err := dbr.RateLimit.Get(ctx, ip).Result() + ip, err := checkIP(r) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) // 500 Internal Server Error + return + } + _, err = dbr.RateLimit.Get(ctx, ip).Result() if err != nil { if err == redis.Nil { // redis.Nil means the IP is not in the database. @@ -47,3 +52,16 @@ func rateLimit(next http.Handler) http.Handler { } }) } + +func checkIP(r *http.Request) (ip string, err error) { + // We look at the remote address and check the IP. + // If for some reason no remote IP is there, we error to reject. + if r.RemoteAddr != "" { + ip, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil || ip == "" { + return "", err + } + return ip, nil + } + return "", err +}