From 5e9404717e1c962ab729cde150ef5b40ea31a6e8 Mon Sep 17 00:00:00 2001 From: Chris Stockton <180184+cstockton@users.noreply.github.com> Date: Mon, 14 Oct 2024 02:14:17 -0700 Subject: [PATCH] feat: configurable email and sms rate limiting (#1800) Adds two new configuration values for rate limiting the sending of emails and sms messages: - GOTRUE_RATE_LIMIT_EMAIL_SENT - GOTRUE_RATE_LIMIT_SMS_SENT It is implemented with a simple rate limiter that resets a counter at a regular interval. The first intervals start time is set when the counter is initialized. It will be reset when the server is restarted, but preserved when the config is reloaded. Syntax examples: ``` 1.5 # Allow 1.5 events over 1 hour (legacy format) 100 # Allow 100 events over 1 hour (1h is default) 100/1h # Allow 100 events over 1 hour (explicit duration) 100/24h # Allow 100 events over 24 hours 100/72h # Allow 100 events over 72 hours (use hours for days) 10/30m # Allow 10 events over 30 minutes 3/10s # Allow 3 events over 10 seconds ``` Syntax in ABNF to express the format as value: ``` value = count / rate count = 1*DIGIT ["." 1*DIGIT] rate = 1*DIGIT "/" ival ival = ival-sec / ival-min / ival-hr ival-sec = 1*DIGIT "s" ival-min = 1*DIGIT "s" ival-hr = 1*DIGIT "h" ``` This change was a continuation of https://github.com/supabase/auth/pull/1746 adapted to support the recent preservation of rate limiters across server reloads. --------- Co-authored-by: Chris Stockton Co-authored-by: Stojan Dimitrovski --- internal/api/api.go | 20 ++-- internal/api/context.go | 19 ---- internal/api/mail.go | 17 ++-- internal/api/middleware.go | 21 ---- internal/api/middleware_test.go | 174 -------------------------------- internal/api/options.go | 17 +--- internal/api/phone.go | 9 +- internal/api/ratelimits.go | 49 +++++++++ internal/api/ratelimits_test.go | 99 ++++++++++++++++++ internal/conf/configuration.go | 23 +++-- internal/conf/rate.go | 58 +++++++++++ internal/conf/rate_test.go | 53 ++++++++++ 12 files changed, 292 insertions(+), 267 deletions(-) create mode 100644 internal/api/ratelimits.go create mode 100644 internal/api/ratelimits_test.go create mode 100644 internal/conf/rate.go create mode 100644 internal/conf/rate_test.go diff --git a/internal/api/api.go b/internal/api/api.go index dba93ae150..580280aa52 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -140,9 +140,8 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne r.Get("/authorize", api.ExternalProviderRedirect) - sharedLimiter := api.limitEmailOrPhoneSentHandler(api.limiterOpts) - r.With(sharedLimiter).With(api.requireAdminCredentials).Post("/invite", api.Invite) - r.With(sharedLimiter).With(api.verifyCaptcha).Route("/signup", func(r *router) { + r.With(api.requireAdminCredentials).Post("/invite", api.Invite) + r.With(api.verifyCaptcha).Route("/signup", func(r *router) { // rate limit per hour limitAnonymousSignIns := api.limiterOpts.AnonymousSignIns limitSignups := api.limiterOpts.Signups @@ -165,24 +164,20 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne if _, err := api.limitHandler(limitSignups)(w, r); err != nil { return err } - // apply shared rate limiting on email / phone - if _, err := sharedLimiter(w, r); err != nil { - return err - } return api.Signup(w, r) }) }) r.With(api.limitHandler(api.limiterOpts.Recover)). - With(sharedLimiter).With(api.verifyCaptcha).With(api.requireEmailProvider).Post("/recover", api.Recover) + With(api.verifyCaptcha).With(api.requireEmailProvider).Post("/recover", api.Recover) r.With(api.limitHandler(api.limiterOpts.Resend)). - With(sharedLimiter).With(api.verifyCaptcha).Post("/resend", api.Resend) + With(api.verifyCaptcha).Post("/resend", api.Resend) r.With(api.limitHandler(api.limiterOpts.MagicLink)). - With(sharedLimiter).With(api.verifyCaptcha).Post("/magiclink", api.MagicLink) + With(api.verifyCaptcha).Post("/magiclink", api.MagicLink) r.With(api.limitHandler(api.limiterOpts.Otp)). - With(sharedLimiter).With(api.verifyCaptcha).Post("/otp", api.Otp) + With(api.verifyCaptcha).Post("/otp", api.Otp) r.With(api.limitHandler(api.limiterOpts.Token)). With(api.verifyCaptcha).Post("/token", api.Token) @@ -200,8 +195,7 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne r.With(api.requireAuthentication).Route("/user", func(r *router) { r.Get("/", api.UserGet) - r.With(api.limitHandler(api.limiterOpts.User)). - With(sharedLimiter).Put("/", api.UserUpdate) + r.With(api.limitHandler(api.limiterOpts.User)).Put("/", api.UserUpdate) r.Route("/identities", func(r *router) { r.Use(api.requireManualLinkingEnabled) diff --git a/internal/api/context.go b/internal/api/context.go index ff01e71204..3047f3dd6a 100644 --- a/internal/api/context.go +++ b/internal/api/context.go @@ -4,7 +4,6 @@ import ( "context" "net/url" - "github.com/didip/tollbooth/v5/limiter" jwt "github.com/golang-jwt/jwt/v5" "github.com/supabase/auth/internal/models" ) @@ -32,7 +31,6 @@ const ( ssoProviderKey = contextKey("sso_provider") externalHostKey = contextKey("external_host") flowStateKey = contextKey("flow_state_id") - sharedLimiterKey = contextKey("shared_limiter") ) // withToken adds the JWT token to the context. @@ -243,20 +241,3 @@ func getExternalHost(ctx context.Context) *url.URL { } return obj.(*url.URL) } - -type SharedLimiter struct { - EmailLimiter *limiter.Limiter - PhoneLimiter *limiter.Limiter -} - -func withLimiter(ctx context.Context, limiter *SharedLimiter) context.Context { - return context.WithValue(ctx, sharedLimiterKey, limiter) -} - -func getLimiter(ctx context.Context) *SharedLimiter { - obj := ctx.Value(sharedLimiterKey) - if obj == nil { - return nil - } - return obj.(*SharedLimiter) -} diff --git a/internal/api/mail.go b/internal/api/mail.go index 696510862f..440e04c90c 100644 --- a/internal/api/mail.go +++ b/internal/api/mail.go @@ -5,7 +5,6 @@ import ( "strings" "time" - "github.com/didip/tollbooth/v5" "github.com/supabase/auth/internal/hooks" mail "github.com/supabase/auth/internal/mailer" "go.opentelemetry.io/otel/attribute" @@ -578,15 +577,13 @@ func (a *API) sendEmail(r *http.Request, tx *storage.Connection, u *models.User, externalURL := getExternalHost(ctx) // apply rate limiting before the email is sent out - if limiter := getLimiter(ctx); limiter != nil { - if err := tollbooth.LimitByKeys(limiter.EmailLimiter, []string{"email_functions"}); err != nil { - emailRateLimitCounter.Add( - ctx, - 1, - metric.WithAttributeSet(attribute.NewSet(attribute.String("path", r.URL.Path))), - ) - return EmailRateLimitExceeded - } + if ok := a.limiterOpts.Email.Allow(); !ok { + emailRateLimitCounter.Add( + ctx, + 1, + metric.WithAttributeSet(attribute.NewSet(attribute.String("path", r.URL.Path))), + ) + return EmailRateLimitExceeded } if config.Hook.SendEmail.Enabled { diff --git a/internal/api/middleware.go b/internal/api/middleware.go index 88d95e20c5..1ad8e8687d 100644 --- a/internal/api/middleware.go +++ b/internal/api/middleware.go @@ -77,27 +77,6 @@ func (a *API) limitHandler(lmt *limiter.Limiter) middlewareHandler { } } -func (a *API) limitEmailOrPhoneSentHandler(limiterOptions *LimiterOptions) middlewareHandler { - return func(w http.ResponseWriter, req *http.Request) (context.Context, error) { - c := req.Context() - config := a.config - shouldRateLimitEmail := config.External.Email.Enabled && !config.Mailer.Autoconfirm - shouldRateLimitPhone := config.External.Phone.Enabled && !config.Sms.Autoconfirm - - if shouldRateLimitEmail || shouldRateLimitPhone { - if req.Method == "PUT" || req.Method == "POST" { - // store rate limiter in request context - c = withLimiter(c, &SharedLimiter{ - EmailLimiter: limiterOptions.Email, - PhoneLimiter: limiterOptions.Phone, - }) - } - } - - return c, nil - } -} - func (a *API) requireAdminCredentials(w http.ResponseWriter, req *http.Request) (context.Context, error) { t, err := a.extractBearerToken(req) if err != nil || t == "" { diff --git a/internal/api/middleware_test.go b/internal/api/middleware_test.go index 33b62a79db..4d529068b1 100644 --- a/internal/api/middleware_test.go +++ b/internal/api/middleware_test.go @@ -185,52 +185,6 @@ func (ts *MiddlewareTestSuite) TestVerifyCaptchaInvalid() { } } -func (ts *MiddlewareTestSuite) TestLimitEmailOrPhoneSentHandler() { - // Set up rate limit config for this test - ts.Config.RateLimitEmailSent = 5 - ts.Config.RateLimitSmsSent = 5 - ts.Config.External.Phone.Enabled = true - - cases := []struct { - desc string - expectedErrorMsg string - requestBody map[string]interface{} - }{ - { - desc: "Email rate limit exceeded", - expectedErrorMsg: "429: Email rate limit exceeded", - requestBody: map[string]interface{}{ - "email": "test@example.com", - }, - }, - { - desc: "SMS rate limit exceeded", - expectedErrorMsg: "429: SMS rate limit exceeded", - requestBody: map[string]interface{}{ - "phone": "+1233456789", - }, - }, - } - - limiter := ts.API.limitEmailOrPhoneSentHandler(NewLimiterOptions(ts.Config)) - for _, c := range cases { - ts.Run(c.desc, func() { - var buffer bytes.Buffer - require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.requestBody)) - req := httptest.NewRequest(http.MethodPost, "http://localhost", &buffer) - req.Header.Set("Content-Type", "application/json") - w := httptest.NewRecorder() - - ctx, err := limiter(w, req) - require.NoError(ts.T(), err) - - // check that shared limiter is set in the request context - sharedLimiter := getLimiter(ctx) - require.NotNil(ts.T(), sharedLimiter) - }) - } -} - func (ts *MiddlewareTestSuite) TestIsValidExternalHost() { cases := []struct { desc string @@ -388,134 +342,6 @@ func (ts *MiddlewareTestSuite) TestLimitHandler() { require.Equal(ts.T(), http.StatusTooManyRequests, w.Code) } -func (ts *MiddlewareTestSuite) TestLimitHandlerWithSharedLimiter() { - // setup config for shared limiter and ip-based limiter to work - ts.Config.RateLimitHeader = "X-Rate-Limit" - ts.Config.External.Email.Enabled = true - ts.Config.External.Phone.Enabled = true - ts.Config.Mailer.Autoconfirm = false - ts.Config.Sms.Autoconfirm = false - - ipBasedLimiter := func(max float64) *limiter.Limiter { - return tollbooth.NewLimiter(max, &limiter.ExpirableOptions{ - DefaultExpirationTTL: time.Hour, - }) - } - - okHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - limiter := getLimiter(r.Context()) - if limiter != nil { - var requestBody struct { - Email string `json:"email"` - Phone string `json:"phone"` - } - err := retrieveRequestParams(r, &requestBody) - require.NoError(ts.T(), err) - - if requestBody.Email != "" { - if err := tollbooth.LimitByKeys(limiter.EmailLimiter, []string{"email_functions"}); err != nil { - sendJSON(w, http.StatusTooManyRequests, HTTPError{ - HTTPStatus: http.StatusTooManyRequests, - ErrorCode: ErrorCodeOverEmailSendRateLimit, - Message: "Email rate limit exceeded", - }) - } - } - if requestBody.Phone != "" { - if err := tollbooth.LimitByKeys(limiter.EmailLimiter, []string{"phone_functions"}); err != nil { - sendJSON(w, http.StatusTooManyRequests, HTTPError{ - HTTPStatus: http.StatusTooManyRequests, - ErrorCode: ErrorCodeOverSMSSendRateLimit, - Message: "SMS rate limit exceeded", - }) - } - } - } - w.WriteHeader(http.StatusOK) - }) - - cases := []struct { - desc string - sharedLimiterConfig *conf.GlobalConfiguration - ipBasedLimiterConfig float64 - body map[string]interface{} - expectedErrorCode string - }{ - { - desc: "Exceed ip-based rate limit before shared limiter", - sharedLimiterConfig: &conf.GlobalConfiguration{ - RateLimitEmailSent: 10, - RateLimitSmsSent: 10, - }, - ipBasedLimiterConfig: 1, - body: map[string]interface{}{ - "email": "foo@example.com", - }, - expectedErrorCode: ErrorCodeOverRequestRateLimit, - }, - { - desc: "Exceed email shared limiter", - sharedLimiterConfig: &conf.GlobalConfiguration{ - RateLimitEmailSent: 1, - RateLimitSmsSent: 1, - }, - ipBasedLimiterConfig: 10, - body: map[string]interface{}{ - "email": "foo@example.com", - }, - expectedErrorCode: ErrorCodeOverEmailSendRateLimit, - }, - { - desc: "Exceed sms shared limiter", - sharedLimiterConfig: &conf.GlobalConfiguration{ - RateLimitEmailSent: 1, - RateLimitSmsSent: 1, - }, - ipBasedLimiterConfig: 10, - body: map[string]interface{}{ - "phone": "123456789", - }, - expectedErrorCode: ErrorCodeOverSMSSendRateLimit, - }, - } - - for _, c := range cases { - ts.Run(c.desc, func() { - ts.Config.RateLimitEmailSent = c.sharedLimiterConfig.RateLimitEmailSent - ts.Config.RateLimitSmsSent = c.sharedLimiterConfig.RateLimitSmsSent - lmt := ts.API.limitHandler(ipBasedLimiter(c.ipBasedLimiterConfig)) - sharedLimiter := ts.API.limitEmailOrPhoneSentHandler(NewLimiterOptions(ts.Config)) - - // get the minimum amount to reach the threshold just before the rate limit is exceeded - threshold := min(c.sharedLimiterConfig.RateLimitEmailSent, c.sharedLimiterConfig.RateLimitSmsSent, c.ipBasedLimiterConfig) - for i := 0; i < int(threshold); i++ { - var buffer bytes.Buffer - require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body)) - req := httptest.NewRequest(http.MethodPost, "http://localhost", &buffer) - req.Header.Add(ts.Config.RateLimitHeader, "0.0.0.0") - - w := httptest.NewRecorder() - lmt.handler(sharedLimiter.handler(okHandler)).ServeHTTP(w, req) - require.Equal(ts.T(), http.StatusOK, w.Code) - } - - var buffer bytes.Buffer - require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body)) - req := httptest.NewRequest(http.MethodPost, "http://localhost", &buffer) - req.Header.Add(ts.Config.RateLimitHeader, "0.0.0.0") - - // check if the rate limit is exceeded with the expected error code - w := httptest.NewRecorder() - lmt.handler(sharedLimiter.handler(okHandler)).ServeHTTP(w, req) - require.Equal(ts.T(), http.StatusTooManyRequests, w.Code) - - var data map[string]interface{} - require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) - require.Equal(ts.T(), c.expectedErrorCode, data["error_code"]) - }) - } -} - func (ts *MiddlewareTestSuite) TestIsValidAuthorizedEmail() { ts.API.config.External.Email.AuthorizedAddresses = []string{"valid@example.com"} diff --git a/internal/api/options.go b/internal/api/options.go index 345e99d811..7c755b78ec 100644 --- a/internal/api/options.go +++ b/internal/api/options.go @@ -13,8 +13,9 @@ type Option interface { } type LimiterOptions struct { - Email *limiter.Limiter - Phone *limiter.Limiter + Email *RateLimiter + Phone *RateLimiter + Signups *limiter.Limiter AnonymousSignIns *limiter.Limiter Recover *limiter.Limiter @@ -35,16 +36,8 @@ func (lo *LimiterOptions) apply(a *API) { a.limiterOpts = lo } func NewLimiterOptions(gc *conf.GlobalConfiguration) *LimiterOptions { o := &LimiterOptions{} - o.Email = tollbooth.NewLimiter(gc.RateLimitEmailSent/(60*60), - &limiter.ExpirableOptions{ - DefaultExpirationTTL: time.Hour, - }).SetBurst(int(gc.RateLimitEmailSent)).SetMethods([]string{"PUT", "POST"}) - - o.Phone = tollbooth.NewLimiter(gc.RateLimitSmsSent/(60*60), - &limiter.ExpirableOptions{ - DefaultExpirationTTL: time.Hour, - }).SetBurst(int(gc.RateLimitSmsSent)).SetMethods([]string{"PUT", "POST"}) - + o.Email = newRateLimiter(gc.RateLimitEmailSent) + o.Phone = newRateLimiter(gc.RateLimitSmsSent) o.AnonymousSignIns = tollbooth.NewLimiter(gc.RateLimitAnonymousUsers/(60*60), &limiter.ExpirableOptions{ DefaultExpirationTTL: time.Hour, diff --git a/internal/api/phone.go b/internal/api/phone.go index ce11c5a3f6..7886210f6e 100644 --- a/internal/api/phone.go +++ b/internal/api/phone.go @@ -8,7 +8,6 @@ import ( "text/template" "time" - "github.com/didip/tollbooth/v5" "github.com/supabase/auth/internal/hooks" "github.com/pkg/errors" @@ -45,7 +44,6 @@ func formatPhoneNumber(phone string) string { // sendPhoneConfirmation sends an otp to the user's phone number func (a *API) sendPhoneConfirmation(r *http.Request, tx *storage.Connection, user *models.User, phone, otpType string, channel string) (string, error) { - ctx := r.Context() config := a.config var token *string @@ -89,11 +87,8 @@ func (a *API) sendPhoneConfirmation(r *http.Request, tx *storage.Connection, use // not using test OTPs if otp == "" { // apply rate limiting before the sms is sent out - limiter := getLimiter(ctx) - if limiter != nil { - if err := tollbooth.LimitByKeys(limiter.PhoneLimiter, []string{"phone_functions"}); err != nil { - return "", tooManyRequestsError(ErrorCodeOverSMSSendRateLimit, "SMS rate limit exceeded") - } + if ok := a.limiterOpts.Phone.Allow(); !ok { + return "", tooManyRequestsError(ErrorCodeOverSMSSendRateLimit, "SMS rate limit exceeded") } otp, err = crypto.GenerateOtp(config.Sms.OtpLength) if err != nil { diff --git a/internal/api/ratelimits.go b/internal/api/ratelimits.go new file mode 100644 index 0000000000..349c6a01b7 --- /dev/null +++ b/internal/api/ratelimits.go @@ -0,0 +1,49 @@ +package api + +import ( + "sync" + "time" + + "github.com/supabase/auth/internal/conf" +) + +// RateLimiter will limit the number of calls to Allow per interval. +type RateLimiter struct { + mu sync.Mutex + ival time.Duration // Count is reset and time updated every ival. + limit int // Limit calls to Allow() per ival. + + // Guarded by mu. + last time.Time // When the limiter was last reset. + count int // Total calls to Allow() since time. +} + +// newRateLimiter returns a rate limiter configured using the given conf.Rate. +func newRateLimiter(r conf.Rate) *RateLimiter { + return &RateLimiter{ + ival: r.OverTime, + limit: int(r.Events), + last: time.Now(), + } +} + +func (rl *RateLimiter) Allow() bool { + rl.mu.Lock() + defer rl.mu.Unlock() + + now := time.Now() + return rl.allowAt(now) +} + +func (rl *RateLimiter) allowAt(at time.Time) bool { + since := at.Sub(rl.last) + if ivals := int64(since / rl.ival); ivals > 0 { + rl.last = rl.last.Add(time.Duration(ivals) * rl.ival) + rl.count = 0 + } + if rl.count < rl.limit { + rl.count++ + return true + } + return false +} diff --git a/internal/api/ratelimits_test.go b/internal/api/ratelimits_test.go new file mode 100644 index 0000000000..a2cf0ec081 --- /dev/null +++ b/internal/api/ratelimits_test.go @@ -0,0 +1,99 @@ +package api + +import ( + "fmt" + "testing" + "time" + + "github.com/supabase/auth/internal/conf" +) + +func Example_newRateLimiter() { + now, _ := time.Parse(time.RFC3339, "2024-09-24T10:00:00.00Z") + cfg := conf.Rate{Events: 100, OverTime: time.Hour * 24} + rl := newRateLimiter(cfg) + rl.last = now + + cur := now + allowed := 0 + + for days := 0; days < 2; days++ { + // First 100 events succeed. + for i := 0; i < 100; i++ { + allow := rl.allowAt(cur) + cur = cur.Add(time.Second) + + if !allow { + fmt.Printf("false @ %v after %v events... [FAILED]\n", cur, allowed) + return + } + allowed++ + } + fmt.Printf("true @ %v for last %v events...\n", cur, allowed) + + // We try hourly until it allows us to make requests again. + denied := 0 + for i := 0; i < 23; i++ { + cur = cur.Add(time.Hour) + allow := rl.allowAt(cur) + if allow { + fmt.Printf("true @ %v before quota reset... [FAILED]\n", cur) + return + } + denied++ + } + fmt.Printf("false @ %v for last %v events...\n", cur, denied) + + cur = cur.Add(time.Hour) + } + + // Output: + // true @ 2024-09-24 10:01:40 +0000 UTC for last 100 events... + // false @ 2024-09-25 09:01:40 +0000 UTC for last 23 events... + // true @ 2024-09-25 10:03:20 +0000 UTC for last 200 events... + // false @ 2024-09-26 09:03:20 +0000 UTC for last 23 events... +} + +func TestNewRateLimiter(t *testing.T) { + now, _ := time.Parse(time.RFC3339, "2024-09-24T10:00:00.00Z") + + type event struct { + ok bool + at time.Time + r int + } + cases := []struct { + cfg conf.Rate + now time.Time + evts []event + }{ + { + cfg: conf.Rate{Events: 100, OverTime: time.Hour * 24}, + now: now, + evts: []event{ + {true, now, 0}, + {true, now.Add(time.Minute), 98}, + {false, now.Add(time.Minute), 0}, + {false, now.Add(time.Minute * 14), 0}, + {false, now.Add(time.Minute * 15), 0}, + {false, now.Add(time.Minute * 16), 0}, + {false, now.Add(time.Minute * 17), 0}, + {false, now.Add(time.Minute * 17), 0}, + {true, now.Add(time.Hour * 24), 0}, + {true, now.Add(time.Hour * 25), 0}, + }, + }, + } + for _, tc := range cases { + rl := newRateLimiter(tc.cfg) + rl.last = tc.now + + for _, evt := range tc.evts { + for i := 0; i <= evt.r; i++ { + if exp, got := evt.ok, rl.allowAt(evt.at); exp != got { + t.Fatalf("exp AllowN(%v, 1) to be %v; got %v", evt.at, exp, got) + } + } + } + } +} diff --git a/internal/conf/configuration.go b/internal/conf/configuration.go index f96f274132..4db5e2ddce 100644 --- a/internal/conf/configuration.go +++ b/internal/conf/configuration.go @@ -237,18 +237,19 @@ type PasswordConfiguration struct { // GlobalConfiguration holds all the configuration that applies to all instances. type GlobalConfiguration struct { - API APIConfiguration - DB DBConfiguration - External ProviderConfiguration - Logging LoggingConfig `envconfig:"LOG"` - Profiler ProfilerConfig `envconfig:"PROFILER"` - OperatorToken string `split_words:"true" required:"false"` - Tracing TracingConfig - Metrics MetricsConfig - SMTP SMTPConfiguration + API APIConfiguration + DB DBConfiguration + External ProviderConfiguration + Logging LoggingConfig `envconfig:"LOG"` + Profiler ProfilerConfig `envconfig:"PROFILER"` + OperatorToken string `split_words:"true" required:"false"` + Tracing TracingConfig + Metrics MetricsConfig + SMTP SMTPConfiguration + RateLimitHeader string `split_words:"true"` - RateLimitEmailSent float64 `split_words:"true" default:"30"` - RateLimitSmsSent float64 `split_words:"true" default:"30"` + RateLimitEmailSent Rate `split_words:"true" default:"30"` + RateLimitSmsSent Rate `split_words:"true" default:"30"` RateLimitVerify float64 `split_words:"true" default:"30"` RateLimitTokenRefresh float64 `split_words:"true" default:"150"` RateLimitSso float64 `split_words:"true" default:"30"` diff --git a/internal/conf/rate.go b/internal/conf/rate.go new file mode 100644 index 0000000000..ebe7ba475b --- /dev/null +++ b/internal/conf/rate.go @@ -0,0 +1,58 @@ +package conf + +import ( + "fmt" + "strconv" + "strings" + "time" +) + +const defaultOverTime = time.Hour + +type Rate struct { + Events float64 `json:"events,omitempty"` + OverTime time.Duration `json:"over_time,omitempty"` +} + +func (r *Rate) EventsPerSecond() float64 { + d := r.OverTime + if d == 0 { + d = defaultOverTime + } + return r.Events / d.Seconds() +} + +// Decode is used by envconfig to parse the env-config string to a Rate value. +func (r *Rate) Decode(value string) error { + if f, err := strconv.ParseFloat(value, 64); err == nil { + r.Events = f + r.OverTime = defaultOverTime + return nil + } + parts := strings.Split(value, "/") + if len(parts) != 2 { + return fmt.Errorf("rate: value does not match rate syntax %q", value) + } + + // 52 because the uint needs to fit in a float64 + e, err := strconv.ParseUint(parts[0], 10, 52) + if err != nil { + return fmt.Errorf("rate: events part of rate value %q failed to parse as uint64: %w", value, err) + } + + d, err := time.ParseDuration(parts[1]) + if err != nil { + return fmt.Errorf("rate: over-time part of rate value %q failed to parse as duration: %w", value, err) + } + + r.Events = float64(e) + r.OverTime = d + return nil +} + +func (r *Rate) String() string { + if r.OverTime == 0 { + return fmt.Sprintf("%f", r.Events) + } + return fmt.Sprintf("%d/%s", uint64(r.Events), r.OverTime.String()) +} diff --git a/internal/conf/rate_test.go b/internal/conf/rate_test.go new file mode 100644 index 0000000000..264965607d --- /dev/null +++ b/internal/conf/rate_test.go @@ -0,0 +1,53 @@ +package conf + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestRateDecode(t *testing.T) { + cases := []struct { + str string + eps float64 + exp Rate + err string + }{ + {str: "1800", eps: 0.5, exp: Rate{Events: 1800, OverTime: time.Hour}}, + {str: "1800.0", eps: 0.5, exp: Rate{Events: 1800, OverTime: time.Hour}}, + {str: "3600/1h", eps: 1, exp: Rate{Events: 3600, OverTime: time.Hour}}, + {str: "100/24h", + eps: 0.0011574074074074073, + exp: Rate{Events: 100, OverTime: time.Hour * 24}}, + + {str: "", eps: 1, exp: Rate{}, + err: `rate: value does not match`}, + {str: "1h", eps: 1, exp: Rate{}, + err: `rate: value does not match`}, + {str: "/", eps: 1, exp: Rate{}, + err: `rate: events part of rate value`}, + {str: "/1h", eps: 1, exp: Rate{}, + err: `rate: events part of rate value`}, + {str: "3600.0/1h", eps: 1, exp: Rate{}, + err: `rate: events part of rate value "3600.0/1h" failed to parse`}, + {str: "100/", eps: 1, exp: Rate{}, + err: `rate: over-time part of rate value`}, + {str: "100/1", eps: 1, exp: Rate{}, + err: `rate: over-time part of rate value`}, + } + for idx, tc := range cases { + var r Rate + err := r.Decode(tc.str) + require.Equal(t, tc.exp, r) // verify don't mutate r on errr + t.Logf("tc #%v - duration str %v", idx, tc.str) + if tc.err != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tc.err) + continue + } + require.NoError(t, err) + require.Equal(t, tc.exp, r) + require.Equal(t, tc.eps, r.EventsPerSecond()) + } +}