Skip to content

Commit

Permalink
Add sctp.Association.Abort(reason string) method
Browse files Browse the repository at this point in the history
According to RFC 4960. Also see my comment at #176.

Closes #182.

I don't think we currently do any tag verification for any packets, but
we can implement that later.
  • Loading branch information
jeremija committed Feb 20, 2021
1 parent 04897bc commit cd4051f
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 5 deletions.
71 changes: 66 additions & 5 deletions association.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ type Association struct {
willSendShutdownAck bool
willSendShutdownComplete bool

willSendAbort bool
willSendAbortCause errorCause

// Reconfig
myNextRSN uint32
reconfigs map[uint32]*chunkReconfig
Expand Down Expand Up @@ -470,6 +473,22 @@ func (a *Association) close() error {
return err
}

// Abort sends the abort packet with user initiated abort and immediately
// closes the connection.
func (a *Association) Abort(reason string) {
a.log.Debugf("[%s] aborting association: %s", a.name, reason)

a.willSendAbort = true
a.willSendAbortCause = &errorCauseUserInitiatedAbort{
upperLayerAbortReason: []byte(reason),
}

a.awakeWriteLoop()

// Wait for readLoop to end
<-a.readLoopCloseCh
}

func (a *Association) closeAllTimers() {
// Close all retransmission & ack timers
a.t1Init.close()
Expand Down Expand Up @@ -587,6 +606,7 @@ func (a *Association) unregisterStream(s *Stream, err error) {

// handleInbound parses incoming raw packets
func (a *Association) handleInbound(raw []byte) error {

p := &packet{}
if err := p.unmarshal(raw); err != nil {
a.log.Warnf("[%s] unable to parse SCTP packet %s", a.name, err)
Expand Down Expand Up @@ -830,12 +850,39 @@ func (a *Association) gatherOutboundShutdownPackets(rawPackets [][]byte) ([][]by
return rawPackets, ok
}

func (a *Association) gatherAbortPacket() ([]byte, error) {
cause := a.willSendAbortCause

a.willSendAbort = false
a.willSendAbortCause = nil

abort := &chunkAbort{}

if cause != nil {
abort.errorCauses = []errorCause{cause}
}

raw, err := a.createPacket([]chunk{abort}).marshal()

return raw, err
}

// gatherOutbound gathers outgoing packets. The returned bool value set to
// false means the association should be closed down after the final send.
func (a *Association) gatherOutbound() ([][]byte, bool) {
a.lock.Lock()
defer a.lock.Unlock()

if a.willSendAbort {
pkt, err := a.gatherAbortPacket()
if err != nil {
a.log.Warnf("[%s] failed to serialize an abort packet", a.name)
return nil, false
}

return [][]byte{pkt}, false
}

rawPackets := [][]byte{}

if a.controlQueue.size() > 0 {
Expand Down Expand Up @@ -1745,6 +1792,17 @@ func (a *Association) handleShutdownComplete(_ *chunkShutdownComplete) error {
return nil
}

func (a *Association) handleAbort(c *chunkAbort) error {
var errStr string
for _, e := range c.errorCauses {
errStr += fmt.Sprintf("(%s)", e)
}

_ = a.close()

return fmt.Errorf("[%s] %w: %s", a.name, errChunk, errStr)
}

// createForwardTSN generates ForwardTSN chunk.
// This method will be be called if useForwardTSN is set to false.
// The caller should hold the lock.
Expand Down Expand Up @@ -2249,6 +2307,8 @@ func (a *Association) handleChunk(p *packet, c chunk) error {
return nil
}

isAbort := false

switch c := c.(type) {
case *chunkInit:
packets, err = a.handleInit(p, c)
Expand All @@ -2257,11 +2317,8 @@ func (a *Association) handleChunk(p *packet, c chunk) error {
err = a.handleInitAck(p, c)

case *chunkAbort:
var errStr string
for _, e := range c.errorCauses {
errStr += fmt.Sprintf("(%s)", e)
}
return fmt.Errorf("[%s] %w: %s", a.name, errChunk, errStr)
isAbort = true
err = a.handleAbort(c)

case *chunkError:
var errStr string
Expand Down Expand Up @@ -2304,6 +2361,10 @@ func (a *Association) handleChunk(p *packet, c chunk) error {

// Log and return, the only condition that is fatal is a ABORT chunk
if err != nil {
if isAbort {
return err
}

a.log.Errorf("Failed to handle chunk: %v", err)
return nil
}
Expand Down
43 changes: 43 additions & 0 deletions association_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2650,3 +2650,46 @@ func TestAssociation_ShutdownDuringWrite(t *testing.T) {
assert.Fail(t, "timed out waiting for a2 read loop to close")
}
}

func TestAssociation_Abort(t *testing.T) {
runtime.GC()
n0 := runtime.NumGoroutine()

defer func() {
runtime.GC()
assert.Equal(t, n0, runtime.NumGoroutine(), "goroutine is leaked")
}()

a1, a2 := createAssocs(t)

s11, err := a1.OpenStream(1, PayloadTypeWebRTCString)
require.NoError(t, err)

s21, err := a2.OpenStream(1, PayloadTypeWebRTCString)
require.NoError(t, err)

testData := []byte("test")

i, err := s11.Write(testData)
assert.Equal(t, len(testData), i)
assert.NoError(t, err)

buf := make([]byte, len(testData))
i, err = s21.Read(buf)
assert.Equal(t, len(testData), i)
assert.NoError(t, err)
assert.Equal(t, testData, buf)

a1.Abort("1234")

// Wait for close read loop channels to prevent flaky tests.
select {
case <-a2.readLoopCloseCh:
case <-time.After(1 * time.Second):
assert.Fail(t, "timed out waiting for a2 read loop to close")
}

i, err = s21.Read(buf)
assert.Equal(t, i, 0, "expected no data read")
assert.Error(t, err, "User Initiated Abort: 1234", "expected abort reason")
}
3 changes: 3 additions & 0 deletions error_cause.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,16 @@ func buildErrorCause(raw []byte) (errorCause, error) {
e = &errorCauseUnrecognizedChunkType{}
case protocolViolation:
e = &errorCauseProtocolViolation{}
case userInitiatedAbort:
e = &errorCauseUserInitiatedAbort{}
default:
return nil, fmt.Errorf("%w: %s", errBuildErrorCaseHandle, c.String())
}

if err := e.unmarshal(raw); err != nil {
return nil, err
}

return e, nil
}

Expand Down
32 changes: 32 additions & 0 deletions error_cause_user_initiated_abort.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package sctp

import (
"fmt"
)

// errorCauseUserInitiatedAbort represents an SCTP error cause
type errorCauseUserInitiatedAbort struct {
errorCauseHeader
upperLayerAbortReason []byte
}

func (e *errorCauseUserInitiatedAbort) marshal() ([]byte, error) {
e.code = userInitiatedAbort
e.errorCauseHeader.raw = e.upperLayerAbortReason
return e.errorCauseHeader.marshal()
}

func (e *errorCauseUserInitiatedAbort) unmarshal(raw []byte) error {
err := e.errorCauseHeader.unmarshal(raw)
if err != nil {
return err
}

e.upperLayerAbortReason = e.errorCauseHeader.raw
return nil
}

// String makes errorCauseUserInitiatedAbort printable
func (e *errorCauseUserInitiatedAbort) String() string {
return fmt.Sprintf("%s: %s", e.errorCauseHeader.String(), e.upperLayerAbortReason)
}

0 comments on commit cd4051f

Please sign in to comment.