Skip to content

Commit

Permalink
webrtc: wait for fin_ack for closing datachannel
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed Oct 23, 2023
1 parent 851d1be commit 071849e
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 23 deletions.
7 changes: 7 additions & 0 deletions core/network/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ type MuxedStream interface {
SetWriteDeadline(time.Time) error
}

// AsyncCloser is implemented by streams that need to do expensive operations on close. Closing the stream async avoids
// blocking the calling goroutine.
type AsyncCloser interface {
// AsyncClose closes the stream and executes onDone when stream is closed
AsyncClose(onDone func()) error
}

// MuxedConn represents a connection to a remote peer that has been
// extended to support stream multiplexing.
//
Expand Down
5 changes: 5 additions & 0 deletions p2p/net/swarm/swarm_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ func (s *Stream) Write(p []byte) (int, error) {
// Close closes the stream, closing both ends and freeing all associated
// resources.
func (s *Stream) Close() error {
if as, ok := s.stream.(network.AsyncCloser); ok {
return as.AsyncClose(func() {
s.closeOnce.Do(s.remove)

Check failure on line 83 in p2p/net/swarm/swarm_stream.go

View workflow job for this annotation

GitHub Actions / go-test / ubuntu (go 1.20.x)

s.closeOnce undefined (type *Stream has no field or method closeOnce)

Check failure on line 83 in p2p/net/swarm/swarm_stream.go

View workflow job for this annotation

GitHub Actions / go-test / ubuntu (go 1.20.x)

s.remove undefined (type *Stream has no field or method remove)

Check failure on line 83 in p2p/net/swarm/swarm_stream.go

View workflow job for this annotation

GitHub Actions / go-test / ubuntu (go 1.21.x)

s.closeOnce undefined (type *Stream has no field or method closeOnce)

Check failure on line 83 in p2p/net/swarm/swarm_stream.go

View workflow job for this annotation

GitHub Actions / go-test / ubuntu (go 1.21.x)

s.remove undefined (type *Stream has no field or method remove)

Check failure on line 83 in p2p/net/swarm/swarm_stream.go

View workflow job for this annotation

GitHub Actions / go-test / windows (go 1.20.x)

s.closeOnce undefined (type *Stream has no field or method closeOnce)

Check failure on line 83 in p2p/net/swarm/swarm_stream.go

View workflow job for this annotation

GitHub Actions / go-test / windows (go 1.20.x)

s.remove undefined (type *Stream has no field or method remove)

Check failure on line 83 in p2p/net/swarm/swarm_stream.go

View workflow job for this annotation

GitHub Actions / go-test / windows (go 1.21.x)

s.closeOnce undefined (type *Stream has no field or method closeOnce)

Check failure on line 83 in p2p/net/swarm/swarm_stream.go

View workflow job for this annotation

GitHub Actions / go-test / windows (go 1.21.x)

s.remove undefined (type *Stream has no field or method remove)

Check failure on line 83 in p2p/net/swarm/swarm_stream.go

View workflow job for this annotation

GitHub Actions / go-test / macos (go 1.20.x)

s.closeOnce undefined (type *Stream has no field or method closeOnce)

Check failure on line 83 in p2p/net/swarm/swarm_stream.go

View workflow job for this annotation

GitHub Actions / go-test / macos (go 1.20.x)

s.remove undefined (type *Stream has no field or method remove)

Check failure on line 83 in p2p/net/swarm/swarm_stream.go

View workflow job for this annotation

GitHub Actions / go-test / macos (go 1.21.x)

s.closeOnce undefined (type *Stream has no field or method closeOnce)

Check failure on line 83 in p2p/net/swarm/swarm_stream.go

View workflow job for this annotation

GitHub Actions / go-test / macos (go 1.21.x)

s.remove undefined (type *Stream has no field or method remove)
})
}
err := s.stream.Close()
s.closeAndRemoveStream()
return err
Expand Down
29 changes: 18 additions & 11 deletions p2p/transport/webrtc/pb/message.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions p2p/transport/webrtc/pb/message.proto
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ message Message {
// The sender abruptly terminates the sending part of the stream. The
// receiver can discard any data that it already received on that stream.
RESET = 2;
// Sending the FIN_ACK flag acknowledges the previous receipt of a message
// with the FIN flag set. Receiving a FIN_ACK flag gives the recipient
// confidence that the remote has received all sent messages.
FIN_ACK = 3;
}

optional Flag flag=1;
Expand Down
88 changes: 77 additions & 11 deletions p2p/transport/webrtc/stream.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package libp2pwebrtc

import (
"errors"
"sync"
"time"

Expand Down Expand Up @@ -53,12 +54,16 @@ const (
sendStateSending sendState = iota
sendStateDataSent
sendStateReset
sendStateDataReceived
)

// Package pion detached data channel into a net.Conn
// and then a network.MuxedStream
type stream struct {
mx sync.Mutex

// readMx ensures there's only a single goroutine reading from reader
readMx sync.Mutex
// pbio.Reader is not thread safe,
// and while our Read is not promised to be thread safe,
// we ourselves internally read from multiple routines...
Expand Down Expand Up @@ -132,28 +137,86 @@ func newStream(
}

func (s *stream) Close() error {
// Close read before write to ensure that the STOP_SENDING message is delivered before
// we close the data channel
closeReadErr := s.CloseRead()
closeWriteErr := s.CloseWrite()
if closeWriteErr != nil {
// writing FIN failed, reset the stream
s.Reset()
return closeWriteErr
}
s.waitForFINACK()
s.mx.Lock()
defer s.mx.Unlock()
s.maybeDeclareStreamDone()
return errors.Join(closeWriteErr, closeReadErr)
}

func (s *stream) AsyncClose(onDone func()) error {
// Close read before write to ensure that the STOP_SENDING message is delivered before
// we close the data channel
closeReadErr := s.CloseRead()
closeWriteErr := s.CloseWrite()
if closeWriteErr != nil {
// writing FIN failed, reset the stream
s.Reset()
return closeWriteErr
}
return closeReadErr
go func() {
s.waitForFINACK()
s.mx.Lock()
defer s.mx.Unlock()
s.maybeDeclareStreamDone()
onDone()
}()
return errors.Join(closeWriteErr, closeReadErr)
}

func (s *stream) Reset() error {
cancelWriteErr := s.cancelWrite()
closeReadErr := s.CloseRead()
if cancelWriteErr != nil {
return cancelWriteErr
}
return closeReadErr
dcCloseErr := s.dataChannel.Close()
s.mx.Lock()
defer s.mx.Unlock()
s.maybeDeclareStreamDone()
return errors.Join(cancelWriteErr, closeReadErr, dcCloseErr)
}

func (s *stream) SetDeadline(t time.Time) error {
_ = s.SetReadDeadline(t)
return s.SetWriteDeadline(t)
}

func (s *stream) waitForFINACK() {
s.mx.Lock()
defer s.mx.Unlock()
// We can only wait for FIN_ACK if we've stopped reading from the stream
if s.sendState != sendStateDataSent || s.receiveState == receiveStateReceiving {
return
}
// First wait for any existing readers to exit
s.readMx.Lock()
s.SetReadDeadline(time.Now().Add(5 * time.Second))
var msg pb.Message
for {
s.mx.Unlock()
if err := s.reader.ReadMsg(&msg); err != nil {
s.readMx.Unlock()
s.mx.Lock()
s.sendState = sendStateDataReceived
break
}
s.readMx.Unlock()
s.mx.Lock()
s.processIncomingFlag(msg.Flag)
if s.sendState != sendStateDataSent {
break
}
s.readMx.Lock()
}
}

// processIncomingFlag process the flag on an incoming message
// It needs to be called with msg.Flag, not msg.GetFlag(),
// otherwise we'd misinterpret the default value.
Expand All @@ -168,6 +231,9 @@ func (s *stream) processIncomingFlag(flag *pb.Message_Flag) {
if s.receiveState == receiveStateReceiving {
s.receiveState = receiveStateDataRead
}
if err := s.sendControlMessage(&pb.Message{Flag: pb.Message_FIN_ACK.Enum()}); err != nil {
log.Debugf("failed to send FIN_ACK:", err)
}
case pb.Message_STOP_SENDING:
if s.sendState == sendStateSending {
s.sendState = sendStateReset
Expand All @@ -180,24 +246,24 @@ func (s *stream) processIncomingFlag(flag *pb.Message_Flag) {
if s.receiveState == receiveStateReceiving {
s.receiveState = receiveStateReset
}
case pb.Message_FIN_ACK:
if s.sendState == sendStateDataSent {
s.sendState = sendStateDataReceived
}
}
s.maybeDeclareStreamDone()
}

// maybeDeclareStreamDone is used to force reset a stream. It must be called with mx acquired
func (s *stream) maybeDeclareStreamDone() {
if (s.sendState == sendStateReset || s.sendState == sendStateDataSent) &&
if (s.sendState == sendStateReset || s.sendState == sendStateDataReceived) &&
(s.receiveState == receiveStateReset || s.receiveState == receiveStateDataRead) &&
len(s.controlMsgQueue) == 0 {

s.mx.Unlock()
defer s.mx.Lock()
_ = s.SetReadDeadline(time.Now().Add(-1 * time.Hour)) // pion ignores zero times
// TODO: we should be closing the underlying datachannel, but this resets the stream
// See https://github.com/libp2p/specs/issues/575 for details.
// _ = s.dataChannel.Close()
// TODO: write for the spawned reader to return

s.dataChannel.Close()
s.onDone()
}
}
Expand Down
3 changes: 3 additions & 0 deletions p2p/transport/webrtc/stream_read.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ func (s *stream) Read(b []byte) (int, error) {
// load the next message
s.mx.Unlock()
var msg pb.Message
s.readMx.Lock()
if err := s.reader.ReadMsg(&msg); err != nil {
s.readMx.Unlock()
s.mx.Lock()
if err == io.EOF {
// if the channel was properly closed, return EOF
Expand All @@ -48,6 +50,7 @@ func (s *stream) Read(b []byte) (int, error) {
}
return 0, err
}
s.readMx.Unlock()
s.mx.Lock()
s.nextMessage = &msg
}
Expand Down
33 changes: 32 additions & 1 deletion p2p/transport/webrtc/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,15 @@ func TestStreamSimpleReadWriteClose(t *testing.T) {
_, err = serverStr.Write([]byte("lorem ipsum"))
require.NoError(t, err)
require.NoError(t, serverStr.CloseWrite())
require.True(t, serverDone)
// and read it at the client
require.False(t, clientDone)
b, err = io.ReadAll(clientStr)
require.NoError(t, err)
require.Equal(t, []byte("lorem ipsum"), b)
require.True(t, clientDone)
// Need to call Close for cleanup. Otherwise the FIN_ACK is never read
require.NoError(t, serverStr.Close())
require.True(t, serverDone)
}

func TestStreamPartialReads(t *testing.T) {
Expand Down Expand Up @@ -332,3 +334,32 @@ func TestStreamReadAfterClose(t *testing.T) {
_, err = clientStr.Read(nil)
require.ErrorIs(t, err, network.ErrReset)
}

func TestStreamCloseAfterFINACK(t *testing.T) {
client, server := getDetachedDataChannels(t)

done := make(chan bool, 1)
clientStr := newStream(client.dc, client.rwc, func() { done <- true })
serverStr := newStream(server.dc, server.rwc, func() {})

go func() {
done <- true
clientStr.Close()
}()
<-done

select {
case <-done:
t.Fatalf("Close should not have completed without processing FIN_ACK")
case <-time.After(3 * time.Second):
}

b := make([]byte, 1)
_, err := serverStr.Read(b)
require.Error(t, err)
select {
case <-done:
case <-time.After(3 * time.Second):
t.Fatalf("Close should have completed")
}
}

0 comments on commit 071849e

Please sign in to comment.