Skip to content

Commit

Permalink
webrtc: cleanup FIN_ACK procedure
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed Oct 23, 2023
1 parent 416c934 commit ca1cf7d
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 65 deletions.
1 change: 0 additions & 1 deletion p2p/transport/webrtc/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,6 @@ func SetupDataChannelQueue(pc *webrtc.PeerConnection, queueLen int) chan Detache
rwc.Close()
}
})

})
return queue
}
76 changes: 33 additions & 43 deletions p2p/transport/webrtc/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package libp2pwebrtc

import (
"errors"
"io"
"os"
"sync"
"time"
Expand Down Expand Up @@ -46,7 +47,7 @@ type receiveState uint8
const (
receiveStateReceiving receiveState = iota
receiveStateDataRead // received and read the FIN
receiveStateReset // either by calling CloseRead locally, or by receiving
receiveStateReset // by calling CloseRead locally or receiving RESET
)

type sendState uint8
Expand Down Expand Up @@ -129,7 +130,6 @@ func newStream(
return
}
}
s.maybeDeclareStreamDone()
select {
case s.writeAvailable <- struct{}{}:
default:
Expand All @@ -148,10 +148,10 @@ func (s *stream) Close() error {
s.Reset()
return closeWriteErr
}
s.waitForFINACK()
s.processPendingMessages()
s.mx.Lock()
defer s.mx.Unlock()
s.maybeDeclareStreamDone()
s.declareStreamDone()
return errors.Join(closeWriteErr, closeReadErr)
}

Expand All @@ -163,15 +163,20 @@ func (s *stream) AsyncClose(onDone func()) error {
if closeWriteErr != nil {
// writing FIN failed, reset the stream
s.Reset()
onDone()
s.declareStreamDone()
if onDone != nil {
onDone()
}
return closeWriteErr
}
go func() {
s.waitForFINACK()
s.processPendingMessages()
s.mx.Lock()
defer s.mx.Unlock()
s.maybeDeclareStreamDone()
onDone()
s.declareStreamDone()
if onDone != nil {
onDone()
}
}()
return errors.Join(closeWriteErr, closeReadErr)
}
Expand All @@ -182,7 +187,7 @@ func (s *stream) Reset() error {
dcCloseErr := s.dataChannel.Close()
s.mx.Lock()
defer s.mx.Unlock()
s.maybeDeclareStreamDone()
s.declareStreamDone()
return errors.Join(cancelWriteErr, closeReadErr, dcCloseErr)
}

Expand All @@ -191,37 +196,30 @@ func (s *stream) SetDeadline(t time.Time) error {
return s.SetWriteDeadline(t)
}

