diff --git a/conn.go b/conn.go index d82228f3..be417554 100644 --- a/conn.go +++ b/conn.go @@ -54,6 +54,11 @@ type addrPkt struct { data []byte } +type recvHandshakeState struct { + done chan struct{} + isRetransmit bool +} + // Conn represents a DTLS connection type Conn struct { lock sync.RWMutex // Internal lock (must not be public) @@ -82,7 +87,7 @@ type Conn struct { log logging.LeveledLogger reading chan struct{} - handshakeRecv chan chan struct{} + handshakeRecv chan recvHandshakeState cancelHandshaker func() cancelHandshakeReader func() @@ -137,7 +142,7 @@ func createConn(nextConn net.PacketConn, rAddr net.Addr, config *Config, isClien writeDeadline: deadline.New(), reading: make(chan struct{}, 1), - handshakeRecv: make(chan chan struct{}), + handshakeRecv: make(chan recvHandshakeState), closed: closer.NewCloser(), cancelHandshaker: func() {}, @@ -704,9 +709,9 @@ func (c *Conn) readAndBuffer(ctx context.Context) error { return err } - var hasHandshake bool + var hasHandshake, isRetransmit bool for _, p := range pkts { - hs, alert, err := c.handleIncomingPacket(ctx, p, rAddr, true) + hs, rtx, alert, err := c.handleIncomingPacket(ctx, p, rAddr, true) if alert != nil { if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil { if err == nil { @@ -725,14 +730,20 @@ func (c *Conn) readAndBuffer(ctx context.Context) error { if hs { hasHandshake = true } + if rtx { + isRetransmit = true + } } if hasHandshake { - done := make(chan struct{}) + s := recvHandshakeState{ + done: make(chan struct{}), + isRetransmit: isRetransmit, + } select { - case c.handshakeRecv <- done: + case c.handshakeRecv <- s: // If the other party may retransmit the flight, // we should respond even if it not a new message. - <-done + <-s.done case <-c.fsm.Done(): } } @@ -744,7 +755,7 @@ func (c *Conn) handleQueuedPackets(ctx context.Context) error { c.encryptedPackets = nil for _, p := range pkts { - _, alert, err := c.handleIncomingPacket(ctx, p.data, p.rAddr, false) // don't re-enqueue + _, _, alert, err := c.handleIncomingPacket(ctx, p.data, p.rAddr, false) // don't re-enqueue if alert != nil { if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil { if err == nil { @@ -771,7 +782,7 @@ func (c *Conn) enqueueEncryptedPackets(packet addrPkt) bool { return false } -func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.Addr, enqueue bool) (bool, *alert.Alert, error) { //nolint:gocognit +func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.Addr, enqueue bool) (bool, bool, *alert.Alert, error) { //nolint:gocognit h := &recordlayer.Header{} // Set connection ID size so that records of content type tls12_cid will // be parsed correctly. @@ -782,7 +793,7 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A // Decode error must be silently discarded // [RFC6347 Section-4.1.2.7] c.log.Debugf("discarded broken packet: %v", err) - return false, nil, nil + return false, false, nil, nil } // Validate epoch remoteEpoch := c.state.getRemoteEpoch() @@ -791,14 +802,14 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A c.log.Debugf("discarded future packet (epoch: %d, seq: %d)", h.Epoch, h.SequenceNumber, ) - return false, nil, nil + return false, false, nil, nil } if enqueue { if ok := c.enqueueEncryptedPackets(addrPkt{rAddr, buf}); ok { c.log.Debug("received packet of next epoch, queuing packet") } } - return false, nil, nil + return false, false, nil, nil } // Anti-replay protection @@ -812,7 +823,7 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A c.log.Debugf("discarded duplicated packet (epoch: %d, seq: %d)", h.Epoch, h.SequenceNumber, ) - return false, nil, nil + return false, false, nil, nil } // originalCID indicates whether the original record had content type @@ -827,14 +838,14 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A c.log.Debug("handshake not finished, queuing packet") } } - return false, nil, nil + return false, false, nil, nil } // If a connection identifier had been negotiated and encryption is // enabled, the connection identifier MUST be sent. if len(c.state.localConnectionID) > 0 && h.ContentType != protocol.ContentTypeConnectionID { c.log.Debug("discarded packet missing connection ID after value negotiated") - return false, nil, nil + return false, false, nil, nil } var err error @@ -845,7 +856,7 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A buf, err = c.state.cipherSuite.Decrypt(hdr, buf) if err != nil { c.log.Debugf("%s: decrypt failed: %s", srvCliStr(c.state.isClient), err) - return false, nil, nil + return false, false, nil, nil } // If this is a connection ID record, make it look like a normal record for // further processing. @@ -854,7 +865,7 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A ip := &recordlayer.InnerPlaintext{} if err := ip.Unmarshal(buf[h.Size():]); err != nil { //nolint:govet c.log.Debugf("unpacking inner plaintext failed: %s", err) - return false, nil, nil + return false, false, nil, nil } unpacked := &recordlayer.Header{ ContentType: ip.RealType, @@ -866,7 +877,7 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A buf, err = unpacked.Marshal() if err != nil { c.log.Debugf("converting CID record to inner plaintext failed: %s", err) - return false, nil, nil + return false, false, nil, nil } buf = append(buf, ip.Content...) } @@ -874,18 +885,19 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A // If connection ID does not match discard the packet. if !bytes.Equal(c.state.localConnectionID, h.ConnectionID) { c.log.Debug("unexpected connection ID") - return false, nil, nil + return false, false, nil, nil } } - isHandshake, err := c.fragmentBuffer.push(append([]byte{}, buf...)) + isHandshake, isRetransmit, err := c.fragmentBuffer.push(append([]byte{}, buf...)) if err != nil { // Decode error must be silently discarded // [RFC6347 Section-4.1.2.7] c.log.Debugf("defragment failed: %s", err) - return false, nil, nil + return false, false, nil, nil } else if isHandshake { markPacketAsValid() + for out, epoch := c.fragmentBuffer.pop(); out != nil; out, epoch = c.fragmentBuffer.pop() { header := &handshake.Header{} if err := header.Unmarshal(out); err != nil { @@ -895,12 +907,12 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A c.handshakeCache.push(out, epoch, header.MessageSequence, header.Type, !c.state.isClient) } - return true, nil, nil + return true, isRetransmit, nil, nil } r := &recordlayer.RecordLayer{} if err := r.Unmarshal(buf); err != nil { - return false, &alert.Alert{Level: alert.Fatal, Description: alert.DecodeError}, err + return false, false, &alert.Alert{Level: alert.Fatal, Description: alert.DecodeError}, err } isLatestSeqNum := false @@ -913,7 +925,7 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A a = &alert.Alert{Level: alert.Warning, Description: alert.CloseNotify} } _ = markPacketAsValid() - return false, a, &alertError{content} + return false, false, a, &alertError{content} case *protocol.ChangeCipherSpec: if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() { if enqueue { @@ -921,7 +933,7 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A c.log.Debugf("CipherSuite not initialized, queuing packet") } } - return false, nil, nil + return false, false, nil, nil } newRemoteEpoch := h.Epoch + 1 @@ -933,7 +945,7 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A } case *protocol.ApplicationData: if h.Epoch == 0 { - return false, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, errApplicationDataEpochZero + return false, false, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, errApplicationDataEpochZero } isLatestSeqNum = markPacketAsValid() @@ -945,7 +957,7 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A } default: - return false, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, fmt.Errorf("%w: %d", errUnhandledContextType, content.ContentType()) + return false, false, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, fmt.Errorf("%w: %d", errUnhandledContextType, content.ContentType()) } // Any valid connection ID record is a candidate for updating the remote @@ -959,10 +971,10 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A } } - return false, nil, nil + return false, false, nil, nil } -func (c *Conn) recvHandshake() <-chan chan struct{} { +func (c *Conn) recvHandshake() <-chan recvHandshakeState { return c.handshakeRecv } diff --git a/e2e/e2e_lossy_test.go b/e2e/e2e_lossy_test.go index 3e03037b..a306584b 100644 --- a/e2e/e2e_lossy_test.go +++ b/e2e/e2e_lossy_test.go @@ -45,10 +45,11 @@ func TestPionE2ELossy(t *testing.T) { } for _, test := range []struct { - LossChanceRange int - DoClientAuth bool - CipherSuites []dtls.CipherSuiteID - MTU int + LossChanceRange int + DoClientAuth bool + CipherSuites []dtls.CipherSuiteID + MTU int + DisableServerFlightInterval bool }{ { LossChanceRange: 0, @@ -109,6 +110,20 @@ func TestPionE2ELossy(t *testing.T) { MTU: 100, DoClientAuth: true, }, + // Incoming retransmitted handshakes should cause us to retransmit. Disabling the FlightInterval on one side + // means that a incoming re-transmissions causes the retransmission to be fired + { + LossChanceRange: 10, + DisableServerFlightInterval: true, + }, + { + LossChanceRange: 20, + DisableServerFlightInterval: true, + }, + { + LossChanceRange: 50, + DisableServerFlightInterval: true, + }, } { name := fmt.Sprintf("Loss%d_MTU%d", test.LossChanceRange, test.MTU) if test.DoClientAuth { @@ -117,6 +132,10 @@ func TestPionE2ELossy(t *testing.T) { for _, ciph := range test.CipherSuites { name += "_With" + ciph.String() } + if test.DisableServerFlightInterval { + name += "_WithNoServerFlightInterval" + } + test := test t.Run(name, func(t *testing.T) { // Limit runtime in case of deadlocks @@ -162,6 +181,10 @@ func TestPionE2ELossy(t *testing.T) { cfg.ClientAuth = dtls.RequireAnyClientCert } + if test.DisableServerFlightInterval { + cfg.FlightInterval = time.Hour + } + server, startupErr := dtls.Server(dtlsnet.PacketConnFromConn(br.GetConn1()), br.GetConn1().RemoteAddr(), cfg) serverDone <- runResult{server, startupErr} }() diff --git a/flight1handler_test.go b/flight1handler_test.go index 2d590dcc..461ef7cf 100644 --- a/flight1handler_test.go +++ b/flight1handler_test.go @@ -21,7 +21,7 @@ func (f *flight1TestMockFlightConn) notify(context.Context, alert.Level, alert.D return nil } func (f *flight1TestMockFlightConn) writePackets(context.Context, []*packet) error { return nil } -func (f *flight1TestMockFlightConn) recvHandshake() <-chan chan struct{} { return nil } +func (f *flight1TestMockFlightConn) recvHandshake() <-chan recvHandshakeState { return nil } func (f *flight1TestMockFlightConn) setLocalEpoch(uint16) {} func (f *flight1TestMockFlightConn) handleQueuedPackets(context.Context) error { return nil } func (f *flight1TestMockFlightConn) sessionKey() []byte { return nil } diff --git a/flight4handler_test.go b/flight4handler_test.go index f4446c40..bf99a160 100644 --- a/flight4handler_test.go +++ b/flight4handler_test.go @@ -27,7 +27,7 @@ func (f *flight4TestMockFlightConn) notify(context.Context, alert.Level, alert.D return nil } func (f *flight4TestMockFlightConn) writePackets(context.Context, []*packet) error { return nil } -func (f *flight4TestMockFlightConn) recvHandshake() <-chan chan struct{} { return nil } +func (f *flight4TestMockFlightConn) recvHandshake() <-chan recvHandshakeState { return nil } func (f *flight4TestMockFlightConn) setLocalEpoch(uint16) {} func (f *flight4TestMockFlightConn) handleQueuedPackets(context.Context) error { return nil } func (f *flight4TestMockFlightConn) sessionKey() []byte { return nil } diff --git a/fragment_buffer.go b/fragment_buffer.go index fb5af6c3..f7785ec3 100644 --- a/fragment_buffer.go +++ b/fragment_buffer.go @@ -43,26 +43,29 @@ func (f *fragmentBuffer) size() int { // Attempts to push a DTLS packet to the fragmentBuffer // when it returns true it means the fragmentBuffer has inserted and the buffer shouldn't be handled // when an error returns it is fatal, and the DTLS connection should be stopped -func (f *fragmentBuffer) push(buf []byte) (bool, error) { +func (f *fragmentBuffer) push(buf []byte) (isHandshake, isRetransmit bool, err error) { if f.size()+len(buf) >= fragmentBufferMaxSize { - return false, errFragmentBufferOverflow + return false, false, errFragmentBufferOverflow } frag := new(fragment) if err := frag.recordLayerHeader.Unmarshal(buf); err != nil { - return false, err + return false, false, err } // fragment isn't a handshake, we don't need to handle it if frag.recordLayerHeader.ContentType != protocol.ContentTypeHandshake { - return false, nil + return false, false, nil } for buf = buf[recordlayer.FixedHeaderSize:]; len(buf) != 0; frag = new(fragment) { if err := frag.handshakeHeader.Unmarshal(buf); err != nil { - return false, err + return false, false, err } + // Fragment is a retransmission. We have already assembled it before successfully + isRetransmit = frag.handshakeHeader.FragmentOffset == 0 && frag.handshakeHeader.MessageSequence < f.currentMessageSequenceNumber + if _, ok := f.cache[frag.handshakeHeader.MessageSequence]; !ok { f.cache[frag.handshakeHeader.MessageSequence] = []*fragment{} } @@ -80,7 +83,7 @@ func (f *fragmentBuffer) push(buf []byte) (bool, error) { buf = buf[end:] } - return true, nil + return true, isRetransmit, nil } func (f *fragmentBuffer) pop() (content []byte, epoch uint16) { diff --git a/fragment_buffer_test.go b/fragment_buffer_test.go index ad8834e7..2b2f62c7 100644 --- a/fragment_buffer_test.go +++ b/fragment_buffer_test.go @@ -94,7 +94,7 @@ func TestFragmentBuffer(t *testing.T) { } { fragmentBuffer := newFragmentBuffer() for _, frag := range test.In { - status, err := fragmentBuffer.push(frag) + status, _, err := fragmentBuffer.push(frag) if err != nil { t.Error(err) } else if !status { @@ -122,13 +122,13 @@ func TestFragmentBuffer_Overflow(t *testing.T) { fragmentBuffer := newFragmentBuffer() // Push a buffer that doesn't exceed size limits - if _, err := fragmentBuffer.push([]byte{0x16, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0F, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xfe, 0xff, 0x00}); err != nil { + if _, _, err := fragmentBuffer.push([]byte{0x16, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0F, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xfe, 0xff, 0x00}); err != nil { t.Fatal(err) } // Allocate a buffer that exceeds cache size largeBuffer := make([]byte, fragmentBufferMaxSize) - if _, err := fragmentBuffer.push(largeBuffer); !errors.Is(err, errFragmentBufferOverflow) { + if _, _, err := fragmentBuffer.push(largeBuffer); !errors.Is(err, errFragmentBufferOverflow) { t.Fatalf("Pushing a large buffer returned (%s) expected(%s)", err, errFragmentBufferOverflow) } } diff --git a/handshaker.go b/handshaker.go index a7125f8d..a585c3db 100644 --- a/handshaker.go +++ b/handshaker.go @@ -137,7 +137,7 @@ type handshakeConfig struct { type flightConn interface { notify(ctx context.Context, level alert.Level, desc alert.Description) error writePackets(context.Context, []*packet) error - recvHandshake() <-chan chan struct{} + recvHandshake() <-chan recvHandshakeState setLocalEpoch(epoch uint16) handleQueuedPackets(context.Context) error sessionKey() []byte @@ -280,10 +280,15 @@ func (s *handshakeFSM) wait(ctx context.Context, c flightConn) (handshakeState, retransmitTimer := time.NewTimer(s.retransmitInterval) for { select { - case done := <-c.recvHandshake(): + case state := <-c.recvHandshake(): + if state.isRetransmit { + close(state.done) + return handshakeSending, nil + } + nextFlight, alert, err := parse(ctx, c, s.state, s.cache, s.cfg) s.retransmitInterval = s.cfg.initialRetransmitInterval - close(done) + close(state.done) if alert != nil { if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil { if err != nil { @@ -328,8 +333,8 @@ func (s *handshakeFSM) wait(ctx context.Context, c flightConn) (handshakeState, func (s *handshakeFSM) finish(ctx context.Context, c flightConn) (handshakeState, error) { select { - case done := <-c.recvHandshake(): - close(done) + case state := <-c.recvHandshake(): + close(state.done) return handshakeSending, nil case <-ctx.Done(): return handshakeErrored, ctx.Err() diff --git a/handshaker_test.go b/handshaker_test.go index b7d4f1bf..b814fc90 100644 --- a/handshaker_test.go +++ b/handshaker_test.go @@ -349,8 +349,8 @@ type TestEndpoint struct { func flightTestPipe(ctx context.Context, clientEndpoint TestEndpoint, serverEndpoint TestEndpoint) (*flightTestConn, *flightTestConn) { ca := newHandshakeCache() cb := newHandshakeCache() - chA := make(chan chan struct{}) - chB := make(chan chan struct{}) + chA := make(chan recvHandshakeState) + chB := make(chan recvHandshakeState) return &flightTestConn{ handshakeCache: ca, otherEndCache: cb, @@ -373,7 +373,7 @@ func flightTestPipe(ctx context.Context, clientEndpoint TestEndpoint, serverEndp type flightTestConn struct { state State handshakeCache *handshakeCache - recv chan chan struct{} + recv chan recvHandshakeState done <-chan struct{} epoch uint16 @@ -382,10 +382,10 @@ type flightTestConn struct { delay time.Duration otherEndCache *handshakeCache - otherEndRecv chan chan struct{} + otherEndRecv chan recvHandshakeState } -func (c *flightTestConn) recvHandshake() <-chan chan struct{} { +func (c *flightTestConn) recvHandshake() <-chan recvHandshakeState { return c.recv } @@ -427,7 +427,7 @@ func (c *flightTestConn) writePackets(_ context.Context, pkts []*packet) error { } go func() { select { - case c.otherEndRecv <- make(chan struct{}): + case c.otherEndRecv <- recvHandshakeState{done: make(chan struct{})}: case <-c.done: } }()