Skip to content

Commit

Permalink
cleanup write path channels
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed Feb 20, 2024
1 parent 7e005b5 commit 0094df6
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 61 deletions.
67 changes: 24 additions & 43 deletions p2p/transport/webrtc/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,10 @@ type stream struct {
nextMessage *pb.Message
receiveState receiveState

writer pbio.Writer // concurrent writes prevented by mx
sendStateChanged chan struct{}
sendState sendState
writeDeadline time.Time
writeDeadlineUpdated chan struct{}
writeAvailable chan struct{}
writer pbio.Writer // concurrent writes prevented by mx
writeStateChanged chan struct{}
sendState sendState
writeDeadline time.Time

controlMessageReaderOnce sync.Once
// controlMessageReaderEndTime is the end time for reading FIN_ACK from the control
Expand All @@ -113,38 +111,30 @@ func newStream(
onDone func(),
) *stream {
s := &stream{
reader: pbio.NewDelimitedReader(rwc, maxMessageSize),
writer: pbio.NewDelimitedWriter(rwc),

sendStateChanged: make(chan struct{}, 1),
writeDeadlineUpdated: make(chan struct{}, 1),
writeAvailable: make(chan struct{}, 1),

controlMessageReaderDone: sync.WaitGroup{},

id: *channel.ID(),
dataChannel: rwc.(*datachannel.DataChannel),
onDone: onDone,
reader: pbio.NewDelimitedReader(rwc, maxMessageSize),
writer: pbio.NewDelimitedWriter(rwc),
writeStateChanged: make(chan struct{}, 1),
id: *channel.ID(),
dataChannel: rwc.(*datachannel.DataChannel),
onDone: onDone,
}
// released when the controlMessageReader goroutine exits
s.controlMessageReaderDone.Add(1)
s.dataChannel.SetBufferedAmountLowThreshold(bufferedAmountLowThreshold)
s.dataChannel.OnBufferedAmountLow(func() {
select {
case s.writeAvailable <- struct{}{}:
default:
}
s.notifyWriteStateChanged()

})
return s
}

func (s *stream) Close() error {
s.mx.Lock()
if s.closeForShutdownErr != nil {
s.mx.Unlock()
isClosed := s.closeForShutdownErr != nil
s.mx.Unlock()
if isClosed {
return nil
}
s.mx.Unlock()

closeWriteErr := s.CloseWrite()
closeReadErr := s.CloseRead()
Expand All @@ -166,11 +156,11 @@ func (s *stream) Close() error {

func (s *stream) Reset() error {
s.mx.Lock()
if s.closeForShutdownErr != nil {
s.mx.Unlock()
isClosed := s.closeForShutdownErr != nil
s.mx.Unlock()
if isClosed {
return nil
}
s.mx.Unlock()

defer s.cleanup()
cancelWriteErr := s.cancelWrite()
Expand All @@ -189,10 +179,7 @@ func (s *stream) closeForShutdown(closeErr error) {

s.closeForShutdownErr = closeErr
s.SetReadDeadline(time.Now().Add(-1 * time.Hour))
select {
case s.sendStateChanged <- struct{}{}:
default:
}
s.notifyWriteStateChanged()
}

func (s *stream) SetDeadline(t time.Time) error {
Expand All @@ -214,16 +201,10 @@ func (s *stream) processIncomingFlag(flag *pb.Message_Flag) {
if s.sendState == sendStateSending || s.sendState == sendStateDataSent {
s.sendState = sendStateReset
}
select {
case s.sendStateChanged <- struct{}{}:
default:
}
s.notifyWriteStateChanged()
case pb.Message_FIN_ACK:
s.sendState = sendStateDataReceived
select {
case s.sendStateChanged <- struct{}{}:
default:
}
s.notifyWriteStateChanged()
case pb.Message_FIN:
if s.receiveState == receiveStateReceiving {
s.receiveState = receiveStateDataRead
Expand Down Expand Up @@ -285,8 +266,9 @@ func (s *stream) spawnControlMessageReader() {
return
}
s.mx.Unlock()
if err := s.reader.ReadMsg(&msg); err != nil {
s.mx.Lock()
err := s.reader.ReadMsg(&msg)
s.mx.Lock()
if err != nil {
// We have to manually manage deadline exceeded errors since pion/sctp can
// return deadline exceeded error for cancelled deadlines
// see: https://github.com/pion/sctp/pull/290/files
Expand All @@ -295,7 +277,6 @@ func (s *stream) spawnControlMessageReader() {
}
return
}
s.mx.Lock()
s.processIncomingFlag(msg.Flag)
}
}()
Expand Down
6 changes: 3 additions & 3 deletions p2p/transport/webrtc/stream_read.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ func (s *stream) Read(b []byte) (int, error) {
// load the next message
s.mx.Unlock()
var msg pb.Message
if err := s.reader.ReadMsg(&msg); err != nil {
s.mx.Lock()
err := s.reader.ReadMsg(&msg)
s.mx.Lock()
if err != nil {
// connection was closed
if s.closeForShutdownErr != nil {
return 0, s.closeForShutdownErr
Expand All @@ -61,7 +62,6 @@ func (s *stream) Read(b []byte) (int, error) {
}
return 0, err
}
s.mx.Lock()
s.nextMessage = &msg
}

Expand Down
26 changes: 11 additions & 15 deletions p2p/transport/webrtc/stream_write.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,10 @@ func (s *stream) Write(b []byte) (int, error) {
if availableSpace < minMessageSize {
s.mx.Unlock()
select {
case <-s.writeAvailable:
case <-writeDeadlineChan:
s.mx.Lock()
return n, os.ErrDeadlineExceeded
case <-s.sendStateChanged:
case <-s.writeDeadlineUpdated:
case <-s.writeStateChanged:
}
s.mx.Lock()
continue
Expand All @@ -108,10 +106,7 @@ func (s *stream) SetWriteDeadline(t time.Time) error {
s.mx.Lock()
defer s.mx.Unlock()
s.writeDeadline = t
select {
case s.writeDeadlineUpdated <- struct{}{}:
default:
}
s.notifyWriteStateChanged()
return nil
}

Expand All @@ -134,10 +129,7 @@ func (s *stream) cancelWrite() error {
return nil
}
s.sendState = sendStateReset
select {
case s.sendStateChanged <- struct{}{}:
default:
}
s.notifyWriteStateChanged()
if err := s.writer.WriteMsg(&pb.Message{Flag: pb.Message_RESET.Enum()}); err != nil {
return err
}
Expand All @@ -152,12 +144,16 @@ func (s *stream) CloseWrite() error {
return nil
}
s.sendState = sendStateDataSent
select {
case s.sendStateChanged <- struct{}{}:
default:
}
s.notifyWriteStateChanged()
if err := s.writer.WriteMsg(&pb.Message{Flag: pb.Message_FIN.Enum()}); err != nil {
return err
}
return nil
}

func (s *stream) notifyWriteStateChanged() {
select {
case s.writeStateChanged <- struct{}{}:
default:
}
}

0 comments on commit 0094df6

Please sign in to comment.