diff --git a/httpbp/client_middlewares.go b/httpbp/client_middlewares.go index f621a826c..3944a1fab 100644 --- a/httpbp/client_middlewares.go +++ b/httpbp/client_middlewares.go @@ -2,7 +2,9 @@ package httpbp import ( "errors" + "fmt" "io" + "log/slog" "net/http" "strconv" "sync" @@ -200,16 +202,36 @@ func CircuitBreaker(config breakerbp.Config) ClientMiddleware { // Retries provides a retry middleware by ensuring certain HTTP responses are // wrapped in errors. Retries wraps the ClientErrorWrapper middleware, e.g. if // you are using Retries there is no need to also use ClientErrorWrapper. -func Retries(limit int, retryOptions ...retry.Option) ClientMiddleware { +func Retries(maxErrorReadAhead int, retryOptions ...retry.Option) ClientMiddleware { if len(retryOptions) == 0 { retryOptions = []retry.Option{retry.Attempts(1)} } return func(next http.RoundTripper) http.RoundTripper { + // include ClientErrorWrapper to ensure retry is applied for some HTTP 5xx + // responses + next = ClientErrorWrapper(maxErrorReadAhead)(next) + return roundTripperFunc(func(req *http.Request) (resp *http.Response, err error) { + if req.Body != nil && req.Body != http.NoBody && req.GetBody == nil { + slog.WarnContext( + req.Context(), + "Request comes with a Body but nil GetBody cannot be retried. httpbp.Retries middleware skipped.", + "req", req, + ) + return next.RoundTrip(req) + } + err = retrybp.Do(req.Context(), func() error { - // include ClientErrorWrapper to ensure retry is applied for - // some HTTP 5xx responses - resp, err = ClientErrorWrapper(limit)(next).RoundTrip(req) + req = req.Clone(req.Context()) + if req.GetBody != nil { + body, err := req.GetBody() + if err != nil { + return fmt.Errorf("httpbp.Retries: GetBody returned error: %w", err) + } + req.Body = body + } + + resp, err = next.RoundTrip(req) if err != nil { return err } diff --git a/httpbp/client_middlewares_test.go b/httpbp/client_middlewares_test.go index 426e74e35..bfd9efc59 100644 --- a/httpbp/client_middlewares_test.go +++ b/httpbp/client_middlewares_test.go @@ -8,6 +8,7 @@ import ( "io" "net/http" "net/http/httptest" + "net/url" "sync" "sync/atomic" "testing" @@ -256,11 +257,22 @@ func TestClientErrorWrapper(t *testing.T) { }) } +func unwrapRetryErrors(err error) []error { + var errs interface { + error + + Unwrap() []error + } + if errors.As(err, &errs) { + return errs.Unwrap() + } + return []error{err} +} + func TestRetry(t *testing.T) { - t.Run("retry for timeout", func(t *testing.T) { - const timeout = time.Millisecond * 10 + t.Run("retry for HTTP 500", func(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - time.Sleep(timeout * 10) + w.WriteHeader(http.StatusInternalServerError) })) defer server.Close() @@ -274,36 +286,72 @@ func TestRetry(t *testing.T) { attempts = n + 1 }), )(http.DefaultTransport), - Timeout: timeout, } - _, err := client.Get(server.URL) + u, err := url.Parse(server.URL) + if err != nil { + t.Fatalf("Failed to parse url %q: %v", server.URL, err) + } + req := &http.Request{ + Method: http.MethodPost, + URL: u, + + // Explicitly set Body to http.NoBody and GetBody to nil, + // This request should not cause Retries middleware to be skipped. + Body: http.NoBody, + GetBody: nil, + } + _, err = client.Do(req) if err == nil { t.Fatalf("expected error to be non-nil") } - expected := uint(1) + expected := uint(2) if attempts != expected { t.Errorf("expected %d, actual: %d", expected, attempts) } + errs := unwrapRetryErrors(err) + if len(errs) != int(expected) { + t.Errorf("Expected %d retry erros, got %+v", expected, errs) + } + for i, err := range errs { + var ce *ClientError + if errors.As(err, &ce) { + if got, want := ce.StatusCode, http.StatusInternalServerError; got != want { + t.Errorf("#%d: status got %d want %d", i, got, want) + } + } else { + t.Errorf("#%d: %#v is not of type *httpbp.ClientError", i, err) + } + } }) - t.Run("retry for HTTP 500", func(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Run("retry POST+HTTPS request", func(t *testing.T) { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + b, err := io.ReadAll(r.Body) + if err != nil { + t.Fatal(err) + } + expected := "{}" + got := string(b) + if got != expected { + t.Errorf("expected %q, got: %q", expected, got) + } + t.Logf("Full body: %q", got) w.WriteHeader(http.StatusInternalServerError) })) defer server.Close() var attempts uint - client := &http.Client{ - Transport: Retries( - DefaultMaxErrorReadAhead, - retry.Attempts(2), - retry.OnRetry(func(n uint, err error) { - // set number of attempts to check if retries were attempted - attempts = n + 1 - }), - )(http.DefaultTransport), - } - _, err := client.Get(server.URL) + t.Log(server.URL) + client := server.Client() + client.Transport = Retries( + DefaultMaxErrorReadAhead, + retry.Attempts(2), + retry.OnRetry(func(n uint, err error) { + // set number of attempts to check if retries were attempted + attempts = n + 1 + }), + )(client.Transport) + _, err := client.Post(server.URL, "application/json", bytes.NewBufferString("{}")) if err == nil { t.Fatalf("expected error to be non-nil") } @@ -311,41 +359,44 @@ func TestRetry(t *testing.T) { if attempts != expected { t.Errorf("expected %d, actual: %d", expected, attempts) } + errs := unwrapRetryErrors(err) + if len(errs) != int(expected) { + t.Errorf("Expected %d retry erros, got %+v", expected, errs) + } + for i, err := range errs { + var ce *ClientError + if errors.As(err, &ce) { + if got, want := ce.StatusCode, http.StatusInternalServerError; got != want { + t.Errorf("#%d: status got %d want %d", i, got, want) + } + } else { + t.Errorf("#%d: %#v is not of type *httpbp.ClientError", i, err) + } + } }) - t.Run("retry POST request", func(t *testing.T) { + t.Run("skip retry for wrongly constructed request", func(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - b, err := io.ReadAll(r.Body) - if err != nil { - t.Fatal(err) - } - expected := "{}" - got := string(b) - if got != expected { - t.Errorf("expected %q, got: %q", expected, got) - } w.WriteHeader(http.StatusInternalServerError) })) defer server.Close() - var attempts uint client := &http.Client{ Transport: Retries( DefaultMaxErrorReadAhead, retry.Attempts(2), retry.OnRetry(func(n uint, err error) { - // set number of attempts to check if retries were attempted - attempts = n + 1 + t.Errorf("Retry not skipped. OnRetry called with (%d, %v)", n, err) }), )(http.DefaultTransport), } - _, err := client.Post(server.URL, "application/json", bytes.NewBufferString("{}")) - if err == nil { - t.Fatalf("expected error to be non-nil") + req, err := http.NewRequest(http.MethodGet, server.URL, bytes.NewBufferString("{}")) + if err != nil { + t.Fatalf("Failed to create http request: %v", err) } - expected := uint(2) - if attempts != expected { - t.Errorf("expected %d, actual: %d", expected, attempts) + req.GetBody = nil + if _, err := client.Do(req); err == nil { + t.Fatalf("expected error to be non-nil") } }) }