From 23f7b422e4b8fb862cfc8d067217d8d60a4829ac Mon Sep 17 00:00:00 2001 From: Jerko Steiner Date: Sat, 20 Feb 2021 10:41:43 +0100 Subject: [PATCH] Add sctp.Association.Abort(reason string) method According to RFC 4960. Also see my comment at #176. Closes #182. I don't think we currently do any tag verification for any packets, but we can implement that later. --- association.go | 74 +++++++++++++++++++++++++++-- association_test.go | 57 ++++++++++++++++++++-- error_cause.go | 3 ++ error_cause_user_initiated_abort.go | 46 ++++++++++++++++++ 4 files changed, 171 insertions(+), 9 deletions(-) create mode 100644 error_cause_user_initiated_abort.go diff --git a/association.go b/association.go index e66634f6..71e2415d 100644 --- a/association.go +++ b/association.go @@ -152,6 +152,9 @@ type Association struct { willSendShutdownAck bool willSendShutdownComplete bool + willSendAbort bool + willSendAbortCause errorCause + // Reconfig myNextRSN uint32 reconfigs map[uint32]*chunkReconfig @@ -469,6 +472,26 @@ func (a *Association) close() error { return err } +// Abort sends the abort packet with user initiated abort and immediately +// closes the connection. +func (a *Association) Abort(reason string) { + a.log.Debugf("[%s] aborting association: %s", a.name, reason) + + a.lock.Lock() + + a.willSendAbort = true + a.willSendAbortCause = &errorCauseUserInitiatedAbort{ + upperLayerAbortReason: []byte(reason), + } + + a.lock.Unlock() + + a.awakeWriteLoop() + + // Wait for readLoop to end + <-a.readLoopCloseCh +} + func (a *Association) closeAllTimers() { // Close all retransmission & ack timers a.t1Init.close() @@ -829,12 +852,39 @@ func (a *Association) gatherOutboundShutdownPackets(rawPackets [][]byte) ([][]by return rawPackets, ok } +func (a *Association) gatherAbortPacket() ([]byte, error) { + cause := a.willSendAbortCause + + a.willSendAbort = false + a.willSendAbortCause = nil + + abort := &chunkAbort{} + + if cause != nil { + abort.errorCauses = []errorCause{cause} + } + + raw, err := a.createPacket([]chunk{abort}).marshal() + + return raw, err +} + // gatherOutbound gathers outgoing packets. The returned bool value set to // false means the association should be closed down after the final send. func (a *Association) gatherOutbound() ([][]byte, bool) { a.lock.Lock() defer a.lock.Unlock() + if a.willSendAbort { + pkt, err := a.gatherAbortPacket() + if err != nil { + a.log.Warnf("[%s] failed to serialize an abort packet", a.name) + return nil, false + } + + return [][]byte{pkt}, false + } + rawPackets := [][]byte{} if a.controlQueue.size() > 0 { @@ -1747,6 +1797,17 @@ func (a *Association) handleShutdownComplete(_ *chunkShutdownComplete) error { return nil } +func (a *Association) handleAbort(c *chunkAbort) error { + var errStr string + for _, e := range c.errorCauses { + errStr += fmt.Sprintf("(%s)", e) + } + + _ = a.close() + + return fmt.Errorf("[%s] %w: %s", a.name, errChunk, errStr) +} + // createForwardTSN generates ForwardTSN chunk. // This method will be be called if useForwardTSN is set to false. // The caller should hold the lock. @@ -2251,6 +2312,8 @@ func (a *Association) handleChunk(p *packet, c chunk) error { return nil } + isAbort := false + switch c := c.(type) { case *chunkInit: packets, err = a.handleInit(p, c) @@ -2259,11 +2322,8 @@ func (a *Association) handleChunk(p *packet, c chunk) error { err = a.handleInitAck(p, c) case *chunkAbort: - var errStr string - for _, e := range c.errorCauses { - errStr += fmt.Sprintf("(%s)", e) - } - return fmt.Errorf("[%s] %w: %s", a.name, errChunk, errStr) + isAbort = true + err = a.handleAbort(c) case *chunkError: var errStr string @@ -2306,6 +2366,10 @@ func (a *Association) handleChunk(p *packet, c chunk) error { // Log and return, the only condition that is fatal is a ABORT chunk if err != nil { + if isAbort { + return err + } + a.log.Errorf("Failed to handle chunk: %v", err) return nil } diff --git a/association_test.go b/association_test.go index f118b428..6dd13351 100644 --- a/association_test.go +++ b/association_test.go @@ -2652,11 +2652,12 @@ func TestAssociation_ShutdownDuringWrite(t *testing.T) { } } -func TestAssociation_HandlePacketBeforeInit(t *testing.T) { +func TestAssociation_HandlePacketInCookieWaitState(t *testing.T) { loggerFactory := logging.NewDefaultLoggerFactory() testCases := map[string]struct { inputPacket *packet + skipClose bool }{ "InitAck": { inputPacket: &packet{ @@ -2680,6 +2681,8 @@ func TestAssociation_HandlePacketBeforeInit(t *testing.T) { destinationPort: 1, chunks: []chunk{&chunkAbort{}}, }, + // Prevent "use of close network connection" error on close. + skipClose: true, }, "CoockeEcho": { inputPacket: &packet{ @@ -2774,9 +2777,12 @@ func TestAssociation_HandlePacketBeforeInit(t *testing.T) { LoggerFactory: loggerFactory, }) a.init(true) - defer func() { - assert.NoError(t, a.close()) - }() + + if !testCase.skipClose { + defer func() { + assert.NoError(t, a.close()) + }() + } packet, err := testCase.inputPacket.marshal() assert.NoError(t, err) @@ -2788,3 +2794,46 @@ func TestAssociation_HandlePacketBeforeInit(t *testing.T) { }) } } + +func TestAssociation_Abort(t *testing.T) { + runtime.GC() + n0 := runtime.NumGoroutine() + + defer func() { + runtime.GC() + assert.Equal(t, n0, runtime.NumGoroutine(), "goroutine is leaked") + }() + + a1, a2 := createAssocs(t) + + s11, err := a1.OpenStream(1, PayloadTypeWebRTCString) + require.NoError(t, err) + + s21, err := a2.OpenStream(1, PayloadTypeWebRTCString) + require.NoError(t, err) + + testData := []byte("test") + + i, err := s11.Write(testData) + assert.Equal(t, len(testData), i) + assert.NoError(t, err) + + buf := make([]byte, len(testData)) + i, err = s21.Read(buf) + assert.Equal(t, len(testData), i) + assert.NoError(t, err) + assert.Equal(t, testData, buf) + + a1.Abort("1234") + + // Wait for close read loop channels to prevent flaky tests. + select { + case <-a2.readLoopCloseCh: + case <-time.After(1 * time.Second): + assert.Fail(t, "timed out waiting for a2 read loop to close") + } + + i, err = s21.Read(buf) + assert.Equal(t, i, 0, "expected no data read") + assert.Error(t, err, "User Initiated Abort: 1234", "expected abort reason") +} diff --git a/error_cause.go b/error_cause.go index db9a56f0..a511c133 100644 --- a/error_cause.go +++ b/error_cause.go @@ -32,6 +32,8 @@ func buildErrorCause(raw []byte) (errorCause, error) { e = &errorCauseUnrecognizedChunkType{} case protocolViolation: e = &errorCauseProtocolViolation{} + case userInitiatedAbort: + e = &errorCauseUserInitiatedAbort{} default: return nil, fmt.Errorf("%w: %s", errBuildErrorCaseHandle, c.String()) } @@ -39,6 +41,7 @@ func buildErrorCause(raw []byte) (errorCause, error) { if err := e.unmarshal(raw); err != nil { return nil, err } + return e, nil } diff --git a/error_cause_user_initiated_abort.go b/error_cause_user_initiated_abort.go new file mode 100644 index 00000000..5cb4125a --- /dev/null +++ b/error_cause_user_initiated_abort.go @@ -0,0 +1,46 @@ +package sctp + +import ( + "fmt" +) + +/* + This error cause MAY be included in ABORT chunks that are sent + because of an upper-layer request. The upper layer can specify an + Upper Layer Abort Reason that is transported by SCTP transparently + and MAY be delivered to the upper-layer protocol at the peer. + + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Cause Code=12 | Cause Length=Variable | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + / Upper Layer Abort Reason / + \ \ + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +*/ +type errorCauseUserInitiatedAbort struct { + errorCauseHeader + upperLayerAbortReason []byte +} + +func (e *errorCauseUserInitiatedAbort) marshal() ([]byte, error) { + e.code = userInitiatedAbort + e.errorCauseHeader.raw = e.upperLayerAbortReason + return e.errorCauseHeader.marshal() +} + +func (e *errorCauseUserInitiatedAbort) unmarshal(raw []byte) error { + err := e.errorCauseHeader.unmarshal(raw) + if err != nil { + return err + } + + e.upperLayerAbortReason = e.errorCauseHeader.raw + return nil +} + +// String makes errorCauseUserInitiatedAbort printable +func (e *errorCauseUserInitiatedAbort) String() string { + return fmt.Sprintf("%s: %s", e.errorCauseHeader.String(), e.upperLayerAbortReason) +}