From 9be049c8cd13b8aad85d5df6214cef208d767bdf Mon Sep 17 00:00:00 2001 From: Sukun Date: Mon, 4 Mar 2024 19:49:52 +0530 Subject: [PATCH] Limit size of encrypted packet queue Before we would queue unbounded encrypted data until the handshake finished/timed out. This sets a limit of 100 packets until the handshake completes. --- conn.go | 94 ++++++++++++++++++++++++++++++++++------------------ conn_test.go | 84 ++++++++++++++++++++++++++++++++++++++++++++++ resume.go | 6 +++- 3 files changed, 150 insertions(+), 34 deletions(-) diff --git a/conn.go b/conn.go index 332d9a211..e65163cf7 100644 --- a/conn.go +++ b/conn.go @@ -35,6 +35,9 @@ const ( inboundBufferSize = 8192 // Default replay protection window is specified by RFC 6347 Section 4.1.2.6 defaultReplayProtectionWindow = 64 + // maxAppDataPacketQueueSize is the maximum number of app data packets we will + // enqueue before the handshake is completed + maxAppDataPacketQueueSize = 100 ) func invalidKeyingLabels() map[string]bool { @@ -88,7 +91,7 @@ type Conn struct { replayProtectionWindow uint } -func createConn(ctx context.Context, nextConn net.PacketConn, rAddr net.Addr, config *Config, isClient bool, initialState *State) (*Conn, error) { +func createConn(nextConn net.PacketConn, rAddr net.Addr, config *Config, isClient bool) (*Conn, error) { if err := validateConfig(config); err != nil { return nil, err } @@ -97,21 +100,6 @@ func createConn(ctx context.Context, nextConn net.PacketConn, rAddr net.Addr, co return nil, errNilNextConn } - cipherSuites, err := parseCipherSuites(config.CipherSuites, config.CustomCipherSuites, config.includeCertificateSuites(), config.PSK != nil) - if err != nil { - return nil, err - } - - signatureSchemes, err := signaturehash.ParseSignatureSchemes(config.SignatureSchemes, config.InsecureHashes) - if err != nil { - return nil, err - } - - workerInterval := initialTickerInterval - if config.FlightInterval != 0 { - workerInterval = config.FlightInterval - } - loggerFactory := config.LoggerFactory if loggerFactory == nil { loggerFactory = logging.NewDefaultLoggerFactory() @@ -162,6 +150,28 @@ func createConn(ctx context.Context, nextConn net.PacketConn, rAddr net.Addr, co c.setRemoteEpoch(0) c.setLocalEpoch(0) + return c, nil +} + +func handshakeConn(ctx context.Context, conn *Conn, config *Config, isClient bool, initialState *State) (*Conn, error) { + if conn == nil { + return nil, errNilNextConn + } + + cipherSuites, err := parseCipherSuites(config.CipherSuites, config.CustomCipherSuites, config.includeCertificateSuites(), config.PSK != nil) + if err != nil { + return nil, err + } + + signatureSchemes, err := signaturehash.ParseSignatureSchemes(config.SignatureSchemes, config.InsecureHashes) + if err != nil { + return nil, err + } + + workerInterval := initialTickerInterval + if config.FlightInterval != 0 { + workerInterval = config.FlightInterval + } serverName := config.ServerName // Do not allow the use of an IP address literal as an SNI value. @@ -193,7 +203,7 @@ func createConn(ctx context.Context, nextConn net.PacketConn, rAddr net.Addr, co clientCAs: config.ClientCAs, customCipherSuites: config.CustomCipherSuites, retransmitInterval: workerInterval, - log: logger, + log: conn.log, initialEpoch: 0, keyLogWriter: config.KeyLogWriter, sessionStore: config.SessionStore, @@ -222,16 +232,16 @@ func createConn(ctx context.Context, nextConn net.PacketConn, rAddr net.Addr, co var initialFSMState handshakeState if initialState != nil { - if c.state.isClient { + if conn.state.isClient { initialFlight = flight5 } else { initialFlight = flight6 } initialFSMState = handshakeFinished - c.state = *initialState + conn.state = *initialState } else { - if c.state.isClient { + if conn.state.isClient { initialFlight = flight1 } else { initialFlight = flight0 @@ -239,13 +249,13 @@ func createConn(ctx context.Context, nextConn net.PacketConn, rAddr net.Addr, co initialFSMState = handshakePreparing } // Do handshake - if err := c.handshake(ctx, hsCfg, initialFlight, initialFSMState); err != nil { + if err := conn.handshake(ctx, hsCfg, initialFlight, initialFSMState); err != nil { return nil, err } - c.log.Trace("Handshake Completed") + conn.log.Trace("Handshake Completed") - return c, nil + return conn, nil } // Dial connects to the given network address and establishes a DTLS connection on top. @@ -301,7 +311,12 @@ func ClientWithContext(ctx context.Context, conn net.PacketConn, rAddr net.Addr, return nil, errPSKAndIdentityMustBeSetForClient } - return createConn(ctx, conn, rAddr, config, true, nil) + dconn, err := createConn(conn, rAddr, config, true) + if err != nil { + return nil, err + } + + return handshakeConn(ctx, dconn, config, true, nil) } // ServerWithContext listens for incoming DTLS connections. @@ -309,8 +324,11 @@ func ServerWithContext(ctx context.Context, conn net.PacketConn, rAddr net.Addr, if config == nil { return nil, errNoConfigProvided } - - return createConn(ctx, conn, rAddr, config, false, nil) + dconn, err := createConn(conn, rAddr, config, false) + if err != nil { + return nil, err + } + return handshakeConn(ctx, dconn, config, false, nil) } // Read reads data from the connection. @@ -738,6 +756,14 @@ func (c *Conn) handleQueuedPackets(ctx context.Context) error { return nil } +func (c *Conn) enqueueEncryptedPackets(packet addrPkt) bool { + if len(c.encryptedPackets) < maxAppDataPacketQueueSize { + c.encryptedPackets = append(c.encryptedPackets, packet) + return true + } + return false +} + func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.Addr, enqueue bool) (bool, *alert.Alert, error) { //nolint:gocognit h := &recordlayer.Header{} // Set connection ID size so that records of content type tls12_cid will @@ -751,7 +777,6 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A c.log.Debugf("discarded broken packet: %v", err) return false, nil, nil } - // Validate epoch remoteEpoch := c.state.getRemoteEpoch() if h.Epoch > remoteEpoch { @@ -762,8 +787,9 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A return false, nil, nil } if enqueue { - c.log.Debug("received packet of next epoch, queuing packet") - c.encryptedPackets = append(c.encryptedPackets, addrPkt{rAddr, buf}) + if ok := c.enqueueEncryptedPackets(addrPkt{rAddr, buf}); ok { + c.log.Debug("received packet of next epoch, queuing packet") + } } return false, nil, nil } @@ -790,8 +816,9 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A if h.Epoch != 0 { if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() { if enqueue { - c.encryptedPackets = append(c.encryptedPackets, addrPkt{rAddr, buf}) - c.log.Debug("handshake not finished, queuing packet") + if ok := c.enqueueEncryptedPackets(addrPkt{rAddr, buf}); ok { + c.log.Debug("handshake not finished, queuing packet") + } } return false, nil, nil } @@ -883,8 +910,9 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A case *protocol.ChangeCipherSpec: if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() { if enqueue { - c.encryptedPackets = append(c.encryptedPackets, addrPkt{rAddr, buf}) - c.log.Debugf("CipherSuite not initialized, queuing packet") + if ok := c.enqueueEncryptedPackets(addrPkt{rAddr, buf}); ok { + c.log.Debugf("CipherSuite not initialized, queuing packet") + } } return false, nil, nil } diff --git a/conn_test.go b/conn_test.go index 90d1b004e..17eb932f0 100644 --- a/conn_test.go +++ b/conn_test.go @@ -3279,3 +3279,87 @@ func (c *connWithCallback) Write(b []byte) (int, error) { } return c.Conn.Write(b) } + +func TestApplicationDataQueueLimited(t *testing.T) { + // Limit runtime in case of deadlocks + lim := test.TimeOut(time.Second * 20) + defer lim.Stop() + + // Check for leaking routines + report := test.CheckRoutines(t) + defer report() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + ca, cb := dpipe.Pipe() + defer ca.Close() //nolint:errcheck + defer cb.Close() //nolint:errcheck + + done := make(chan struct{}) + go func() { + serverCert, err := selfsign.GenerateSelfSigned() + if err != nil { + t.Error(err) + return + } + cfg := &Config{} + cfg.Certificates = []tls.Certificate{serverCert} + + dconn, err := createConn(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), cfg, false) + if err != nil { + t.Error(err) + return + } + go func() { + for i := 0; i < 5; i++ { + dconn.lock.RLock() + qlen := len(dconn.encryptedPackets) + dconn.lock.RUnlock() + if qlen > maxAppDataPacketQueueSize { + t.Error("too many encrypted packets enqueued", len(dconn.encryptedPackets)) + } + t.Log(qlen) + time.Sleep(1 * time.Second) + } + }() + if _, err := handshakeConn(ctx, dconn, cfg, false, nil); err == nil { + t.Error("expected handshake to fail") + } + close(done) + }() + extensions := []extension.Extension{} + + time.Sleep(50 * time.Millisecond) + + err := sendClientHello([]byte{}, ca, 0, extensions) + if err != nil { + t.Fatal(err) + } + + time.Sleep(50 * time.Millisecond) + + for i := 0; i < 1000; i++ { + // Send an application data packet + packet, err := (&recordlayer.RecordLayer{ + Header: recordlayer.Header{ + Version: protocol.Version1_2, + SequenceNumber: uint64(3), + Epoch: 1, // use an epoch greater than 0 + }, + Content: &protocol.ApplicationData{ + Data: []byte{1, 2, 3, 4}, + }, + }).Marshal() + if err != nil { + t.Fatal(err) + } + ca.Write(packet) // nolint + if i%100 == 0 { + time.Sleep(10 * time.Millisecond) + } + } + time.Sleep(1 * time.Second) + ca.Close() // nolint + <-done +} diff --git a/resume.go b/resume.go index 9e8a2ae42..6cd1c5a69 100644 --- a/resume.go +++ b/resume.go @@ -13,7 +13,11 @@ func Resume(state *State, conn net.PacketConn, rAddr net.Addr, config *Config) ( if err := state.initCipherSuite(); err != nil { return nil, err } - c, err := createConn(context.Background(), conn, rAddr, config, state.isClient, state) + dconn, err := createConn(conn, rAddr, config, state.isClient) + if err != nil { + return nil, err + } + c, err := handshakeConn(context.Background(), dconn, config, state.isClient, state) if err != nil { return nil, err }