Skip to content

Commit

Permalink
Implement shutdown sequence
Browse files Browse the repository at this point in the history
Closes #152.
  • Loading branch information
jeremija committed Dec 20, 2020
1 parent 8e5c733 commit 7d59e0d
Show file tree
Hide file tree
Showing 10 changed files with 744 additions and 27 deletions.
265 changes: 238 additions & 27 deletions association.go

Large diffs are not rendered by default.

182 changes: 182 additions & 0 deletions association_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@ import (
"runtime"
"strings"
"sync"
"sync/atomic"
"testing"
"time"

"github.com/pion/logging"
"github.com/pion/transport/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

var (
Expand Down Expand Up @@ -2471,3 +2473,183 @@ func TestAssocMaxMessageSize(t *testing.T) {
assert.Equal(t, uint32(20000), a.MaxMessageSize(), "should match")
})
}

func createAssocs(t *testing.T) (a1, a2 *Association) {
addr1 := &net.UDPAddr{
IP: net.IP{127, 0, 0, 1},
Port: 1234,
}

addr2 := &net.UDPAddr{
IP: net.IP{127, 0, 0, 1},
Port: 5678,
}

udp1, err := net.DialUDP("udp", addr1, addr2)
if err != nil {
panic(err)
}

udp2, err := net.DialUDP("udp", addr2, addr1)
if err != nil {
panic(err)
}

loggerFactory := logging.NewDefaultLoggerFactory()

a1Chan := make(chan *Association)
a2Chan := make(chan *Association)

go func() {
a1, err := Client(Config{
NetConn: udp1,
LoggerFactory: loggerFactory,
})
require.NoError(t, err)

a1Chan <- a1
}()

go func() {
a2, err := Client(Config{
NetConn: udp2,
LoggerFactory: loggerFactory,
})
require.NoError(t, err)

a2Chan <- a2
}()

select {
case a1 = <-a1Chan:
case <-time.After(time.Second):
assert.Fail(t, "timed out waiting for a1")
}

select {
case a2 = <-a2Chan:
case <-time.After(time.Second):
assert.Fail(t, "timed out waiting for a2")
}

return a1, a2
}

func TestAssociation_Shutdown(t *testing.T) {
runtime.GC()
n0 := runtime.NumGoroutine()

defer func() {
runtime.GC()
assert.Equal(t, n0, runtime.NumGoroutine(), "goroutine is leaked")
}()

a1, a2 := createAssocs(t)

s11, err := a1.OpenStream(1, PayloadTypeWebRTCString)
require.NoError(t, err)

s21, err := a2.OpenStream(1, PayloadTypeWebRTCString)
require.NoError(t, err)

testData := []byte("test")

i, err := s11.Write(testData)
assert.Equal(t, len(testData), i)
assert.NoError(t, err)

buf := make([]byte, len(testData))
i, err = s21.Read(buf)
assert.Equal(t, len(testData), i)
assert.NoError(t, err)
assert.Equal(t, testData, buf)

done, err := a1.Shutdown()
require.NoError(t, err)

select {
case <-done:
case <-time.After(1 * time.Second):
assert.Fail(t, "timed out waiting for a1 shutdown to complete")
}

// Wait for close read loop channels to prevent flaky tests.
select {
case <-a2.readLoopCloseCh:
case <-time.After(1 * time.Second):
assert.Fail(t, "timed out waiting for a2 read loop to close")
}
}

func TestAssociation_ShutdownDuringWrite(t *testing.T) {
runtime.GC()
n0 := runtime.NumGoroutine()

defer func() {
runtime.GC()
assert.Equal(t, n0, runtime.NumGoroutine(), "goroutine is leaked")
}()

a1, a2 := createAssocs(t)

s11, err := a1.OpenStream(1, PayloadTypeWebRTCString)
require.NoError(t, err)

s21, err := a2.OpenStream(1, PayloadTypeWebRTCString)
require.NoError(t, err)

var stopWriting int32
writingDone := make(chan struct{})

go func() {
defer close(writingDone)
var i byte

for {
i++

_, err := s21.Write([]byte{i})
if err != nil {
return
}
}
}()

testData := []byte("test")

i, err := s11.Write(testData)
assert.Equal(t, len(testData), i)
assert.NoError(t, err)

buf := make([]byte, len(testData))
i, err = s21.Read(buf)
assert.Equal(t, len(testData), i)
assert.NoError(t, err)
assert.Equal(t, testData, buf)

done, err := a1.Shutdown()
require.NoError(t, err)

// running this test with -race flag is very slow so timeout needs to be high.
timeout := 5 * time.Minute

select {
case <-done:
atomic.AddInt32(&stopWriting, 1)
case <-time.After(timeout):
assert.Fail(t, "timed out waiting for a1 shutdown to complete")
}

select {
case <-writingDone:
case <-time.After(timeout):
assert.Fail(t, "timed out waiting writing goroutine to exit")
}

// Wait for close read loop channels to prevent flaky tests.
select {
case <-a2.readLoopCloseCh:
case <-time.After(timeout):
assert.Fail(t, "timed out waiting for a2 read loop to close")
}
}
68 changes: 68 additions & 0 deletions chunk_shutdown.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package sctp

