From 8992b8c91e46e2698e1f72d234284375ddbad1dd Mon Sep 17 00:00:00 2001 From: sukun Date: Tue, 20 Feb 2024 18:08:56 +0530 Subject: [PATCH] disallow SetReadDeadline after read half is closed --- p2p/transport/webrtc/stream.go | 29 ++++++++++++++--------------- p2p/transport/webrtc/stream_read.go | 13 ++++++++++++- p2p/transport/webrtc/stream_test.go | 2 +- 3 files changed, 27 insertions(+), 17 deletions(-) diff --git a/p2p/transport/webrtc/stream.go b/p2p/transport/webrtc/stream.go index a78604b99e..45641321dc 100644 --- a/p2p/transport/webrtc/stream.go +++ b/p2p/transport/webrtc/stream.go @@ -142,15 +142,17 @@ func (s *stream) Close() error { s.Reset() return errors.Join(closeWriteErr, closeReadErr) } + s.mx.Lock() - s.controlMessageReaderEndTime = time.Now().Add(maxFINACKWait) + if s.controlMessageReaderEndTime.IsZero() { + s.controlMessageReaderEndTime = time.Now().Add(maxFINACKWait) + s.setDataChannelReadDeadline(time.Now().Add(-1 * time.Hour)) + go func() { + s.controlMessageReaderDone.Wait() + s.cleanup() + }() + } s.mx.Unlock() - s.SetReadDeadline(time.Now().Add(-1 * time.Hour)) - - go func() { - s.controlMessageReaderDone.Wait() - s.cleanup() - }() return nil } @@ -165,10 +167,8 @@ func (s *stream) Reset() error { defer s.cleanup() cancelWriteErr := s.cancelWrite() closeReadErr := s.CloseRead() - if cancelWriteErr != nil { - return cancelWriteErr - } - return closeReadErr + s.setDataChannelReadDeadline(time.Now().Add(-1 * time.Hour)) + return errors.Join(closeReadErr, cancelWriteErr) } func (s *stream) closeForShutdown(closeErr error) { @@ -178,7 +178,6 @@ func (s *stream) closeForShutdown(closeErr error) { defer s.mx.Unlock() s.closeForShutdownErr = closeErr - s.SetReadDeadline(time.Now().Add(-1 * time.Hour)) s.notifyWriteStateChanged() } @@ -230,18 +229,18 @@ func (s *stream) spawnControlMessageReader() { go func() { defer s.controlMessageReaderDone.Done() // cleanup the sctp deadline timer goroutine - defer s.SetReadDeadline(time.Time{}) + defer s.setDataChannelReadDeadline(time.Time{}) setDeadline := func() bool { if s.controlMessageReaderEndTime.IsZero() || time.Now().Before(s.controlMessageReaderEndTime) { - s.SetReadDeadline(s.controlMessageReaderEndTime) + s.setDataChannelReadDeadline(s.controlMessageReaderEndTime) return true } return false } // Unblock any Read call waiting on reader.ReadMsg - s.SetReadDeadline(time.Now().Add(-1 * time.Hour)) + s.setDataChannelReadDeadline(time.Now().Add(-1 * time.Hour)) s.readerMx.Lock() // We have the lock any readers blocked on reader.ReadMsg have exited. diff --git a/p2p/transport/webrtc/stream_read.go b/p2p/transport/webrtc/stream_read.go index f9df06eca2..80d99ea91c 100644 --- a/p2p/transport/webrtc/stream_read.go +++ b/p2p/transport/webrtc/stream_read.go @@ -87,7 +87,18 @@ func (s *stream) Read(b []byte) (int, error) { } } -func (s *stream) SetReadDeadline(t time.Time) error { return s.dataChannel.SetReadDeadline(t) } +func (s *stream) SetReadDeadline(t time.Time) error { + s.mx.Lock() + defer s.mx.Unlock() + if s.receiveState == receiveStateReceiving { + s.setDataChannelReadDeadline(t) + } + return nil +} + +func (s *stream) setDataChannelReadDeadline(t time.Time) error { + return s.dataChannel.SetReadDeadline(t) +} func (s *stream) CloseRead() error { s.mx.Lock() diff --git a/p2p/transport/webrtc/stream_test.go b/p2p/transport/webrtc/stream_test.go index 884e688bad..8f1ec165cf 100644 --- a/p2p/transport/webrtc/stream_test.go +++ b/p2p/transport/webrtc/stream_test.go @@ -451,7 +451,7 @@ func TestStreamConcurrentClose(t *testing.T) { func TestStreamResetAfterClose(t *testing.T) { client, _ := getDetachedDataChannels(t) - done := make(chan bool, 1) + done := make(chan bool, 2) clientStr := newStream(client.dc, client.rwc, func() { done <- true }) clientStr.Close()