func (s *stream) waitForFINACK() {
s.mx.Lock()
defer s.mx.Unlock()
// Only wait for FIN_ACK if we are waiting for FIN_ACK and we have stopped reading from the stream
if s.sendState != sendStateDataSent || s.receiveState == receiveStateReceiving {
return
}
func (s *stream) processPendingMessages() {
// First wait for any existing readers to exit
s.SetReadDeadline(time.Now().Add(-1 * time.Minute))
s.SetReadDeadline(time.Now().Add(-1 * time.Hour))

s.readerMx.Lock()
defer s.readerMx.Unlock()

s.mx.Lock()
sendState := s.sendState
s.mx.Unlock()
s.SetReadDeadline(time.Now().Add(10 * time.Second))
var msg pb.Message
for {
s.mx.Unlock()
for sendState == sendStateDataSent {
if err := s.reader.ReadMsg(&msg); err != nil {
s.readerMx.Unlock()
s.mx.Lock()
// 10 seconds is enough time for the message to be delivered. The peer just hasn't responded
// with FIN_ACK
if errors.Is(err, os.ErrDeadlineExceeded) {
s.sendState = sendStateDataReceived
if errors.Is(err, os.ErrDeadlineExceeded) || errors.Is(err, io.EOF) {
break
}
break
}
s.readerMx.Unlock()
s.mx.Lock()
s.processIncomingFlag(msg.Flag)
if s.sendState != sendStateDataSent {
break
}
s.readerMx.Lock()
sendState = s.sendState
s.mx.Unlock()
}
}

Expand All @@ -236,7 +234,7 @@ func (s *stream) processIncomingFlag(flag *pb.Message_Flag) {

switch *flag {
case pb.Message_FIN:
if s.receiveState == receiveStateReceiving {
if s.receiveState == receiveStateReceiving || s.receiveState == receiveStateReset {
s.receiveState = receiveStateDataRead
}
if err := s.sendControlMessage(&pb.Message{Flag: pb.Message_FIN_ACK.Enum()}); err != nil {
Expand All @@ -259,21 +257,13 @@ func (s *stream) processIncomingFlag(flag *pb.Message_Flag) {
s.sendState = sendStateDataReceived
}
}
s.maybeDeclareStreamDone()
}

// maybeDeclareStreamDone is used to force reset a stream. It must be called with mx acquired
func (s *stream) maybeDeclareStreamDone() {
if (s.sendState == sendStateReset || s.sendState == sendStateDataReceived) &&
(s.receiveState == receiveStateReset || s.receiveState == receiveStateDataRead) &&
len(s.controlMsgQueue) == 0 {

s.mx.Unlock()
defer s.mx.Lock()
_ = s.SetReadDeadline(time.Now().Add(-1 * time.Hour)) // pion ignores zero times
s.dataChannel.Close()
s.onDone()
}
// declareStreamDone cleansup the stream
func (s *stream) declareStreamDone() {
_ = s.SetReadDeadline(time.Now().Add(-1 * time.Hour)) // pion ignores zero times
s.dataChannel.Close()
s.onDone()
}

func (s *stream) setCloseError(e error) {
Expand Down
10 changes: 3 additions & 7 deletions p2p/transport/webrtc/stream_read.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package libp2pwebrtc

import (
"errors"
"io"
"time"

Expand Down Expand Up @@ -43,7 +42,7 @@ func (s *stream) Read(b []byte) (int, error) {
}
// This case occurs when the remote node closes the stream without writing a FIN message
// There's little we can do here
return 0, errors.New("didn't receive final state for stream")
s.receiveState = receiveStateReset
}
if s.receiveState == receiveStateReset {
return 0, network.ErrReset
Expand All @@ -61,7 +60,6 @@ func (s *stream) Read(b []byte) (int, error) {
s.nextMessage.Message = s.nextMessage.Message[n:]
return read, nil
}

// process flags on the message after reading all the data
s.processIncomingFlag(s.nextMessage.Flag)
s.nextMessage = nil
Expand Down Expand Up @@ -91,12 +89,10 @@ func (s *stream) CloseRead() error {
var err error
if s.receiveState == receiveStateReceiving && s.closeErr == nil {
err = s.sendControlMessage(&pb.Message{Flag: pb.Message_STOP_SENDING.Enum()})
s.receiveState = receiveStateReset
}
s.receiveState = receiveStateReset
s.maybeDeclareStreamDone()

// make any calls to Read blocking on ReadMsg return immediately
s.dataChannel.SetReadDeadline(time.Now())
s.dataChannel.SetReadDeadline(time.Now().Add(-1 * time.Hour))

return err
}
48 changes: 43 additions & 5 deletions p2p/transport/webrtc/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,15 @@ func TestStreamSimpleReadWriteClose(t *testing.T) {
b, err = io.ReadAll(clientStr)
require.NoError(t, err)
require.Equal(t, []byte("lorem ipsum"), b)
require.True(t, clientDone)
require.False(t, clientDone)

// stream is only cleaned up on calling Close or AsyncClose or Reset
clientStr.AsyncClose(nil)
serverStr.AsyncClose(nil)
require.Eventually(t, func() bool { return clientDone }, 10*time.Second, 100*time.Millisecond)
// Need to call Close for cleanup. Otherwise the FIN_ACK is never read
require.NoError(t, serverStr.Close())
require.True(t, serverDone)
require.Eventually(t, func() bool { return serverDone }, 10*time.Second, 100*time.Millisecond)
}

func TestStreamPartialReads(t *testing.T) {
Expand Down Expand Up @@ -205,7 +210,7 @@ func TestStreamReadReturnsOnClose(t *testing.T) {
errChan <- err
}()
time.Sleep(50 * time.Millisecond) // give the Read call some time to hit the loop
require.NoError(t, clientStr.Close())
require.NoError(t, clientStr.AsyncClose(nil))
select {
case err := <-errChan:
require.ErrorIs(t, err, network.ErrReset)
Expand Down Expand Up @@ -245,7 +250,8 @@ func TestStreamResets(t *testing.T) {
_, err := serverStr.Write([]byte("foobar"))
return errors.Is(err, network.ErrReset)
}, time.Second, 50*time.Millisecond)
require.True(t, serverDone)
serverStr.AsyncClose(nil)
require.Eventually(t, func() bool { return serverDone }, 5*time.Second, 100*time.Millisecond)
}

func TestStreamReadDeadlineAsync(t *testing.T) {
Expand Down Expand Up @@ -315,7 +321,7 @@ func TestStreamReadAfterClose(t *testing.T) {
clientStr := newStream(client.dc, client.rwc, func() {})
serverStr := newStream(server.dc, server.rwc, func() {})

serverStr.Close()
serverStr.AsyncClose(nil)
b := make([]byte, 1)
_, err := clientStr.Read(b)
require.Equal(t, io.EOF, err)
Expand Down Expand Up @@ -363,3 +369,35 @@ func TestStreamCloseAfterFINACK(t *testing.T) {
t.Fatalf("Close should have completed")
}
}

func TestStreamFinAckAfterStopSending(t *testing.T) {
client, server := getDetachedDataChannels(t)

done := make(chan bool, 1)
clientStr := newStream(client.dc, client.rwc, func() { done <- true })
serverStr := newStream(server.dc, server.rwc, func() {})

go func() {
done <- true
clientStr.CloseRead()
clientStr.Write([]byte("hello world"))
clientStr.Close()
}()
<-done

select {
case <-done:
t.Fatalf("Close should not have completed without processing FIN_ACK")
case <-time.After(1 * time.Second):
}

b := make([]byte, 24)
_, err := serverStr.Read(b)
require.NoError(t, err)
serverStr.Close() // Sends stop_sending, fin
select {
case <-done:
case <-time.After(5 * time.Second):
t.Fatalf("Close should have completed")
}
}
2 changes: 0 additions & 2 deletions p2p/transport/webrtc/stream_write.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ func (s *stream) cancelWrite() error {
if err := s.sendControlMessage(&pb.Message{Flag: pb.Message_RESET.Enum()}); err != nil {
return err
}
s.maybeDeclareStreamDone()
return nil
}

Expand All @@ -196,6 +195,5 @@ func (s *stream) CloseWrite() error {
if err := s.sendControlMessage(&pb.Message{Flag: pb.Message_FIN.Enum()}); err != nil {
return err
}
s.maybeDeclareStreamDone()
return nil
}
14 changes: 7 additions & 7 deletions p2p/transport/webrtc/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -470,17 +470,17 @@ func TestTransportWebRTC_RemoteReadsAfterClose(t *testing.T) {
done <- err
return
}
stream, err := lconn.AcceptStream()
s, err := lconn.AcceptStream()
if err != nil {
done <- err
return
}
_, err = stream.Write([]byte{1, 2, 3, 4})
_, err = s.Write([]byte{1, 2, 3, 4})
if err != nil {
done <- err
return
}
err = stream.Close()
err = s.(*stream).AsyncClose(nil)
if err != nil {
done <- err
return
Expand Down Expand Up @@ -541,12 +541,12 @@ func TestTransportWebRTC_RemoteReadsAfterClose2(t *testing.T) {

conn, err := tr1.Dial(context.Background(), listener.Multiaddr(), listeningPeer)
require.NoError(t, err)
// create a stream
stream, err := conn.OpenStream(context.Background())
// create a s
s, err := conn.OpenStream(context.Background())
require.NoError(t, err)
_, err = stream.Write([]byte{1, 2, 3, 4})
_, err = s.Write([]byte{1, 2, 3, 4})
require.NoError(t, err)
err = stream.Close()
err = s.(*stream).AsyncClose(nil)
require.NoError(t, err)
// signal stream closure
close(awaitStreamClosure)
Expand Down

0 comments on commit ca1cf7d

Please sign in to comment.