Skip to content

Commit

Permalink
httpbp: Fix retry request body
Browse files Browse the repository at this point in the history
When we set GetBody in http.Request, it's expected that Body is also
set, add special handling in Retries to make sure we also set Body when
retrying when GetBody is also set.
  • Loading branch information
fishy committed Apr 17, 2024
1 parent 5eb5e90 commit 8154f42
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 4 deletions.
12 changes: 10 additions & 2 deletions httpbp/client_middlewares.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package httpbp

import (
"errors"
"fmt"
"io"
"net/http"
"strconv"
Expand Down Expand Up @@ -200,16 +201,23 @@ func CircuitBreaker(config breakerbp.Config) ClientMiddleware {
// Retries provides a retry middleware by ensuring certain HTTP responses are
// wrapped in errors. Retries wraps the ClientErrorWrapper middleware, e.g. if
// you are using Retries there is no need to also use ClientErrorWrapper.
func Retries(limit int, retryOptions ...retry.Option) ClientMiddleware {
func Retries(maxErrorReadAhead int, retryOptions ...retry.Option) ClientMiddleware {
if len(retryOptions) == 0 {
retryOptions = []retry.Option{retry.Attempts(1)}
}
return func(next http.RoundTripper) http.RoundTripper {
return roundTripperFunc(func(req *http.Request) (resp *http.Response, err error) {
if req.GetBody != nil {
body, err := req.GetBody()
if err != nil {
return nil, fmt.Errorf("httpbp.Retries: GetBody returned error: %w", err)
}
req.Body = body
}
err = retrybp.Do(req.Context(), func() error {
// include ClientErrorWrapper to ensure retry is applied for
// some HTTP 5xx responses
resp, err = ClientErrorWrapper(limit)(next).RoundTrip(req)
resp, err = ClientErrorWrapper(maxErrorReadAhead)(next).RoundTrip(req)
if err != nil {
return err
}
Expand Down
27 changes: 25 additions & 2 deletions httpbp/client_middlewares_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"sync/atomic"
"testing"
Expand Down Expand Up @@ -259,7 +260,20 @@ func TestClientErrorWrapper(t *testing.T) {
func TestRetry(t *testing.T) {
t.Run("retry for timeout", func(t *testing.T) {
const timeout = time.Millisecond * 10
const body = "body"
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Body == nil {
t.Error("Request body is empty")
} else {
read, err := io.ReadAll(r.Body)
if err != nil {
t.Errorf("Failed to read body: %v", err)
} else {
if got := string(read); got != body {
t.Errorf("Request body got %q want %q", got, body)
}
}
}
time.Sleep(timeout * 10)
}))
defer server.Close()
Expand All @@ -276,8 +290,17 @@ func TestRetry(t *testing.T) {
)(http.DefaultTransport),
Timeout: timeout,
}
_, err := client.Get(server.URL)
if err == nil {
getBody := func() io.Reader {
return strings.NewReader(body)
}
req, err := http.NewRequest(http.MethodGet, server.URL, getBody())
if err != nil {
t.Fatal(err)
}
req.GetBody = func() (io.ReadCloser, error) {
return io.NopCloser(getBody()), nil
}
if _, err := client.Do(req); err == nil {
t.Fatalf("expected error to be non-nil")
}
expected := uint(1)
Expand Down

0 comments on commit 8154f42

Please sign in to comment.