Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

httpbp: Fix Retries middleware #647

Merged
merged 1 commit into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions httpbp/client_middlewares.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package httpbp

import (
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"strconv"
"sync"
Expand Down Expand Up @@ -200,16 +202,36 @@ func CircuitBreaker(config breakerbp.Config) ClientMiddleware {
// Retries provides a retry middleware by ensuring certain HTTP responses are
// wrapped in errors. Retries wraps the ClientErrorWrapper middleware, e.g. if
// you are using Retries there is no need to also use ClientErrorWrapper.
func Retries(limit int, retryOptions ...retry.Option) ClientMiddleware {
func Retries(maxErrorReadAhead int, retryOptions ...retry.Option) ClientMiddleware {
if len(retryOptions) == 0 {
retryOptions = []retry.Option{retry.Attempts(1)}
}
return func(next http.RoundTripper) http.RoundTripper {
// include ClientErrorWrapper to ensure retry is applied for some HTTP 5xx
// responses
next = ClientErrorWrapper(maxErrorReadAhead)(next)

return roundTripperFunc(func(req *http.Request) (resp *http.Response, err error) {
if req.Body != nil && req.Body != http.NoBody && req.GetBody == nil {
slog.WarnContext(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔕 do you want to this as a once or do it on a rate limit?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will only log once per request (it's not inside the retry loop), so if it's logged multiple times that means users have multiple requests constructed incorrectly, which I think it's warranted to log multiple times (for each incorrect request)

and we are using slog, so if users really want to rate-limit it/suppress it they can do so in their slog handler :)

wdyt?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤷

req.Context(),
"Request comes with a Body but nil GetBody cannot be retried. httpbp.Retries middleware skipped.",
"req", req,
)
return next.RoundTrip(req)
}

err = retrybp.Do(req.Context(), func() error {
// include ClientErrorWrapper to ensure retry is applied for
// some HTTP 5xx responses
resp, err = ClientErrorWrapper(limit)(next).RoundTrip(req)
req = req.Clone(req.Context())
if req.GetBody != nil {
body, err := req.GetBody()
if err != nil {
return fmt.Errorf("httpbp.Retries: GetBody returned error: %w", err)
}
req.Body = body
}

resp, err = next.RoundTrip(req)
if err != nil {
return err
}
Expand Down
127 changes: 89 additions & 38 deletions httpbp/client_middlewares_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"io"
"net/http"
"net/http/httptest"
"net/url"
"sync"
"sync/atomic"
"testing"
Expand Down Expand Up @@ -256,11 +257,22 @@ func TestClientErrorWrapper(t *testing.T) {
})
}

func unwrapRetryErrors(err error) []error {
var errs interface {
error

Unwrap() []error
}
if errors.As(err, &errs) {
return errs.Unwrap()
}
return []error{err}
}

func TestRetry(t *testing.T) {
t.Run("retry for timeout", func(t *testing.T) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this subtest is removed because we don't really retry for timeouts (there's no timeout budget left for any retries), the expected attempt in this subtest is also 1.

const timeout = time.Millisecond * 10
t.Run("retry for HTTP 500", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(timeout * 10)
w.WriteHeader(http.StatusInternalServerError)
}))
defer server.Close()

Expand All @@ -274,78 +286,117 @@ func TestRetry(t *testing.T) {
attempts = n + 1
}),
)(http.DefaultTransport),
Timeout: timeout,
}
_, err := client.Get(server.URL)
u, err := url.Parse(server.URL)
if err != nil {
t.Fatalf("Failed to parse url %q: %v", server.URL, err)
}
req := &http.Request{
Method: http.MethodPost,
URL: u,

// Explicitly set Body to http.NoBody and GetBody to nil,
// This request should not cause Retries middleware to be skipped.
Body: http.NoBody,
GetBody: nil,
}
_, err = client.Do(req)
if err == nil {
t.Fatalf("expected error to be non-nil")
}
expected := uint(1)
expected := uint(2)
if attempts != expected {
t.Errorf("expected %d, actual: %d", expected, attempts)
}
errs := unwrapRetryErrors(err)
if len(errs) != int(expected) {
t.Errorf("Expected %d retry erros, got %+v", expected, errs)
}
for i, err := range errs {
var ce *ClientError
if errors.As(err, &ce) {
if got, want := ce.StatusCode, http.StatusInternalServerError; got != want {
t.Errorf("#%d: status got %d want %d", i, got, want)
}
} else {
t.Errorf("#%d: %#v is not of type *httpbp.ClientError", i, err)
}
}
})

t.Run("retry for HTTP 500", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Run("retry POST+HTTPS request", func(t *testing.T) {
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
b, err := io.ReadAll(r.Body)
if err != nil {
t.Fatal(err)
}
expected := "{}"
got := string(b)
if got != expected {
t.Errorf("expected %q, got: %q", expected, got)
}
t.Logf("Full body: %q", got)
w.WriteHeader(http.StatusInternalServerError)
}))
defer server.Close()

var attempts uint
client := &http.Client{
Transport: Retries(
DefaultMaxErrorReadAhead,
retry.Attempts(2),
retry.OnRetry(func(n uint, err error) {
// set number of attempts to check if retries were attempted
attempts = n + 1
}),
)(http.DefaultTransport),
}
_, err := client.Get(server.URL)
t.Log(server.URL)
client := server.Client()
client.Transport = Retries(
DefaultMaxErrorReadAhead,
retry.Attempts(2),
retry.OnRetry(func(n uint, err error) {
// set number of attempts to check if retries were attempted
attempts = n + 1
}),
)(client.Transport)
_, err := client.Post(server.URL, "application/json", bytes.NewBufferString("{}"))
if err == nil {
t.Fatalf("expected error to be non-nil")
}
expected := uint(2)
if attempts != expected {
t.Errorf("expected %d, actual: %d", expected, attempts)
}
errs := unwrapRetryErrors(err)
if len(errs) != int(expected) {
t.Errorf("Expected %d retry erros, got %+v", expected, errs)
}
for i, err := range errs {
var ce *ClientError
if errors.As(err, &ce) {
if got, want := ce.StatusCode, http.StatusInternalServerError; got != want {
t.Errorf("#%d: status got %d want %d", i, got, want)
}
} else {
t.Errorf("#%d: %#v is not of type *httpbp.ClientError", i, err)
}
}
})

