From f1f3b66f2ea86dc1356198b7cca4d22d89c332cc Mon Sep 17 00:00:00 2001 From: Marc Lopez Rubio Date: Thu, 13 Feb 2025 16:19:16 +0800 Subject: [PATCH] TLS: Add support for CA certificate in Kafka (#625) This commit adds support for setting the TLS CA in the Kafka TLS configuration. It supports hot reloading the certificate if it's updated. --------- Signed-off-by: Marc Lopez Rubio --- kafka/common.go | 95 ++++++++++++++++++++++++++++- kafka/common_test.go | 141 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 235 insertions(+), 1 deletion(-) diff --git a/kafka/common.go b/kafka/common.go index ae2ff716..95f31072 100644 --- a/kafka/common.go +++ b/kafka/common.go @@ -20,11 +20,14 @@ package kafka import ( "context" "crypto/tls" + "crypto/x509" "errors" "fmt" "net" "os" "strings" + "sync" + "sync/atomic" "time" awsconfig "github.com/aws/aws-sdk-go-v2/config" @@ -212,9 +215,28 @@ func (cfg *CommonConfig) finalize() error { case cfg.TLS == nil && cfg.Dialer == nil && os.Getenv("KAFKA_PLAINTEXT") != "true": // Auto-configure TLS from environment variables. cfg.TLS = &tls.Config{} - if os.Getenv("KAFKA_TLS_INSECURE") == "true" { + tlsInsecure := os.Getenv("KAFKA_TLS_INSECURE") == "true" + caCertPath := os.Getenv("KAFKA_TLS_CA_CERT_PATH") + if tlsInsecure && caCertPath != "" { + errs = append(errs, errors.New( + "kafka: cannot set both KAFKA_TLS_INSECURE and KAFKA_TLS_CA_CERT_PATH", + )) + break + } + if tlsInsecure { cfg.TLS.InsecureSkipVerify = true } + if caCertPath != "" { + // Auto-configure a dialer that reloads the CA cert when the file + // changes. + dialFn, err := newCertReloadingDialer(caCertPath, cfg.TLS) + if err != nil { + errs = append(errs, fmt.Errorf("kafka: error creating dialer with CA cert: %w", err)) + break + } + cfg.Dialer = dialFn + cfg.TLS = nil + } } if cfg.SASL == nil { saslConfig := saslConfigProperties{ @@ -355,3 +377,74 @@ func topicFieldFunc(f TopicLogFieldFunc) TopicLogFieldFunc { return zap.Skip() } } + +// newCertReloadingDialer returns a dialer that reloads the CA cert when the +// file mod time changes. +func newCertReloadingDialer(caPath string, tlsCfg *tls.Config) (func(ctx context.Context, network, address string) (net.Conn, error), error) { + p, err := os.Stat(caPath) + if err != nil { + return nil, err + } + dialer := &net.Dialer{Timeout: 10 * time.Second} // default dialer timeout in kgo. + cfg := tlsCfg.Clone() + caCert, err := os.ReadFile(caPath) + if err != nil { + return nil, fmt.Errorf("kafka: failed to read CA cert: %w", err) + } + cfg.RootCAs = x509.NewCertPool() + if !cfg.RootCAs.AppendCertsFromPEM(caCert) { + return nil, errors.New("kafka: failed to append CA cert") + } + var certModTS atomic.Int64 + certModTS.Store(p.ModTime().UnixNano()) + var mu sync.RWMutex // guards cfg.RootCAs and certModTS + return func(ctx context.Context, network, host string) (net.Conn, error) { + if p, err := os.Stat(caPath); err == nil { + if modTS := certModTS.Load(); p.ModTime().UnixNano() != modTS { + if err := func() error { // anonymous function to defer unlock. + mu.Lock() + defer mu.Unlock() + currentModTS := p.ModTime().UnixNano() + if modTS := certModTS.Load(); currentModTS != modTS { + caCert, err := os.ReadFile(caPath) + if err != nil { + return fmt.Errorf( + "failed to read CA cert on reload: %w", err, + ) + } + if len(caCert) == 0 { + // Nothing is written to the file yet, it may be in + // the process of being written. Return early, since + // we cannot reload the cert yet. + return nil + } + cfg.RootCAs = x509.NewCertPool() + if !cfg.RootCAs.AppendCertsFromPEM(caCert) { + return errors.New("failed to append CA cert on reload") + } + certModTS.Store(currentModTS) + } + return nil + }(); err != nil { + return nil, fmt.Errorf("kafka: %w", err) + } + } + } + mu.RLock() + c := cfg.Clone() + mu.RUnlock() + // Copied this pattern from franz-go client.go. + // https://github.com/twmb/franz-go/blob/f30c518d6b727b9169a90b8c10e2127301822a3a/pkg/kgo/client.go#L440-L453 + if c.ServerName == "" { + server, _, err := net.SplitHostPort(host) + if err != nil { + return nil, fmt.Errorf("dialer: unable to split host:port for dialing: %w", err) + } + c.ServerName = server + } + return (&tls.Dialer{ + NetDialer: dialer, + Config: c, + }).DialContext(ctx, network, host) + }, nil +} diff --git a/kafka/common_test.go b/kafka/common_test.go index 7cddc7f4..8c6ab6f0 100644 --- a/kafka/common_test.go +++ b/kafka/common_test.go @@ -19,12 +19,24 @@ package kafka import ( "context" + "crypto/rand" + "crypto/rsa" "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "errors" "fmt" + "io" + "math/big" "net" + "net/http" + "net/http/httptest" "os" "path/filepath" + "runtime" "strings" + "sync" "testing" "time" @@ -64,6 +76,17 @@ func TestCommonConfig(t *testing.T) { "kafka: at least one broker must be set", ) }) + t.Run("invalid KAFKA_TLS_INSECURE and KAFKA_TLS_CA_CERT_PATH", func(t *testing.T) { + t.Setenv("KAFKA_TLS_INSECURE", "true") + t.Setenv("KAFKA_TLS_CA_CERT_PATH", "ca_cert.pem") + t.Setenv("KAFKA_PLAINTEXT", "") + assertErrors(t, CommonConfig{ + Brokers: []string{"broker"}, + Logger: zap.NewNop(), + }, + "kafka: cannot set both KAFKA_TLS_INSECURE and KAFKA_TLS_CA_CERT_PATH", + ) + }) t.Run("tls_or_dialer", func(t *testing.T) { assertErrors(t, CommonConfig{ @@ -318,3 +341,121 @@ func TestTopicFieldFunc(t *testing.T) { assert.Equal(t, zap.String("topic", "c"), topic) }) } + +// generateValidCACert creates a valid self-signed CA certificate in PEM format. +func generateValidCACert(t testing.TB) []byte { + key, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + template := x509.Certificate{ + SerialNumber: big.NewInt(int64(time.Now().Year())), + Subject: pkix.Name{CommonName: "Test CA"}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + IsCA: true, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + BasicConstraintsValid: true, + } + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key) + require.NoError(t, err) + + return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) +} + +func TestTLSCACertPath(t *testing.T) { + t.Run("valid cert", func(t *testing.T) { + t.Setenv("KAFKA_PLAINTEXT", "") // clear plaintext mode + + tempFile := filepath.Join(t.TempDir(), "ca_cert.pem") + validCert := generateValidCACert(t) + err := os.WriteFile(tempFile, validCert, 0644) + require.NoError(t, err) + + t.Setenv("KAFKA_TLS_CA_CERT_PATH", tempFile) + cfg := CommonConfig{Brokers: []string{"broker"}, Logger: zap.NewNop()} + require.NoError(t, cfg.finalize()) + require.NotNil(t, cfg.Dialer) + require.Nil(t, cfg.TLS) + }) + t.Run("missing file", func(t *testing.T) { + t.Setenv("KAFKA_PLAINTEXT", "") + tempFile := filepath.Join(t.TempDir(), "nonexistent_cert.pem") + t.Setenv("KAFKA_TLS_CA_CERT_PATH", tempFile) + cfg := CommonConfig{Brokers: []string{"broker"}, Logger: zap.NewNop()} + err := cfg.finalize() + require.Error(t, err) + require.Contains(t, err.Error(), "kafka: error creating dialer with CA cert") + require.Contains(t, err.Error(), "no such file or directory") + }) + t.Run("invalid cert", func(t *testing.T) { + t.Setenv("KAFKA_PLAINTEXT", "") + tempFile := filepath.Join(t.TempDir(), "invalid_cert.pem") + err := os.WriteFile(tempFile, []byte("invalid pem data"), 0644) + require.NoError(t, err) + + t.Setenv("KAFKA_TLS_CA_CERT_PATH", tempFile) + cfg := CommonConfig{Brokers: []string{"broker"}, Logger: zap.NewNop()} + err = cfg.finalize() + require.Error(t, err) + require.Contains(t, err.Error(), "kafka: error creating dialer with CA cert") + }) +} + +func TestTLSHotReload(t *testing.T) { + srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + t.Cleanup(srv.Close) + + // Get the certificate from the test server and encode it in PEM format. + certPEM := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", Bytes: srv.TLS.Certificates[0].Certificate[0], + }) + tempFile := filepath.Join(t.TempDir(), "cert.pem") + require.NoError(t, os.WriteFile(tempFile, certPEM, 0644)) + + dialFunc, err := newCertReloadingDialer(tempFile, &tls.Config{}) + require.NoError(t, err) + + var wg sync.WaitGroup + addr := srv.Listener.Addr() + ctx, cancel := context.WithCancel(context.Background()) + for i := 0; i < runtime.GOMAXPROCS(0)*4; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case <-ctx.Done(): + return + case <-time.After(time.Millisecond): + conn, err := dialFunc(ctx, addr.Network(), addr.String()) + if err != nil && !errors.Is(err, io.EOF) { + select { + case <-ctx.Done(): + return + default: + } + // Ensure no TLS errors occur. + require.NoError(t, err) + } + if conn != nil { + conn.Close() + } + } + } + }() + } + + <-time.After(200 * time.Millisecond) + + for i := 0; i < runtime.GOMAXPROCS(0); i++ { + // Update the file, so that the CA cert is reloaded when dialer is called again. + require.NoError(t, os.WriteFile(tempFile, certPEM, 0644)) + <-time.After(50 * time.Millisecond) + } + + cancel() // allow go routines to exit + wg.Wait() // wait for all go routines to finish +}