diff --git a/association_test.go b/association_test.go index 6dd13351..2079d22a 100644 --- a/association_test.go +++ b/association_test.go @@ -12,6 +12,7 @@ import ( "math" "math/rand" "net" + "os" "runtime" "strings" "sync" @@ -466,6 +467,57 @@ func TestAssocReliable(t *testing.T) { closeAssociationPair(br, a0, a1) }) + t.Run("ReadDeadline", func(t *testing.T) { + lim := test.TimeOut(time.Second * 10) + defer lim.Stop() + + const si uint16 = 1 + const msg = "ABC" + br := test.NewBridge() + + a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) + if !assert.Nil(t, err, "failed to create associations") { + assert.FailNow(t, "failed due to earlier error") + } + + s0, s1, err := establishSessionPair(br, a0, a1, si) + assert.Nil(t, err, "failed to establish session pair") + + assert.Equal(t, 0, a0.bufferedAmount(), "incorrect bufferedAmount") + + assert.NoError(t, s1.SetReadDeadline(time.Now().Add(time.Millisecond)), "failed to set read deadline") + buf := make([]byte, 32) + // First fails + n, ppi, err := s1.ReadSCTP(buf) + assert.Equal(t, 0, n) + assert.Equal(t, PayloadProtocolIdentifier(0), ppi) + assert.True(t, errors.Is(err, os.ErrDeadlineExceeded)) + // Second too + n, ppi, err = s1.ReadSCTP(buf) + assert.Equal(t, 0, n) + assert.Equal(t, PayloadProtocolIdentifier(0), ppi) + assert.True(t, errors.Is(err, os.ErrDeadlineExceeded)) + assert.NoError(t, s1.SetReadDeadline(time.Time{}), "failed to disable read deadline") + + n, err = s0.WriteSCTP([]byte(msg), PayloadTypeWebRTCBinary) + if err != nil { + assert.FailNow(t, "failed due to earlier error") + } + assert.Equal(t, len(msg), n, "unexpected length of received data") + assert.Equal(t, len(msg), a0.bufferedAmount(), "incorrect bufferedAmount") + + flushBuffers(br, a0, a1) + + n, ppi, err = s1.ReadSCTP(buf) + if !assert.Nil(t, err, "ReadSCTP failed") { + assert.FailNow(t, "failed due to earlier error") + } + assert.Equal(t, n, len(msg), "unexpected length of received data") + assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi") + + closeAssociationPair(br, a0, a1) + }) + t.Run("ordered reordered", func(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() diff --git a/stream.go b/stream.go index f9304e17..fd2c5fc7 100644 --- a/stream.go +++ b/stream.go @@ -5,8 +5,10 @@ import ( "fmt" "io" "math" + "os" "sync" "sync/atomic" + "time" "github.com/pion/logging" ) @@ -46,6 +48,7 @@ func (ss StreamState) String() string { var ( errOutboundPacketTooLarge = errors.New("outbound packet larger than maximum message size") errStreamClosed = errors.New("stream closed") + errReadDeadlineExceeded = fmt.Errorf("read deadline exceeded: %w", os.ErrDeadlineExceeded) ) // Stream represents an SCTP stream @@ -58,6 +61,7 @@ type Stream struct { sequenceNumber uint16 readNotifier *sync.Cond readErr error + readTimeoutCancel chan struct{} unordered bool reliabilityType byte reliabilityValue uint32 @@ -115,6 +119,14 @@ func (s *Stream) ReadSCTP(p []byte) (int, PayloadProtocolIdentifier, error) { s.lock.Lock() defer s.lock.Unlock() + defer func() { + // close readTimeoutCancel if the current read timeout routine is no longer effective + if s.readTimeoutCancel != nil && s.readErr != nil { + close(s.readTimeoutCancel) + s.readTimeoutCancel = nil + } + }() + for { n, ppi, err := s.reassemblyQueue.read(p) if err == nil { @@ -132,6 +144,47 @@ func (s *Stream) ReadSCTP(p []byte) (int, PayloadProtocolIdentifier, error) { } } +// SetReadDeadline sets the read deadline in an identical way to net.Conn +func (s *Stream) SetReadDeadline(deadline time.Time) error { + s.lock.Lock() + defer s.lock.Unlock() + + if s.readTimeoutCancel != nil { + close(s.readTimeoutCancel) + s.readTimeoutCancel = nil + } + + if s.readErr != nil { + if !errors.Is(s.readErr, errReadDeadlineExceeded) { + return nil + } + s.readErr = nil + } + + if !deadline.IsZero() { + s.readTimeoutCancel = make(chan struct{}) + + go func(readTimeoutCancel chan struct{}) { + t := time.NewTimer(time.Until(deadline)) + select { + case <-readTimeoutCancel: + t.Stop() + return + case <-t.C: + s.lock.Lock() + if s.readErr == nil { + s.readErr = errReadDeadlineExceeded + } + s.readTimeoutCancel = nil + s.lock.Unlock() + + s.readNotifier.Signal() + } + }(s.readTimeoutCancel) + } + return nil +} + func (s *Stream) handleData(pd *chunkPayloadData) { s.lock.Lock() defer s.lock.Unlock()