From 8a0f0104b551dd7a89d1eee356cb2031cc2bcd63 Mon Sep 17 00:00:00 2001 From: bwplotka Date: Thu, 16 Jan 2025 09:24:05 +0000 Subject: [PATCH] http_config: Allow customizing TLS config and settings. Signed-off-by: bwplotka --- config/http_config.go | 53 ++++++++++++++++++++++++++++++------------- 1 file changed, 37 insertions(+), 16 deletions(-) diff --git a/config/http_config.go b/config/http_config.go index 57ec252a..b1422890 100644 --- a/config/http_config.go +++ b/config/http_config.go @@ -453,13 +453,14 @@ func (a *BasicAuth) UnmarshalYAML(unmarshal func(interface{}) error) error { type DialContextFunc func(context.Context, string, string) (net.Conn, error) type httpClientOptions struct { - dialContextFunc DialContextFunc - keepAlivesEnabled bool - http2Enabled bool - idleConnTimeout time.Duration - userAgent string - host string - secretManager SecretManager + dialContextFunc DialContextFunc + keepAlivesEnabled bool + http2Enabled bool + idleConnTimeout time.Duration + userAgent string + host string + secretManager SecretManager + extendTLSConfigFunc TLSConfigExtension } // HTTPClientOption defines an option that can be applied to the HTTP client. @@ -515,6 +516,17 @@ func WithHost(host string) HTTPClientOption { }) } +// TLSConfigExtension modifies the given tls config and settings. +type TLSConfigExtension func(*tls.Config, TLSRoundTripperSettings) (*tls.Config, TLSRoundTripperSettings, error) + +// WithTLSConfigExtension allows to insert extension function that can freely modify +// TLSConfig and TLSRoundTripperSettings used for the round tripper creation. +func WithTLSConfigExtension(extendTLSConfigFunc TLSConfigExtension) HTTPClientOption { + return httpClientOptionFunc(func(opts *httpClientOptions) { + opts.extendTLSConfigFunc = extendTLSConfigFunc + }) +} + type secretManagerOption struct { secretManager SecretManager } @@ -679,6 +691,15 @@ func NewRoundTripperFromConfigWithContext(ctx context.Context, cfg HTTPClientCon if err != nil { return nil, err } + + // Allow customizing the TLS config and settings, if specified in opts. + if opts.extendTLSConfigFunc != nil { + tlsConfig, tlsSettings, err = opts.extendTLSConfigFunc(tlsConfig, tlsSettings) + if err != nil { + return nil, err + } + } + if tlsSettings.immutable() { // No need for a RoundTripper that reloads the files automatically. return newRT(tlsConfig) @@ -1264,18 +1285,18 @@ func (t *TLSRoundTripperSettings) immutable() bool { } func NewTLSRoundTripper( - cfg *tls.Config, - settings TLSRoundTripperSettings, - newRT func(*tls.Config) (http.RoundTripper, error), + cfg *tls.Config, + settings TLSRoundTripperSettings, + newRT func(*tls.Config) (http.RoundTripper, error), ) (http.RoundTripper, error) { return NewTLSRoundTripperWithContext(context.Background(), cfg, settings, newRT) } func NewTLSRoundTripperWithContext( - ctx context.Context, - cfg *tls.Config, - settings TLSRoundTripperSettings, - newRT func(*tls.Config) (http.RoundTripper, error), + ctx context.Context, + cfg *tls.Config, + settings TLSRoundTripperSettings, + newRT func(*tls.Config) (http.RoundTripper, error), ) (http.RoundTripper, error) { t := &tlsRoundTripper{ settings: settings, @@ -1347,8 +1368,8 @@ func (t *tlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { t.mtx.RLock() equal := bytes.Equal(caHash[:], t.hashCAData) && - bytes.Equal(certHash[:], t.hashCertData) && - bytes.Equal(keyHash[:], t.hashKeyData) + bytes.Equal(certHash[:], t.hashCertData) && + bytes.Equal(keyHash[:], t.hashKeyData) rt := t.rt t.mtx.RUnlock() if equal {