Skip to content

Commit

Permalink
webrtc: return error on 0 len writes after close
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed Oct 23, 2023
1 parent fd471fa commit b7c04f8
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 4 deletions.
7 changes: 3 additions & 4 deletions p2p/transport/webrtc/stream_read.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,6 @@ import (
)

func (s *stream) Read(b []byte) (int, error) {
if len(b) == 0 {
return 0, nil
}

s.mx.Lock()
defer s.mx.Unlock()

Expand All @@ -27,6 +23,9 @@ func (s *stream) Read(b []byte) (int, error) {
return 0, network.ErrReset
}

if len(b) == 0 {
return 0, nil
}
var read int
for {
if s.nextMessage == nil {
Expand Down
27 changes: 27 additions & 0 deletions p2p/transport/webrtc/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ type detachedChan struct {
func getDetachedDataChannels(t *testing.T) (detachedChan, detachedChan) {
s := webrtc.SettingEngine{}
s.DetachDataChannels()
s.SetIncludeLoopbackCandidate(true)
api := webrtc.NewAPI(webrtc.WithSettingEngine(s))

offerPC, err := api.NewPeerConnection(webrtc.Configuration{})
Expand Down Expand Up @@ -305,3 +306,29 @@ func TestStreamWriteDeadlineAsync(t *testing.T) {
require.GreaterOrEqual(t, took, timeout)
require.LessOrEqual(t, took, timeout*3/2)
}

func TestStreamReadAfterClose(t *testing.T) {
client, server := getDetachedDataChannels(t)

clientStr := newStream(client.dc, client.rwc, func() {})
serverStr := newStream(server.dc, server.rwc, func() {})

serverStr.Close()
b := make([]byte, 1)
_, err := clientStr.Read(b)
require.Equal(t, io.EOF, err)
_, err = clientStr.Read(nil)
require.Equal(t, io.EOF, err)

client, server = getDetachedDataChannels(t)

clientStr = newStream(client.dc, client.rwc, func() {})
serverStr = newStream(server.dc, server.rwc, func() {})

serverStr.Reset()
b = make([]byte, 1)
_, err = clientStr.Read(b)
require.ErrorIs(t, err, network.ErrReset)
_, err = clientStr.Read(nil)
require.ErrorIs(t, err, network.ErrReset)
}

0 comments on commit b7c04f8

Please sign in to comment.