diff --git a/thriftbp/client_pool.go b/thriftbp/client_pool.go index 1f307d921..27a66efb9 100644 --- a/thriftbp/client_pool.go +++ b/thriftbp/client_pool.go @@ -211,6 +211,11 @@ type ClientPoolConfig struct { // // Optional. If this is empty, no "User-Agent" header will be sent. ClientName string `yaml:"clientName"` + + // The hostname to add as a "thrift-hostname" header. + // + // Optional. If empty, no "thrift-hostname" header will be sent. + ThriftHostnameHeader string `yaml:"thriftHostnameHeader"` } // Validate checks ClientPoolConfig for any missing or erroneous values. @@ -429,6 +434,24 @@ func NewCustomClientPoolWithContext( return newClientPool(ctx, cfg, genAddr, protoFactory, middlewares...) } +const ThriftHostnameHeader = "thrift-hostname" + +// thriftHostnameHeaderMiddleware adds a `thrift-hostname` header if one was +// specified in the configuration. +// This middleware is always added but will only add the header is necessary. +func thriftHostnameHeaderMiddleware(hostname string) thrift.ClientMiddleware { + return func(next thrift.TClient) thrift.TClient { + return thrift.WrappedTClient{ + Wrapped: func(ctx context.Context, method string, args, result thrift.TStruct) (thrift.ResponseMeta, error) { + if hostname != "" { + ctx = AddClientHeader(ctx, ThriftHostnameHeader, hostname) + } + return next.Call(ctx, method, args, result) + }, + } + } +} + func newClientPool( ctx context.Context, cfg ClientPoolConfig, @@ -505,6 +528,8 @@ func newClientPool( slug: cfg.ServiceSlug, } + middlewares = append(middlewares, thriftHostnameHeaderMiddleware(cfg.ThriftHostnameHeader)) + // finish setting up the clientPool by wrapping the inner "Call" with the // given middleware. // diff --git a/thriftbp/client_pool_test.go b/thriftbp/client_pool_test.go index 71e0de803..f31b09e6e 100644 --- a/thriftbp/client_pool_test.go +++ b/thriftbp/client_pool_test.go @@ -11,8 +11,11 @@ import ( "github.com/apache/thrift/lib/go/thrift" + "github.com/reddit/baseplate.go" "github.com/reddit/baseplate.go/ecinterface" + baseplatethrift "github.com/reddit/baseplate.go/internal/gen-go/reddit/baseplate" "github.com/reddit/baseplate.go/thriftbp" + "github.com/reddit/baseplate.go/thriftbp/thrifttest" ) const ( @@ -190,6 +193,49 @@ func TestBehaviorWithNetworkIssues(t *testing.T) { } } +type thriftHostnameHandler struct { + server baseplate.Server +} + +func (thriftHostnameHandler) IsHealthy(ctx context.Context, _ *baseplatethrift.IsHealthyRequest) (r bool, err error) { + value, ok := thrift.GetHeader(ctx, thriftbp.ThriftHostnameHeader) + if !ok { + return false, errors.New("did not find the thrift header") + } + if value != "my-thrift-header" { + return false, errors.New("unexpected value for the thrift header") + } + return true, nil +} + +func TestThriftHostnameHeader(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + store := newSecretsStore(t) + defer store.Close() + + handler := thriftHostnameHandler{} + server, err := thrifttest.NewBaseplateServer(thrifttest.ServerConfig{ + Processor: baseplatethrift.NewBaseplateServiceV2Processor(&handler), + SecretStore: store, + ClientConfig: thriftbp.ClientPoolConfig{ + ThriftHostnameHeader: "my-thrift-header", + }, + }) + if err != nil { + t.Fatal(err) + } + handler.server = server + server.Start(ctx) + + client := baseplatethrift.NewBaseplateServiceV2Client(server.ClientPool.TClient()) + _, err = client.IsHealthy(ctx, &baseplatethrift.IsHealthyRequest{}) + if err != nil { + t.Fatal(err) + } +} + func TestInitialConnectionsFallback(t *testing.T) { ln, err := net.Listen("tcp", addr) if err != nil {