import (
"encoding/binary"
"errors"
"fmt"
)

/*
chunkShutdown represents an SCTP Chunk of type chunkShutdown
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Type = 7 | Chunk Flags | Length = 8 |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Cumulative TSN Ack |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
*/
type chunkShutdown struct {
chunkHeader
cumulativeTSNAck uint32
}

const (
cumulativeTSNAckLength = 4
)

var (
errInvalidChunkSize = errors.New("invalid chunk size")
errChunkTypeNotShutdown = errors.New("ChunkType is not of type SHUTDOWN")
)

func (c *chunkShutdown) unmarshal(raw []byte) error {
if err := c.chunkHeader.unmarshal(raw); err != nil {
return err
}

if c.typ != ctShutdown {
return fmt.Errorf("%w: actually is %s", errChunkTypeNotShutdown, c.typ.String())
}

if len(c.raw) != cumulativeTSNAckLength {
return errInvalidChunkSize
}

c.cumulativeTSNAck = binary.BigEndian.Uint32(c.raw[0:])

return nil
}

func (c *chunkShutdown) marshal() ([]byte, error) {
out := make([]byte, cumulativeTSNAckLength)
binary.BigEndian.PutUint32(out[0:], c.cumulativeTSNAck)

c.typ = ctShutdown
c.raw = out
return c.chunkHeader.marshal()
}

func (c *chunkShutdown) check() (abort bool, err error) {
return false, nil
}

// String makes chunkShutdown printable
func (c *chunkShutdown) String() string {
return c.chunkHeader.String()
}
47 changes: 47 additions & 0 deletions chunk_shutdown_ack.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package sctp

import (
"errors"
"fmt"
)

/*
chunkShutdownAck represents an SCTP Chunk of type chunkShutdownAck
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Type = 8 | Chunk Flags | Length = 4 |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
*/
type chunkShutdownAck struct {
chunkHeader
}

var errChunkTypeNotShutdownAck = errors.New("ChunkType is not of type SHUTDOWN-ACK")

func (c *chunkShutdownAck) unmarshal(raw []byte) error {
if err := c.chunkHeader.unmarshal(raw); err != nil {
return err
}

if c.typ != ctShutdownAck {
return fmt.Errorf("%w: actually is %s", errChunkTypeNotShutdownAck, c.typ.String())
}

return nil
}

func (c *chunkShutdownAck) marshal() ([]byte, error) {
c.typ = ctShutdownAck
return c.chunkHeader.marshal()
}

func (c *chunkShutdownAck) check() (abort bool, err error) {
return false, nil
}

// String makes chunkShutdownAck printable
func (c *chunkShutdownAck) String() string {
return c.chunkHeader.String()
}
48 changes: 48 additions & 0 deletions chunk_shutdown_ack_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package sctp

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestChunkShutdownAck_Success(t *testing.T) {
tt := []struct {
binary []byte
}{
{[]byte{0x08, 0x00, 0x00, 0x04}},
}

for i, tc := range tt {
actual := &chunkShutdownAck{}
err := actual.unmarshal(tc.binary)
if err != nil {
t.Fatalf("failed to unmarshal #%d: %v", i, err)
}

b, err := actual.marshal()
if err != nil {
t.Fatalf("failed to marshal: %v", err)
}
assert.Equal(t, tc.binary, b, "test %d not equal", i)
}
}

func TestChunkShutdownAck_Failure(t *testing.T) {
tt := []struct {
name string
binary []byte
}{
{"length too short", []byte{0x08, 0x00, 0x00}},
{"length too long", []byte{0x08, 0x00, 0x00, 0x04, 0x12}},
{"invalid type", []byte{0x0f, 0x00, 0x00, 0x04}},
}

for i, tc := range tt {
actual := &chunkShutdownAck{}
err := actual.unmarshal(tc.binary)
if err == nil {
t.Errorf("expected unmarshal #%d: '%s' to fail.", i, tc.name)
}
}
}
Loading

0 comments on commit 7d59e0d

Please sign in to comment.