Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Thrift and HTTP client-side fault injection. #666

Open
wants to merge 33 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
767a3b6
Rough draft of idea.
HiramSilvey Nov 14, 2024
89dee2f
Refactored into common library.
HiramSilvey Nov 15, 2024
0aff005
Return universal Thrift errors.
HiramSilvey Nov 18, 2024
fb9035e
Revert to only support Transport errors.
HiramSilvey Nov 18, 2024
c4c7c1c
Initial POC test. Needs refactoring still.
HiramSilvey Nov 19, 2024
ed571bf
Refactored and moved heavy testing logic into common test.
HiramSilvey Nov 22, 2024
73cee20
WIP HTTP tests.
HiramSilvey Nov 25, 2024
f18a4c3
Fixed HTTP tests.
HiramSilvey Dec 6, 2024
43a5c57
Add logging and refine tests.
HiramSilvey Dec 6, 2024
ac4e91a
Use official rand library and update logging.
HiramSilvey Dec 6, 2024
b2d1003
go mod tidy
HiramSilvey Dec 6, 2024
99dc665
Fixed shortened address to include port trimming.
HiramSilvey Dec 6, 2024
3f63e05
Remove unused function in test code.
HiramSilvey Dec 6, 2024
4444f33
Fix lint errors.
HiramSilvey Dec 6, 2024
1c77548
Update based on PR review comments.
HiramSilvey Dec 10, 2024
71114b0
Simplified random int generation.
HiramSilvey Dec 10, 2024
4fd7018
Move faults into internal directory.
HiramSilvey Dec 10, 2024
291b7bc
Fix style.
HiramSilvey Dec 10, 2024
b6d720f
Remove unused type struct and fix imports.
HiramSilvey Dec 10, 2024
a892c01
Expand testing and fix PR review nits.
HiramSilvey Dec 11, 2024
3bb8fd0
Fix imports formatting.
HiramSilvey Dec 11, 2024
66851ea
Fix typo.
HiramSilvey Dec 11, 2024
e2908ff
Strip port and anything after it for non-cluster-local addresses.
HiramSilvey Dec 11, 2024
2e2882b
Fix tests and update HTTP address to strip port automatically.
HiramSilvey Dec 11, 2024
6f2024e
Add edge case test.
HiramSilvey Dec 11, 2024
d467f04
Update to use a different random integer per feature.
HiramSilvey Dec 11, 2024
d602696
Fix typo.
HiramSilvey Dec 12, 2024
788b3e4
Abort delay if request context cancelled.
HiramSilvey Dec 12, 2024
01cf3f0
Update test to be more clear what the effect is.
HiramSilvey Dec 12, 2024
e75e5ea
Fix printf formatters.
HiramSilvey Dec 13, 2024
d03c802
Fix oversight where skipping the delay would skip abort as well.
HiramSilvey Dec 13, 2024
065e336
Update abort section to mirror delay section more closely in style.
HiramSilvey Dec 13, 2024
c2d6cd2
Move sleep fn up for better readability.
HiramSilvey Dec 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions httpbp/client_middlewares.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"log/slog"
"net/http"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
Expand All @@ -15,6 +16,7 @@ import (
"github.com/prometheus/client_golang/prometheus"

"github.com/reddit/baseplate.go/breakerbp"
"github.com/reddit/baseplate.go/internal/faults"
//lint:ignore SA1019 This library is internal only, not actually deprecated
"github.com/reddit/baseplate.go/internalv2compat"
"github.com/reddit/baseplate.go/retrybp"
Expand Down Expand Up @@ -43,6 +45,8 @@ func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
// plus any additional client middleware passed into this function. Default
// middlewares are:
//
// * FaultInjection
//
// * MonitorClient with transport.WithRetrySlugSuffix
//
// * PrometheusClientMetrics with transport.WithRetrySlugSuffix
Expand Down Expand Up @@ -76,6 +80,7 @@ func NewClient(config ClientConfig, middleware ...ClientMiddleware) (*http.Clien
}

defaults := []ClientMiddleware{
FaultInjection(),
MonitorClient(config.Slug + transport.WithRetrySlugSuffix),
PrometheusClientMetrics(config.Slug + transport.WithRetrySlugSuffix),
Retries(config.MaxErrorReadAhead, config.RetryOptions...),
Expand Down Expand Up @@ -349,3 +354,44 @@ func PrometheusClientMetrics(serverSlug string) ClientMiddleware {
})
}
}

