Skip to content

Commit

Permalink
TLS: Add support for CA certificate in Kafka (#625)
Browse files Browse the repository at this point in the history
This commit adds support for setting the TLS CA in the Kafka TLS configuration.

It supports hot reloading the certificate if it's updated.

---------

Signed-off-by: Marc Lopez Rubio <[email protected]>
  • Loading branch information
marclop authored Feb 13, 2025
1 parent 59896b8 commit f1f3b66
Show file tree
Hide file tree
Showing 2 changed files with 235 additions and 1 deletion.
95 changes: 94 additions & 1 deletion kafka/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,14 @@ package kafka
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"net"
"os"
"strings"
"sync"
"sync/atomic"
"time"

awsconfig "github.com/aws/aws-sdk-go-v2/config"
Expand Down Expand Up @@ -212,9 +215,28 @@ func (cfg *CommonConfig) finalize() error {
case cfg.TLS == nil && cfg.Dialer == nil && os.Getenv("KAFKA_PLAINTEXT") != "true":
// Auto-configure TLS from environment variables.
cfg.TLS = &tls.Config{}
if os.Getenv("KAFKA_TLS_INSECURE") == "true" {
tlsInsecure := os.Getenv("KAFKA_TLS_INSECURE") == "true"
caCertPath := os.Getenv("KAFKA_TLS_CA_CERT_PATH")
if tlsInsecure && caCertPath != "" {
errs = append(errs, errors.New(
"kafka: cannot set both KAFKA_TLS_INSECURE and KAFKA_TLS_CA_CERT_PATH",
))
break
}
if tlsInsecure {
cfg.TLS.InsecureSkipVerify = true
}
if caCertPath != "" {
// Auto-configure a dialer that reloads the CA cert when the file
// changes.
dialFn, err := newCertReloadingDialer(caCertPath, cfg.TLS)
if err != nil {
errs = append(errs, fmt.Errorf("kafka: error creating dialer with CA cert: %w", err))
break
}
cfg.Dialer = dialFn
cfg.TLS = nil
}
}
if cfg.SASL == nil {
saslConfig := saslConfigProperties{
Expand Down Expand Up @@ -355,3 +377,74 @@ func topicFieldFunc(f TopicLogFieldFunc) TopicLogFieldFunc {
return zap.Skip()
}
}

// newCertReloadingDialer returns a dialer that reloads the CA cert when the
// file mod time changes.
func newCertReloadingDialer(caPath string, tlsCfg *tls.Config) (func(ctx context.Context, network, address string) (net.Conn, error), error) {
p, err := os.Stat(caPath)
if err != nil {
return nil, err
}
dialer := &net.Dialer{Timeout: 10 * time.Second} // default dialer timeout in kgo.
cfg := tlsCfg.Clone()
caCert, err := os.ReadFile(caPath)
if err != nil {
return nil, fmt.Errorf("kafka: failed to read CA cert: %w", err)
}
cfg.RootCAs = x509.NewCertPool()
if !cfg.RootCAs.AppendCertsFromPEM(caCert) {
return nil, errors.New("kafka: failed to append CA cert")
}
var certModTS atomic.Int64
certModTS.Store(p.ModTime().UnixNano())
var mu sync.RWMutex // guards cfg.RootCAs and certModTS
return func(ctx context.Context, network, host string) (net.Conn, error) {
if p, err := os.Stat(caPath); err == nil {
if modTS := certModTS.Load(); p.ModTime().UnixNano() != modTS {
if err := func() error { // anonymous function to defer unlock.
mu.Lock()
defer mu.Unlock()
currentModTS := p.ModTime().UnixNano()
if modTS := certModTS.Load(); currentModTS != modTS {
caCert, err := os.ReadFile(caPath)
if err != nil {
return fmt.Errorf(
"failed to read CA cert on reload: %w", err,
)
}
if len(caCert) == 0 {
// Nothing is written to the file yet, it may be in
// the process of being written. Return early, since
// we cannot reload the cert yet.
return nil
}
cfg.RootCAs = x509.NewCertPool()
if !cfg.RootCAs.AppendCertsFromPEM(caCert) {
return errors.New("failed to append CA cert on reload")
}
certModTS.Store(currentModTS)
}
return nil
}(); err != nil {
return nil, fmt.Errorf("kafka: %w", err)
}
}
}
mu.RLock()
c := cfg.Clone()
mu.RUnlock()
// Copied this pattern from franz-go client.go.
// https://github.com/twmb/franz-go/blob/f30c518d6b727b9169a90b8c10e2127301822a3a/pkg/kgo/client.go#L440-L453
if c.ServerName == "" {
server, _, err := net.SplitHostPort(host)
if err != nil {
return nil, fmt.Errorf("dialer: unable to split host:port for dialing: %w", err)
}
c.ServerName = server
}
return (&tls.Dialer{
NetDialer: dialer,
Config: c,
}).DialContext(ctx, network, host)
}, nil
}
141 changes: 141 additions & 0 deletions kafka/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,24 @@ package kafka

import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"errors"
"fmt"
"io"
"math/big"
"net"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"runtime"
"strings"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -64,6 +76,17 @@ func TestCommonConfig(t *testing.T) {
"kafka: at least one broker must be set",
)
})
t.Run("invalid KAFKA_TLS_INSECURE and KAFKA_TLS_CA_CERT_PATH", func(t *testing.T) {
t.Setenv("KAFKA_TLS_INSECURE", "true")
t.Setenv("KAFKA_TLS_CA_CERT_PATH", "ca_cert.pem")
t.Setenv("KAFKA_PLAINTEXT", "")
assertErrors(t, CommonConfig{
Brokers: []string{"broker"},
Logger: zap.NewNop(),
},
"kafka: cannot set both KAFKA_TLS_INSECURE and KAFKA_TLS_CA_CERT_PATH",
)
})

t.Run("tls_or_dialer", func(t *testing.T) {
assertErrors(t, CommonConfig{
Expand Down Expand Up @@ -318,3 +341,121 @@ func TestTopicFieldFunc(t *testing.T) {
assert.Equal(t, zap.String("topic", "c"), topic)
})
}

// generateValidCACert creates a valid self-signed CA certificate in PEM format.
func generateValidCACert(t testing.TB) []byte {
key, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)

template := x509.Certificate{
SerialNumber: big.NewInt(int64(time.Now().Year())),
Subject: pkix.Name{CommonName: "Test CA"},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
require.NoError(t, err)

return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
}

func TestTLSCACertPath(t *testing.T) {
t.Run("valid cert", func(t *testing.T) {
t.Setenv("KAFKA_PLAINTEXT", "") // clear plaintext mode

tempFile := filepath.Join(t.TempDir(), "ca_cert.pem")
validCert := generateValidCACert(t)
err := os.WriteFile(tempFile, validCert, 0644)
require.NoError(t, err)

t.Setenv("KAFKA_TLS_CA_CERT_PATH", tempFile)
cfg := CommonConfig{Brokers: []string{"broker"}, Logger: zap.NewNop()}
require.NoError(t, cfg.finalize())
require.NotNil(t, cfg.Dialer)
require.Nil(t, cfg.TLS)
})
t.Run("missing file", func(t *testing.T) {
t.Setenv("KAFKA_PLAINTEXT", "")
tempFile := filepath.Join(t.TempDir(), "nonexistent_cert.pem")
t.Setenv("KAFKA_TLS_CA_CERT_PATH", tempFile)
cfg := CommonConfig{Brokers: []string{"broker"}, Logger: zap.NewNop()}
err := cfg.finalize()
require.Error(t, err)
require.Contains(t, err.Error(), "kafka: error creating dialer with CA cert")
require.Contains(t, err.Error(), "no such file or directory")
})
t.Run("invalid cert", func(t *testing.T) {
t.Setenv("KAFKA_PLAINTEXT", "")
tempFile := filepath.Join(t.TempDir(), "invalid_cert.pem")
err := os.WriteFile(tempFile, []byte("invalid pem data"), 0644)
require.NoError(t, err)

t.Setenv("KAFKA_TLS_CA_CERT_PATH", tempFile)
cfg := CommonConfig{Brokers: []string{"broker"}, Logger: zap.NewNop()}
err = cfg.finalize()
require.Error(t, err)
require.Contains(t, err.Error(), "kafka: error creating dialer with CA cert")
})
}

func TestTLSHotReload(t *testing.T) {
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
t.Cleanup(srv.Close)

// Get the certificate from the test server and encode it in PEM format.
certPEM := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE", Bytes: srv.TLS.Certificates[0].Certificate[0],
})
tempFile := filepath.Join(t.TempDir(), "cert.pem")
require.NoError(t, os.WriteFile(tempFile, certPEM, 0644))

dialFunc, err := newCertReloadingDialer(tempFile, &tls.Config{})
require.NoError(t, err)

var wg sync.WaitGroup
addr := srv.Listener.Addr()
ctx, cancel := context.WithCancel(context.Background())
for i := 0; i < runtime.GOMAXPROCS(0)*4; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-ctx.Done():
return
case <-time.After(time.Millisecond):
conn, err := dialFunc(ctx, addr.Network(), addr.String())
if err != nil && !errors.Is(err, io.EOF) {
select {
case <-ctx.Done():
return
default:
}
// Ensure no TLS errors occur.
require.NoError(t, err)
}
if conn != nil {
conn.Close()
}
}
}
}()
}

<-time.After(200 * time.Millisecond)

for i := 0; i < runtime.GOMAXPROCS(0); i++ {
// Update the file, so that the CA cert is reloaded when dialer is called again.
require.NoError(t, os.WriteFile(tempFile, certPEM, 0644))
<-time.After(50 * time.Millisecond)
}

cancel() // allow go routines to exit
wg.Wait() // wait for all go routines to finish
}

0 comments on commit f1f3b66

Please sign in to comment.