Skip to content

Commit

Permalink
Stream: Add "SetReadDeadline"
Browse files Browse the repository at this point in the history
Add SetReadDeadline method to the Stream object to allow read timeout
to be implemented
  • Loading branch information
Sam Lancia committed Oct 16, 2022
1 parent d13d723 commit 396e5d5
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 0 deletions.
5 changes: 5 additions & 0 deletions association.go
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,11 @@ func (a *Association) unregisterStream(s *Stream, err error) {
delete(a.streams, s.streamIdentifier)
s.readErr = err
s.readNotifier.Broadcast()

if s.readTimeoutCancel != nil {
close(s.readTimeoutCancel)
s.readTimeoutCancel = nil
}
}

// handleInbound parses incoming raw packets
Expand Down
52 changes: 52 additions & 0 deletions association_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"math"
"math/rand"
"net"
"os"
"runtime"
"strings"
"sync"
Expand Down Expand Up @@ -466,6 +467,57 @@ func TestAssocReliable(t *testing.T) {
closeAssociationPair(br, a0, a1)
})

t.Run("ReadDeadline", func(t *testing.T) {
lim := test.TimeOut(time.Second * 10)
defer lim.Stop()

const si uint16 = 1
const msg = "ABC"
br := test.NewBridge()

a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0)
if !assert.Nil(t, err, "failed to create associations") {
assert.FailNow(t, "failed due to earlier error")
}

s0, s1, err := establishSessionPair(br, a0, a1, si)
assert.Nil(t, err, "failed to establish session pair")

assert.Equal(t, 0, a0.bufferedAmount(), "incorrect bufferedAmount")

assert.NoError(t, s1.SetReadDeadline(time.Now().Add(time.Millisecond)), "failed to set read deadline")
buf := make([]byte, 32)
// First fails
n, ppi, err := s1.ReadSCTP(buf)
assert.Equal(t, 0, n)
assert.Equal(t, PayloadProtocolIdentifier(0), ppi)
assert.True(t, errors.Is(err, os.ErrDeadlineExceeded))
// Second too
n, ppi, err = s1.ReadSCTP(buf)
assert.Equal(t, 0, n)
assert.Equal(t, PayloadProtocolIdentifier(0), ppi)
assert.True(t, errors.Is(err, os.ErrDeadlineExceeded))
assert.NoError(t, s1.SetReadDeadline(time.Time{}), "failed to disable read deadline")

n, err = s0.WriteSCTP([]byte(msg), PayloadTypeWebRTCBinary)
if err != nil {
assert.FailNow(t, "failed due to earlier error")
}
assert.Equal(t, len(msg), n, "unexpected length of received data")
assert.Equal(t, len(msg), a0.bufferedAmount(), "incorrect bufferedAmount")

flushBuffers(br, a0, a1)

n, ppi, err = s1.ReadSCTP(buf)
if !assert.Nil(t, err, "ReadSCTP failed") {
assert.FailNow(t, "failed due to earlier error")
}
assert.Equal(t, n, len(msg), "unexpected length of received data")
assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi")

closeAssociationPair(br, a0, a1)
})

t.Run("ordered reordered", func(t *testing.T) {
lim := test.TimeOut(time.Second * 10)
defer lim.Stop()
Expand Down
46 changes: 46 additions & 0 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ import (
"fmt"
"io"
"math"
"os"
"sync"
"sync/atomic"
"time"

"github.com/pion/logging"
)
Expand Down Expand Up @@ -46,6 +48,7 @@ func (ss StreamState) String() string {
var (
errOutboundPacketTooLarge = errors.New("outbound packet larger than maximum message size")
errStreamClosed = errors.New("stream closed")
errReadDeadlineExceeded = fmt.Errorf("read deadline exceeded: %w", os.ErrDeadlineExceeded)
)

// Stream represents an SCTP stream
Expand All @@ -58,6 +61,7 @@ type Stream struct {
sequenceNumber uint16
readNotifier *sync.Cond
readErr error
readTimeoutCancel chan struct{}
unordered bool
reliabilityType byte
reliabilityValue uint32
Expand Down Expand Up @@ -132,6 +136,43 @@ func (s *Stream) ReadSCTP(p []byte) (int, PayloadProtocolIdentifier, error) {
}
}

// SetReadDeadline sets the read deadline in an identical way to net.Conn
func (s *Stream) SetReadDeadline(deadline time.Time) error {
s.lock.Lock()
defer s.lock.Unlock()

if s.readTimeoutCancel != nil {
close(s.readTimeoutCancel)
s.readTimeoutCancel = nil
}

if deadline.IsZero() {
if errors.Is(s.readErr, errReadDeadlineExceeded) {
s.readErr = nil
}
} else {
s.readTimeoutCancel = make(chan struct{})

go func(readTimeoutCancel chan struct{}) {
t := time.NewTimer(time.Until(deadline))
select {
case <-readTimeoutCancel:
t.Stop()
return
case <-t.C:
s.lock.Lock()
if s.readErr == nil {
s.readErr = errReadDeadlineExceeded
}
s.lock.Unlock()

s.readNotifier.Signal()
}
}(s.readTimeoutCancel)
}
return nil
}

func (s *Stream) handleData(pd *chunkPayloadData) {
s.lock.Lock()
defer s.lock.Unlock()
Expand Down Expand Up @@ -396,6 +437,11 @@ func (s *Stream) onInboundStreamReset() {
s.readErr = io.EOF
s.readNotifier.Broadcast()

if s.readTimeoutCancel != nil {
close(s.readTimeoutCancel)
s.readTimeoutCancel = nil
}

if s.state == StreamStateClosing {
s.log.Debugf("[%s] state change: closing => closed", s.name)
s.state = StreamStateClosed
Expand Down

0 comments on commit 396e5d5

Please sign in to comment.