t.Run("retry POST request", func(t *testing.T) {
t.Run("skip retry for wrongly constructed request", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
b, err := io.ReadAll(r.Body)
if err != nil {
t.Fatal(err)
}
expected := "{}"
got := string(b)
if got != expected {
t.Errorf("expected %q, got: %q", expected, got)
}
w.WriteHeader(http.StatusInternalServerError)
}))
defer server.Close()

var attempts uint
client := &http.Client{
Transport: Retries(
DefaultMaxErrorReadAhead,
retry.Attempts(2),
retry.OnRetry(func(n uint, err error) {
// set number of attempts to check if retries were attempted
attempts = n + 1
t.Errorf("Retry not skipped. OnRetry called with (%d, %v)", n, err)
}),
)(http.DefaultTransport),
}
_, err := client.Post(server.URL, "application/json", bytes.NewBufferString("{}"))
if err == nil {
t.Fatalf("expected error to be non-nil")
req, err := http.NewRequest(http.MethodGet, server.URL, bytes.NewBufferString("{}"))
if err != nil {
t.Fatalf("Failed to create http request: %v", err)
}
expected := uint(2)
if attempts != expected {
t.Errorf("expected %d, actual: %d", expected, attempts)
req.GetBody = nil
if _, err := client.Do(req); err == nil {
t.Fatalf("expected error to be non-nil")
}
})
}
Expand Down
Loading