func FaultInjection() ClientMiddleware {
return func(next http.RoundTripper) http.RoundTripper {
return roundTripperFunc(func(req *http.Request) (*http.Response, error) {
resumeFn := func() (*http.Response, error) {
return next.RoundTrip(req)
}
responseFn := func(code int, message string) (*http.Response, error) {
return &http.Response{
HiramSilvey marked this conversation as resolved.
Show resolved Hide resolved
Status: http.StatusText(code),
StatusCode: code,
Proto: req.Proto,
ProtoMajor: req.ProtoMajor,
ProtoMinor: req.ProtoMinor,
Header: map[string][]string{
// Copied from the standard http.Error() function.
"Content-Type": {"text/plain; charset=utf-8"},
"X-Content-Type-Options": {"nosniff"},
},
ContentLength: 0,
TransferEncoding: req.TransferEncoding,
Request: req,
TLS: req.TLS,
}, nil
}

resp, err := faults.InjectFault(faults.InjectFaultParams[*http.Response]{
Context: req.Context(),
CallerName: "httpbp.FaultInjection",
Address: req.URL.Hostname(),
Method: strings.TrimPrefix(req.URL.Path, "/"),
AbortCodeMin: 400,
AbortCodeMax: 599,
GetHeaderFn: faults.GetHeaderFn(req.Header.Get),
ResumeFn: resumeFn,
ResponseFn: responseFn,
})
return resp, err
})
}
}
138 changes: 138 additions & 0 deletions httpbp/client_middlewares_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/sony/gobreaker"

"github.com/reddit/baseplate.go/breakerbp"
"github.com/reddit/baseplate.go/internal/faults"
)

func TestNewClient(t *testing.T) {
Expand Down Expand Up @@ -395,3 +396,140 @@ func TestCircuitBreaker(t *testing.T) {
t.Errorf("Expected the third request to return %v, got %v", gobreaker.ErrOpenState, err)
}
}

func TestFaultInjection(t *testing.T) {
testCases := []struct {
name string
faultServerAddrMatch bool
faultServerMethodHeader string
faultDelayMsHeader string
faultDelayPercentageHeader string
faultAbortCodeHeader string
faultAbortMessageHeader string
faultAbortPercentageHeader string

wantResp *http.Response
}{
{
name: "no fault specified",
wantResp: &http.Response{
StatusCode: http.StatusOK,
},
},
{
name: "abort",

faultServerAddrMatch: true,
faultServerMethodHeader: "testMethod",
faultAbortCodeHeader: "500",

wantResp: &http.Response{
StatusCode: http.StatusInternalServerError,
},
},
{
name: "service does not match",

faultServerAddrMatch: false,
faultServerMethodHeader: "testMethod",
faultAbortCodeHeader: "500",

wantResp: &http.Response{
StatusCode: http.StatusOK,
},
},
{
name: "method does not match",

faultServerAddrMatch: true,
faultServerMethodHeader: "fooMethod",
faultAbortCodeHeader: "500",

wantResp: &http.Response{
StatusCode: http.StatusOK,
},
},
{
name: "less than min abort code",

faultServerAddrMatch: true,
faultServerMethodHeader: "testMethod",
faultAbortCodeHeader: "99",

wantResp: &http.Response{
StatusCode: http.StatusOK,
},
},
{
name: "greater than max abort code",

faultServerAddrMatch: true,
faultServerMethodHeader: "testMethod",
faultAbortCodeHeader: "600",

wantResp: &http.Response{
StatusCode: http.StatusOK,
},
},
}

for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, "Success!")
}))
defer server.Close()

client, err := NewClient(ClientConfig{
Slug: "test",
})
if err != nil {
t.Fatalf("NewClient returned error: %v", err)
}

req, err := http.NewRequest("GET", server.URL+"/testMethod", nil)
if err != nil {
t.Fatalf("unexpected error when creating request: %v", err)
}

if tt.faultServerAddrMatch {
// We can't set a specific address here because the middleware
// relies on the DNS address, which is not customizable when making
// real requests to a local HTTP test server.
parsed, err := url.Parse(server.URL)
if err != nil {
t.Fatalf("unexpected error when parsing httptest server URL: %v", err)
}
req.Header.Set(faults.FaultServerAddressHeader, parsed.Hostname())
}
if tt.faultServerMethodHeader != "" {
req.Header.Set(faults.FaultServerMethodHeader, tt.faultServerMethodHeader)
}
if tt.faultDelayMsHeader != "" {
req.Header.Set(faults.FaultDelayMsHeader, tt.faultDelayMsHeader)
}
if tt.faultDelayPercentageHeader != "" {
req.Header.Set(faults.FaultDelayPercentageHeader, tt.faultDelayPercentageHeader)
}
if tt.faultAbortCodeHeader != "" {
req.Header.Set(faults.FaultAbortCodeHeader, tt.faultAbortCodeHeader)
}
if tt.faultAbortMessageHeader != "" {
req.Header.Set(faults.FaultAbortMessageHeader, tt.faultAbortMessageHeader)
}
if tt.faultAbortPercentageHeader != "" {
req.Header.Set(faults.FaultAbortPercentageHeader, tt.faultAbortPercentageHeader)
}

resp, err := client.Do(req)

if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if tt.wantResp.StatusCode != resp.StatusCode {
t.Fatalf("expected response code %v, got %v", tt.wantResp.StatusCode, resp.StatusCode)
}
})
}
}
162 changes: 162 additions & 0 deletions internal/faults/common.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
// Package faults provides common headers and client-side fault injection
// functionality.
package faults

