diff --git a/p2p/transport/webrtc/connection.go b/p2p/transport/webrtc/connection.go index 5a0a8fc5f3..bcbab352e3 100644 --- a/p2p/transport/webrtc/connection.go +++ b/p2p/transport/webrtc/connection.go @@ -337,7 +337,6 @@ func SetupDataChannelQueue(pc *webrtc.PeerConnection, queueLen int) chan Detache rwc.Close() } }) - }) return queue } diff --git a/p2p/transport/webrtc/stream.go b/p2p/transport/webrtc/stream.go index d31e9645e0..35aa328e74 100644 --- a/p2p/transport/webrtc/stream.go +++ b/p2p/transport/webrtc/stream.go @@ -2,6 +2,7 @@ package libp2pwebrtc import ( "errors" + "io" "os" "sync" "time" @@ -46,7 +47,7 @@ type receiveState uint8 const ( receiveStateReceiving receiveState = iota receiveStateDataRead // received and read the FIN - receiveStateReset // either by calling CloseRead locally, or by receiving + receiveStateReset // by calling CloseRead locally or receiving RESET ) type sendState uint8 @@ -129,7 +130,6 @@ func newStream( return } } - s.maybeDeclareStreamDone() select { case s.writeAvailable <- struct{}{}: default: @@ -148,10 +148,10 @@ func (s *stream) Close() error { s.Reset() return closeWriteErr } - s.waitForFINACK() + s.processPendingMessages() s.mx.Lock() defer s.mx.Unlock() - s.maybeDeclareStreamDone() + s.declareStreamDone() return errors.Join(closeWriteErr, closeReadErr) } @@ -163,15 +163,20 @@ func (s *stream) AsyncClose(onDone func()) error { if closeWriteErr != nil { // writing FIN failed, reset the stream s.Reset() - onDone() + s.declareStreamDone() + if onDone != nil { + onDone() + } return closeWriteErr } go func() { - s.waitForFINACK() + s.processPendingMessages() s.mx.Lock() defer s.mx.Unlock() - s.maybeDeclareStreamDone() - onDone() + s.declareStreamDone() + if onDone != nil { + onDone() + } }() return errors.Join(closeWriteErr, closeReadErr) } @@ -182,7 +187,7 @@ func (s *stream) Reset() error { dcCloseErr := s.dataChannel.Close() s.mx.Lock() defer s.mx.Unlock() - s.maybeDeclareStreamDone() + s.declareStreamDone() return errors.Join(cancelWriteErr, closeReadErr, dcCloseErr) } @@ -191,37 +196,30 @@ func (s *stream) SetDeadline(t time.Time) error { return s.SetWriteDeadline(t) } -func (s *stream) waitForFINACK() { - s.mx.Lock() - defer s.mx.Unlock() - // Only wait for FIN_ACK if we are waiting for FIN_ACK and we have stopped reading from the stream - if s.sendState != sendStateDataSent || s.receiveState == receiveStateReceiving { - return - } +func (s *stream) processPendingMessages() { // First wait for any existing readers to exit - s.SetReadDeadline(time.Now().Add(-1 * time.Minute)) + s.SetReadDeadline(time.Now().Add(-1 * time.Hour)) + s.readerMx.Lock() + defer s.readerMx.Unlock() + + s.mx.Lock() + sendState := s.sendState + s.mx.Unlock() s.SetReadDeadline(time.Now().Add(10 * time.Second)) var msg pb.Message - for { - s.mx.Unlock() + for sendState == sendStateDataSent { if err := s.reader.ReadMsg(&msg); err != nil { - s.readerMx.Unlock() - s.mx.Lock() // 10 seconds is enough time for the message to be delivered. The peer just hasn't responded // with FIN_ACK - if errors.Is(err, os.ErrDeadlineExceeded) { - s.sendState = sendStateDataReceived + if errors.Is(err, os.ErrDeadlineExceeded) || errors.Is(err, io.EOF) { + break } - break } - s.readerMx.Unlock() s.mx.Lock() s.processIncomingFlag(msg.Flag) - if s.sendState != sendStateDataSent { - break - } - s.readerMx.Lock() + sendState = s.sendState + s.mx.Unlock() } } @@ -236,7 +234,7 @@ func (s *stream) processIncomingFlag(flag *pb.Message_Flag) { switch *flag { case pb.Message_FIN: - if s.receiveState == receiveStateReceiving { + if s.receiveState == receiveStateReceiving || s.receiveState == receiveStateReset { s.receiveState = receiveStateDataRead } if err := s.sendControlMessage(&pb.Message{Flag: pb.Message_FIN_ACK.Enum()}); err != nil { @@ -259,21 +257,13 @@ func (s *stream) processIncomingFlag(flag *pb.Message_Flag) { s.sendState = sendStateDataReceived } } - s.maybeDeclareStreamDone() } -// maybeDeclareStreamDone is used to force reset a stream. It must be called with mx acquired -func (s *stream) maybeDeclareStreamDone() { - if (s.sendState == sendStateReset || s.sendState == sendStateDataReceived) && - (s.receiveState == receiveStateReset || s.receiveState == receiveStateDataRead) && - len(s.controlMsgQueue) == 0 { - - s.mx.Unlock() - defer s.mx.Lock() - _ = s.SetReadDeadline(time.Now().Add(-1 * time.Hour)) // pion ignores zero times - s.dataChannel.Close() - s.onDone() - } +// declareStreamDone cleansup the stream +func (s *stream) declareStreamDone() { + _ = s.SetReadDeadline(time.Now().Add(-1 * time.Hour)) // pion ignores zero times + s.dataChannel.Close() + s.onDone() } func (s *stream) setCloseError(e error) { diff --git a/p2p/transport/webrtc/stream_read.go b/p2p/transport/webrtc/stream_read.go index 9d2f9a3c13..17895126bd 100644 --- a/p2p/transport/webrtc/stream_read.go +++ b/p2p/transport/webrtc/stream_read.go @@ -1,7 +1,6 @@ package libp2pwebrtc import ( - "errors" "io" "time" @@ -43,7 +42,7 @@ func (s *stream) Read(b []byte) (int, error) { } // This case occurs when the remote node closes the stream without writing a FIN message // There's little we can do here - return 0, errors.New("didn't receive final state for stream") + s.receiveState = receiveStateReset } if s.receiveState == receiveStateReset { return 0, network.ErrReset @@ -61,7 +60,6 @@ func (s *stream) Read(b []byte) (int, error) { s.nextMessage.Message = s.nextMessage.Message[n:] return read, nil } - // process flags on the message after reading all the data s.processIncomingFlag(s.nextMessage.Flag) s.nextMessage = nil @@ -91,12 +89,10 @@ func (s *stream) CloseRead() error { var err error if s.receiveState == receiveStateReceiving && s.closeErr == nil { err = s.sendControlMessage(&pb.Message{Flag: pb.Message_STOP_SENDING.Enum()}) + s.receiveState = receiveStateReset } - s.receiveState = receiveStateReset - s.maybeDeclareStreamDone() - // make any calls to Read blocking on ReadMsg return immediately - s.dataChannel.SetReadDeadline(time.Now()) + s.dataChannel.SetReadDeadline(time.Now().Add(-1 * time.Hour)) return err } diff --git a/p2p/transport/webrtc/stream_test.go b/p2p/transport/webrtc/stream_test.go index 6cdfaa7cd3..3f4a98b679 100644 --- a/p2p/transport/webrtc/stream_test.go +++ b/p2p/transport/webrtc/stream_test.go @@ -131,10 +131,15 @@ func TestStreamSimpleReadWriteClose(t *testing.T) { b, err = io.ReadAll(clientStr) require.NoError(t, err) require.Equal(t, []byte("lorem ipsum"), b) - require.True(t, clientDone) + require.False(t, clientDone) + + // stream is only cleaned up on calling Close or AsyncClose or Reset + clientStr.AsyncClose(nil) + serverStr.AsyncClose(nil) + require.Eventually(t, func() bool { return clientDone }, 10*time.Second, 100*time.Millisecond) // Need to call Close for cleanup. Otherwise the FIN_ACK is never read require.NoError(t, serverStr.Close()) - require.True(t, serverDone) + require.Eventually(t, func() bool { return serverDone }, 10*time.Second, 100*time.Millisecond) } func TestStreamPartialReads(t *testing.T) { @@ -205,7 +210,7 @@ func TestStreamReadReturnsOnClose(t *testing.T) { errChan <- err }() time.Sleep(50 * time.Millisecond) // give the Read call some time to hit the loop - require.NoError(t, clientStr.Close()) + require.NoError(t, clientStr.AsyncClose(nil)) select { case err := <-errChan: require.ErrorIs(t, err, network.ErrReset) @@ -245,7 +250,8 @@ func TestStreamResets(t *testing.T) { _, err := serverStr.Write([]byte("foobar")) return errors.Is(err, network.ErrReset) }, time.Second, 50*time.Millisecond) - require.True(t, serverDone) + serverStr.AsyncClose(nil) + require.Eventually(t, func() bool { return serverDone }, 5*time.Second, 100*time.Millisecond) } func TestStreamReadDeadlineAsync(t *testing.T) { @@ -315,7 +321,7 @@ func TestStreamReadAfterClose(t *testing.T) { clientStr := newStream(client.dc, client.rwc, func() {}) serverStr := newStream(server.dc, server.rwc, func() {}) - serverStr.Close() + serverStr.AsyncClose(nil) b := make([]byte, 1) _, err := clientStr.Read(b) require.Equal(t, io.EOF, err) @@ -363,3 +369,35 @@ func TestStreamCloseAfterFINACK(t *testing.T) { t.Fatalf("Close should have completed") } } + +func TestStreamFinAckAfterStopSending(t *testing.T) { + client, server := getDetachedDataChannels(t) + + done := make(chan bool, 1) + clientStr := newStream(client.dc, client.rwc, func() { done <- true }) + serverStr := newStream(server.dc, server.rwc, func() {}) + + go func() { + done <- true + clientStr.CloseRead() + clientStr.Write([]byte("hello world")) + clientStr.Close() + }() + <-done + + select { + case <-done: + t.Fatalf("Close should not have completed without processing FIN_ACK") + case <-time.After(1 * time.Second): + } + + b := make([]byte, 24) + _, err := serverStr.Read(b) + require.NoError(t, err) + serverStr.Close() // Sends stop_sending, fin + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatalf("Close should have completed") + } +} diff --git a/p2p/transport/webrtc/stream_write.go b/p2p/transport/webrtc/stream_write.go index 698af9c4d6..59d8a8d5a4 100644 --- a/p2p/transport/webrtc/stream_write.go +++ b/p2p/transport/webrtc/stream_write.go @@ -181,7 +181,6 @@ func (s *stream) cancelWrite() error { if err := s.sendControlMessage(&pb.Message{Flag: pb.Message_RESET.Enum()}); err != nil { return err } - s.maybeDeclareStreamDone() return nil } @@ -196,6 +195,5 @@ func (s *stream) CloseWrite() error { if err := s.sendControlMessage(&pb.Message{Flag: pb.Message_FIN.Enum()}); err != nil { return err } - s.maybeDeclareStreamDone() return nil } diff --git a/p2p/transport/webrtc/transport_test.go b/p2p/transport/webrtc/transport_test.go index 7f4df94fc1..9e9ae0039a 100644 --- a/p2p/transport/webrtc/transport_test.go +++ b/p2p/transport/webrtc/transport_test.go @@ -470,17 +470,17 @@ func TestTransportWebRTC_RemoteReadsAfterClose(t *testing.T) { done <- err return } - stream, err := lconn.AcceptStream() + s, err := lconn.AcceptStream() if err != nil { done <- err return } - _, err = stream.Write([]byte{1, 2, 3, 4}) + _, err = s.Write([]byte{1, 2, 3, 4}) if err != nil { done <- err return } - err = stream.Close() + err = s.(*stream).AsyncClose(nil) if err != nil { done <- err return @@ -541,12 +541,12 @@ func TestTransportWebRTC_RemoteReadsAfterClose2(t *testing.T) { conn, err := tr1.Dial(context.Background(), listener.Multiaddr(), listeningPeer) require.NoError(t, err) - // create a stream - stream, err := conn.OpenStream(context.Background()) + // create a s + s, err := conn.OpenStream(context.Background()) require.NoError(t, err) - _, err = stream.Write([]byte{1, 2, 3, 4}) + _, err = s.Write([]byte{1, 2, 3, 4}) require.NoError(t, err) - err = stream.Close() + err = s.(*stream).AsyncClose(nil) require.NoError(t, err) // signal stream closure close(awaitStreamClosure)