diff --git a/p2p/transport/webrtc/stream.go b/p2p/transport/webrtc/stream.go index 23fef07bbc..a78604b99e 100644 --- a/p2p/transport/webrtc/stream.go +++ b/p2p/transport/webrtc/stream.go @@ -83,12 +83,10 @@ type stream struct { nextMessage *pb.Message receiveState receiveState - writer pbio.Writer // concurrent writes prevented by mx - sendStateChanged chan struct{} - sendState sendState - writeDeadline time.Time - writeDeadlineUpdated chan struct{} - writeAvailable chan struct{} + writer pbio.Writer // concurrent writes prevented by mx + writeStateChanged chan struct{} + sendState sendState + writeDeadline time.Time controlMessageReaderOnce sync.Once // controlMessageReaderEndTime is the end time for reading FIN_ACK from the control @@ -113,38 +111,30 @@ func newStream( onDone func(), ) *stream { s := &stream{ - reader: pbio.NewDelimitedReader(rwc, maxMessageSize), - writer: pbio.NewDelimitedWriter(rwc), - - sendStateChanged: make(chan struct{}, 1), - writeDeadlineUpdated: make(chan struct{}, 1), - writeAvailable: make(chan struct{}, 1), - - controlMessageReaderDone: sync.WaitGroup{}, - - id: *channel.ID(), - dataChannel: rwc.(*datachannel.DataChannel), - onDone: onDone, + reader: pbio.NewDelimitedReader(rwc, maxMessageSize), + writer: pbio.NewDelimitedWriter(rwc), + writeStateChanged: make(chan struct{}, 1), + id: *channel.ID(), + dataChannel: rwc.(*datachannel.DataChannel), + onDone: onDone, } // released when the controlMessageReader goroutine exits s.controlMessageReaderDone.Add(1) s.dataChannel.SetBufferedAmountLowThreshold(bufferedAmountLowThreshold) s.dataChannel.OnBufferedAmountLow(func() { - select { - case s.writeAvailable <- struct{}{}: - default: - } + s.notifyWriteStateChanged() + }) return s } func (s *stream) Close() error { s.mx.Lock() - if s.closeForShutdownErr != nil { - s.mx.Unlock() + isClosed := s.closeForShutdownErr != nil + s.mx.Unlock() + if isClosed { return nil } - s.mx.Unlock() closeWriteErr := s.CloseWrite() closeReadErr := s.CloseRead() @@ -166,11 +156,11 @@ func (s *stream) Close() error { func (s *stream) Reset() error { s.mx.Lock() - if s.closeForShutdownErr != nil { - s.mx.Unlock() + isClosed := s.closeForShutdownErr != nil + s.mx.Unlock() + if isClosed { return nil } - s.mx.Unlock() defer s.cleanup() cancelWriteErr := s.cancelWrite() @@ -189,10 +179,7 @@ func (s *stream) closeForShutdown(closeErr error) { s.closeForShutdownErr = closeErr s.SetReadDeadline(time.Now().Add(-1 * time.Hour)) - select { - case s.sendStateChanged <- struct{}{}: - default: - } + s.notifyWriteStateChanged() } func (s *stream) SetDeadline(t time.Time) error { @@ -214,16 +201,10 @@ func (s *stream) processIncomingFlag(flag *pb.Message_Flag) { if s.sendState == sendStateSending || s.sendState == sendStateDataSent { s.sendState = sendStateReset } - select { - case s.sendStateChanged <- struct{}{}: - default: - } + s.notifyWriteStateChanged() case pb.Message_FIN_ACK: s.sendState = sendStateDataReceived - select { - case s.sendStateChanged <- struct{}{}: - default: - } + s.notifyWriteStateChanged() case pb.Message_FIN: if s.receiveState == receiveStateReceiving { s.receiveState = receiveStateDataRead @@ -285,8 +266,9 @@ func (s *stream) spawnControlMessageReader() { return } s.mx.Unlock() - if err := s.reader.ReadMsg(&msg); err != nil { - s.mx.Lock() + err := s.reader.ReadMsg(&msg) + s.mx.Lock() + if err != nil { // We have to manually manage deadline exceeded errors since pion/sctp can // return deadline exceeded error for cancelled deadlines // see: https://github.com/pion/sctp/pull/290/files @@ -295,7 +277,6 @@ func (s *stream) spawnControlMessageReader() { } return } - s.mx.Lock() s.processIncomingFlag(msg.Flag) } }() diff --git a/p2p/transport/webrtc/stream_read.go b/p2p/transport/webrtc/stream_read.go index 002ebac0ec..f9df06eca2 100644 --- a/p2p/transport/webrtc/stream_read.go +++ b/p2p/transport/webrtc/stream_read.go @@ -35,8 +35,9 @@ func (s *stream) Read(b []byte) (int, error) { // load the next message s.mx.Unlock() var msg pb.Message - if err := s.reader.ReadMsg(&msg); err != nil { - s.mx.Lock() + err := s.reader.ReadMsg(&msg) + s.mx.Lock() + if err != nil { // connection was closed if s.closeForShutdownErr != nil { return 0, s.closeForShutdownErr @@ -61,7 +62,6 @@ func (s *stream) Read(b []byte) (int, error) { } return 0, err } - s.mx.Lock() s.nextMessage = &msg } diff --git a/p2p/transport/webrtc/stream_write.go b/p2p/transport/webrtc/stream_write.go index b0b2837d91..82d4ac287d 100644 --- a/p2p/transport/webrtc/stream_write.go +++ b/p2p/transport/webrtc/stream_write.go @@ -76,12 +76,10 @@ func (s *stream) Write(b []byte) (int, error) { if availableSpace < minMessageSize { s.mx.Unlock() select { - case <-s.writeAvailable: case <-writeDeadlineChan: s.mx.Lock() return n, os.ErrDeadlineExceeded - case <-s.sendStateChanged: - case <-s.writeDeadlineUpdated: + case <-s.writeStateChanged: } s.mx.Lock() continue @@ -108,10 +106,7 @@ func (s *stream) SetWriteDeadline(t time.Time) error { s.mx.Lock() defer s.mx.Unlock() s.writeDeadline = t - select { - case s.writeDeadlineUpdated <- struct{}{}: - default: - } + s.notifyWriteStateChanged() return nil } @@ -134,10 +129,7 @@ func (s *stream) cancelWrite() error { return nil } s.sendState = sendStateReset - select { - case s.sendStateChanged <- struct{}{}: - default: - } + s.notifyWriteStateChanged() if err := s.writer.WriteMsg(&pb.Message{Flag: pb.Message_RESET.Enum()}); err != nil { return err } @@ -152,12 +144,16 @@ func (s *stream) CloseWrite() error { return nil } s.sendState = sendStateDataSent - select { - case s.sendStateChanged <- struct{}{}: - default: - } + s.notifyWriteStateChanged() if err := s.writer.WriteMsg(&pb.Message{Flag: pb.Message_FIN.Enum()}); err != nil { return err } return nil } + +func (s *stream) notifyWriteStateChanged() { + select { + case s.writeStateChanged <- struct{}{}: + default: + } +}