From 8154f42bdeb82d7d0b4cc798a9bb9b29310e17e6 Mon Sep 17 00:00:00 2001 From: Yuxuan 'fishy' Wang Date: Wed, 17 Apr 2024 09:24:05 -0700 Subject: [PATCH] httpbp: Fix retry request body When we set GetBody in http.Request, it's expected that Body is also set, add special handling in Retries to make sure we also set Body when retrying when GetBody is also set. --- httpbp/client_middlewares.go | 12 ++++++++++-- httpbp/client_middlewares_test.go | 27 +++++++++++++++++++++++++-- 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/httpbp/client_middlewares.go b/httpbp/client_middlewares.go index f621a826c..33a3407f9 100644 --- a/httpbp/client_middlewares.go +++ b/httpbp/client_middlewares.go @@ -2,6 +2,7 @@ package httpbp import ( "errors" + "fmt" "io" "net/http" "strconv" @@ -200,16 +201,23 @@ func CircuitBreaker(config breakerbp.Config) ClientMiddleware { // Retries provides a retry middleware by ensuring certain HTTP responses are // wrapped in errors. Retries wraps the ClientErrorWrapper middleware, e.g. if // you are using Retries there is no need to also use ClientErrorWrapper. -func Retries(limit int, retryOptions ...retry.Option) ClientMiddleware { +func Retries(maxErrorReadAhead int, retryOptions ...retry.Option) ClientMiddleware { if len(retryOptions) == 0 { retryOptions = []retry.Option{retry.Attempts(1)} } return func(next http.RoundTripper) http.RoundTripper { return roundTripperFunc(func(req *http.Request) (resp *http.Response, err error) { + if req.GetBody != nil { + body, err := req.GetBody() + if err != nil { + return nil, fmt.Errorf("httpbp.Retries: GetBody returned error: %w", err) + } + req.Body = body + } err = retrybp.Do(req.Context(), func() error { // include ClientErrorWrapper to ensure retry is applied for // some HTTP 5xx responses - resp, err = ClientErrorWrapper(limit)(next).RoundTrip(req) + resp, err = ClientErrorWrapper(maxErrorReadAhead)(next).RoundTrip(req) if err != nil { return err } diff --git a/httpbp/client_middlewares_test.go b/httpbp/client_middlewares_test.go index 426e74e35..5755c8af4 100644 --- a/httpbp/client_middlewares_test.go +++ b/httpbp/client_middlewares_test.go @@ -8,6 +8,7 @@ import ( "io" "net/http" "net/http/httptest" + "strings" "sync" "sync/atomic" "testing" @@ -259,7 +260,20 @@ func TestClientErrorWrapper(t *testing.T) { func TestRetry(t *testing.T) { t.Run("retry for timeout", func(t *testing.T) { const timeout = time.Millisecond * 10 + const body = "body" server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Body == nil { + t.Error("Request body is empty") + } else { + read, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("Failed to read body: %v", err) + } else { + if got := string(read); got != body { + t.Errorf("Request body got %q want %q", got, body) + } + } + } time.Sleep(timeout * 10) })) defer server.Close() @@ -276,8 +290,17 @@ func TestRetry(t *testing.T) { )(http.DefaultTransport), Timeout: timeout, } - _, err := client.Get(server.URL) - if err == nil { + getBody := func() io.Reader { + return strings.NewReader(body) + } + req, err := http.NewRequest(http.MethodGet, server.URL, getBody()) + if err != nil { + t.Fatal(err) + } + req.GetBody = func() (io.ReadCloser, error) { + return io.NopCloser(getBody()), nil + } + if _, err := client.Do(req); err == nil { t.Fatalf("expected error to be non-nil") } expected := uint(1)