import (
"context"
"fmt"
"log/slog"
"math/rand/v2"
"strconv"
"strings"
"time"
)

// GetHeaderFn is the function type to return the value of a protocol-specific
// header with the given key.
type GetHeaderFn func(key string) string

// ResumeFn is the function type to continue processing the protocol-specific
// request without injecting a fault.
type ResumeFn[T any] func() (T, error)

// ResponseFn is the function type to inject a protocol-specific fault with the
// given code and message.
type ResponseFn[T any] func(code int, message string) (T, error)

// sleepFn is the function type to sleep for the given duration. Only used in
// tests.
type sleepFn func(ctx context.Context, d time.Duration) error

// The canonical address for a cluster-local address is <service>.<namespace>,
// without the local cluster suffix or port. The canonical address for a
// non-cluster-local address is the full original address without the port.
func getCanonicalAddress(serverAddress string) string {
// Cluster-local address.
if i := strings.Index(serverAddress, ".svc.cluster.local"); i != -1 {
return serverAddress[:i]
}
// External host:port address.
if i := strings.LastIndex(serverAddress, ":"); i != -1 {
port := serverAddress[i+1:]
// Verify this is actually a port number.
if port != "" && port[0] >= '0' && port[0] <= '9' {
return serverAddress[:i]
}
}
// Other address, i.e. unix domain socket.
return serverAddress
}

func parsePercentage(percentage string) (int, error) {
if percentage == "" {
return 100, nil
}
intPercentage, err := strconv.Atoi(percentage)
if err != nil {
return 0, fmt.Errorf("provided percentage %q is not a valid integer: %w", percentage, err)
}
if intPercentage < 0 || intPercentage > 100 {
return 0, fmt.Errorf("provided percentage \"%d\" is outside the valid range of [0-100]", intPercentage)
}
return intPercentage, nil
}

func selected(randInt *int, percentage int) bool {
if randInt != nil {
return *randInt < percentage
}
// Use a different random integer per feature as per
// https://github.com/grpc/proposal/blob/master/A33-Fault-Injection.md#evaluate-possibility-fraction.
return rand.IntN(100) < percentage
}

func sleep(ctx context.Context, d time.Duration) error {
t := time.NewTimer(d)
select {
case <-t.C:
case <-ctx.Done():
t.Stop()
return ctx.Err()
}
return nil
}

type InjectFaultParams[T any] struct {
Context context.Context
CallerName string

Address, Method string
AbortCodeMin, AbortCodeMax int

GetHeaderFn GetHeaderFn
ResumeFn ResumeFn[T]
ResponseFn ResponseFn[T]

randInt *int
sleepFn *sleepFn
}

func InjectFault[T any](params InjectFaultParams[T]) (T, error) {
faultHeaderAddress := params.GetHeaderFn(FaultServerAddressHeader)
requestAddress := getCanonicalAddress(params.Address)
if faultHeaderAddress == "" || faultHeaderAddress != requestAddress {
return params.ResumeFn()
}

serverMethod := params.GetHeaderFn(FaultServerMethodHeader)
if serverMethod != "" && serverMethod != params.Method {
return params.ResumeFn()
}

delayMs := params.GetHeaderFn(FaultDelayMsHeader)
if delayMs != "" {
percentage, err := parsePercentage(params.GetHeaderFn(FaultDelayPercentageHeader))
if err != nil {
slog.Warn(fmt.Sprintf("%s: %v", params.CallerName, err))
return params.ResumeFn()
}

if selected(params.randInt, percentage) {
delay, err := strconv.Atoi(delayMs)
if err != nil {
slog.Warn(fmt.Sprintf("%s: provided delay %q is not a valid integer", params.CallerName, delayMs))
return params.ResumeFn()
}

sleepFn := sleep
if params.sleepFn != nil {
sleepFn = *params.sleepFn
}
if err := sleepFn(params.Context, time.Duration(delay)*time.Millisecond); err != nil {
slog.Warn(fmt.Sprintf("%s: error when delaying request: %v", params.CallerName, err))
return params.ResumeFn()
}
}
}

abortCode := params.GetHeaderFn(FaultAbortCodeHeader)
if abortCode != "" {
percentage, err := parsePercentage(params.GetHeaderFn(FaultAbortPercentageHeader))
if err != nil {
slog.Warn(fmt.Sprintf("%s: %v", params.CallerName, err))
return params.ResumeFn()
}

if selected(params.randInt, percentage) {
code, err := strconv.Atoi(abortCode)
if err != nil {
slog.Warn(fmt.Sprintf("%s: provided abort code %q is not a valid integer", params.CallerName, abortCode))
return params.ResumeFn()
}
if code < params.AbortCodeMin || code > params.AbortCodeMax {
slog.Warn(fmt.Sprintf("%s: provided abort code \"%d\" is outside of the valid range", params.CallerName, code))
return params.ResumeFn()
}
abortMessage := params.GetHeaderFn(FaultAbortMessageHeader)
return params.ResponseFn(code, abortMessage)
}
}

return params.ResumeFn()
}
Loading
Loading