Skip to content

Commit

Permalink
webrtc: wait for fin_ack for closing datachannel
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed Oct 23, 2023
1 parent b7c04f8 commit 6e4d8c2
Show file tree
Hide file tree
Showing 9 changed files with 192 additions and 37 deletions.
7 changes: 7 additions & 0 deletions core/network/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ type MuxedStream interface {
SetWriteDeadline(time.Time) error
}

// AsyncCloser is implemented by streams that need to do expensive operations on close. Closing the stream async avoids
// blocking the calling goroutine.
type AsyncCloser interface {
// AsyncClose closes the stream and executes onDone when stream is closed
AsyncClose(onDone func()) error
}

// MuxedConn represents a connection to a remote peer that has been
// extended to support stream multiplexing.
//
Expand Down
6 changes: 6 additions & 0 deletions p2p/net/swarm/swarm_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ func (s *Stream) Write(p []byte) (int, error) {
// Close closes the stream, closing both ends and freeing all associated
// resources.
func (s *Stream) Close() error {
if as, ok := s.stream.(network.AsyncCloser); ok {
err := as.AsyncClose(func() {
s.closeAndRemoveStream()
})
return err
}
err := s.stream.Close()
s.closeAndRemoveStream()
return err
Expand Down
45 changes: 45 additions & 0 deletions p2p/net/swarm/swarm_stream_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package swarm

import (
"context"
"sync/atomic"
"testing"

"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peerstore"
"github.com/stretchr/testify/require"
)

type asyncStreamWrapper struct {
network.MuxedStream
before func()
}

func (s *asyncStreamWrapper) AsyncClose(onDone func()) error {
s.before()
err := s.Close()
onDone()
return err
}

func TestStreamAsyncCloser(t *testing.T) {
s1 := makeSwarm(t)
s2 := makeSwarm(t)

s1.Peerstore().AddAddrs(s2.LocalPeer(), s2.ListenAddresses(), peerstore.TempAddrTTL)
s, err := s1.NewStream(context.Background(), s2.LocalPeer())
require.NoError(t, err)
ss, ok := s.(*Stream)
require.True(t, ok)

var called atomic.Bool
as := &asyncStreamWrapper{
MuxedStream: ss.stream,
before: func() {
called.Store(true)
},
}
ss.stream = as
ss.Close()
require.True(t, called.Load())
}
14 changes: 0 additions & 14 deletions p2p/transport/webrtc/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ var _ tpt.CapableConn = &connection{}
// maxAcceptQueueLen is the number of waiting streams.
const maxAcceptQueueLen = 256

const maxDataChannelID = 1 << 10

type errConnectionTimeout struct{}

var _ net.Error = &errConnectionTimeout{}
Expand Down Expand Up @@ -195,12 +193,6 @@ func (c *connection) OpenStream(ctx context.Context) (network.MuxedStream, error
if id > math.MaxUint16 {
return nil, errors.New("exhausted stream ID space")
}
// Limit the number of streams, since we're not able to actually properly close them.
// See https://github.com/libp2p/specs/issues/575 for details.
if id > maxDataChannelID {
c.Close()
return c.OpenStream(ctx)
}

streamID := uint16(id)
dc, err := c.pc.CreateDataChannel("", &webrtc.DataChannelInit{ID: &streamID})
Expand Down Expand Up @@ -329,12 +321,6 @@ func (c *connection) setRemotePublicKey(key ic.PubKey) {
func SetupDataChannelQueue(pc *webrtc.PeerConnection, queueLen int) chan DetachedDataChannel {
queue := make(chan DetachedDataChannel, queueLen)
pc.OnDataChannel(func(dc *webrtc.DataChannel) {
// Limit the number of streams, since we're not able to actually properly close them.
// See https://github.com/libp2p/specs/issues/575 for details.
if *dc.ID() > maxDataChannelID {
dc.Close()
return
}
dc.OnOpen(func() {
rwc, err := dc.Detach()
if err != nil {
Expand Down
29 changes: 18 additions & 11 deletions p2p/transport/webrtc/pb/message.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions p2p/transport/webrtc/pb/message.proto
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ message Message {
// The sender abruptly terminates the sending part of the stream. The
// receiver can discard any data that it already received on that stream.
RESET = 2;
// Sending the FIN_ACK flag acknowledges the previous receipt of a message
// with the FIN flag set. Receiving a FIN_ACK flag gives the recipient
// confidence that the remote has received all sent messages.
FIN_ACK = 3;
}

optional Flag flag=1;
Expand Down
88 changes: 77 additions & 11 deletions p2p/transport/webrtc/stream.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package libp2pwebrtc

import (
"errors"
"sync"
"time"

Expand Down Expand Up @@ -53,12 +54,16 @@ const (
sendStateSending sendState = iota
sendStateDataSent
sendStateReset
sendStateDataReceived
)

// Package pion detached data channel into a net.Conn
// and then a network.MuxedStream
type stream struct {
mx sync.Mutex

// readMx ensures there's only a single goroutine reading from reader
readMx sync.Mutex
// pbio.Reader is not thread safe,
// and while our Read is not promised to be thread safe,
// we ourselves internally read from multiple routines...
Expand Down Expand Up @@ -132,28 +137,86 @@ func newStream(
}

func (s *stream) Close() error {
// Close read before write to ensure that the STOP_SENDING message is delivered before
// we close the data channel
closeReadErr := s.CloseRead()
closeWriteErr := s.CloseWrite()
if closeWriteErr != nil {
// writing FIN failed, reset the stream
s.Reset()
return closeWriteErr
}
s.waitForFINACK()
s.mx.Lock()
defer s.mx.Unlock()
s.maybeDeclareStreamDone()
return errors.Join(closeWriteErr, closeReadErr)
}

func (s *stream) AsyncClose(onDone func()) error {
// Close read before write to ensure that the STOP_SENDING message is delivered before
// we close the data channel
closeReadErr := s.CloseRead()
closeWriteErr := s.CloseWrite()
if closeWriteErr != nil {
// writing FIN failed, reset the stream
s.Reset()
return closeWriteErr
}
return closeReadErr
go func() {
s.waitForFINACK()
s.mx.Lock()
defer s.mx.Unlock()
s.maybeDeclareStreamDone()
onDone()
}()
return errors.Join(closeWriteErr, closeReadErr)
}

func (s *stream) Reset() error {
cancelWriteErr := s.cancelWrite()
closeReadErr := s.CloseRead()
if cancelWriteErr != nil {
return cancelWriteErr
}
return closeReadErr
dcCloseErr := s.dataChannel.Close()
s.mx.Lock()
defer s.mx.Unlock()
s.maybeDeclareStreamDone()
return errors.Join(cancelWriteErr, closeReadErr, dcCloseErr)
}

func (s *stream) SetDeadline(t time.Time) error {
_ = s.SetReadDeadline(t)
return s.SetWriteDeadline(t)
}

func (s *stream) waitForFINACK() {
s.mx.Lock()
defer s.mx.Unlock()
// We can only wait for FIN_ACK if we've stopped reading from the stream
if s.sendState != sendStateDataSent || s.receiveState == receiveStateReceiving {
return
}
// First wait for any existing readers to exit
s.readMx.Lock()
s.SetReadDeadline(time.Now().Add(5 * time.Second))
var msg pb.Message
for {
s.mx.Unlock()
if err := s.reader.ReadMsg(&msg); err != nil {
s.readMx.Unlock()
s.mx.Lock()
s.sendState = sendStateDataReceived
break
}
s.readMx.Unlock()
s.mx.Lock()
s.processIncomingFlag(msg.Flag)
if s.sendState != sendStateDataSent {
break
}
s.readMx.Lock()
}
}

// processIncomingFlag process the flag on an incoming message
// It needs to be called with msg.Flag, not msg.GetFlag(),
// otherwise we'd misinterpret the default value.
Expand All @@ -168,6 +231,9 @@ func (s *stream) processIncomingFlag(flag *pb.Message_Flag) {
if s.receiveState == receiveStateReceiving {
s.receiveState = receiveStateDataRead
}
if err := s.sendControlMessage(&pb.Message{Flag: pb.Message_FIN_ACK.Enum()}); err != nil {
log.Debugf("failed to send FIN_ACK:", err)
}
case pb.Message_STOP_SENDING:
if s.sendState == sendStateSending {
s.sendState = sendStateReset
Expand All @@ -180,24 +246,24 @@ func (s *stream) processIncomingFlag(flag *pb.Message_Flag) {
if s.receiveState == receiveStateReceiving {
s.receiveState = receiveStateReset
}
case pb.Message_FIN_ACK:
if s.sendState == sendStateDataSent {
s.sendState = sendStateDataReceived
}
}
s.maybeDeclareStreamDone()
}

// maybeDeclareStreamDone is used to force reset a stream. It must be called with mx acquired
func (s *stream) maybeDeclareStreamDone() {
if (s.sendState == sendStateReset || s.sendState == sendStateDataSent) &&
if (s.sendState == sendStateReset || s.sendState == sendStateDataReceived) &&
(s.receiveState == receiveStateReset || s.receiveState == receiveStateDataRead) &&
len(s.controlMsgQueue) == 0 {

s.mx.Unlock()
defer s.mx.Lock()
_ = s.SetReadDeadline(time.Now().Add(-1 * time.Hour)) // pion ignores zero times
// TODO: we should be closing the underlying datachannel, but this resets the stream
// See https://github.com/libp2p/specs/issues/575 for details.
// _ = s.dataChannel.Close()
// TODO: write for the spawned reader to return

s.dataChannel.Close()
s.onDone()
}
}
Expand Down
3 changes: 3 additions & 0 deletions p2p/transport/webrtc/stream_read.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ func (s *stream) Read(b []byte) (int, error) {
// load the next message
s.mx.Unlock()
var msg pb.Message
s.readMx.Lock()
if err := s.reader.ReadMsg(&msg); err != nil {
s.readMx.Unlock()
s.mx.Lock()
if err == io.EOF {
// if the channel was properly closed, return EOF
Expand All @@ -48,6 +50,7 @@ func (s *stream) Read(b []byte) (int, error) {
}
return 0, err
}
s.readMx.Unlock()
s.mx.Lock()
s.nextMessage = &msg
}
Expand Down
33 changes: 32 additions & 1 deletion p2p/transport/webrtc/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,15 @@ func TestStreamSimpleReadWriteClose(t *testing.T) {
_, err = serverStr.Write([]byte("lorem ipsum"))
require.NoError(t, err)
require.NoError(t, serverStr.CloseWrite())
require.True(t, serverDone)
// and read it at the client
require.False(t, clientDone)
b, err = io.ReadAll(clientStr)
require.NoError(t, err)
require.Equal(t, []byte("lorem ipsum"), b)
require.True(t, clientDone)
// Need to call Close for cleanup. Otherwise the FIN_ACK is never read
require.NoError(t, serverStr.Close())
require.True(t, serverDone)
}

func TestStreamPartialReads(t *testing.T) {
Expand Down Expand Up @@ -332,3 +334,32 @@ func TestStreamReadAfterClose(t *testing.T) {
_, err = clientStr.Read(nil)
require.ErrorIs(t, err, network.ErrReset)
}

func TestStreamCloseAfterFINACK(t *testing.T) {
client, server := getDetachedDataChannels(t)

done := make(chan bool, 1)
clientStr := newStream(client.dc, client.rwc, func() { done <- true })
serverStr := newStream(server.dc, server.rwc, func() {})

go func() {
done <- true
clientStr.Close()
}()
<-done

select {
case <-done:
t.Fatalf("Close should not have completed without processing FIN_ACK")
case <-time.After(3 * time.Second):
}

b := make([]byte, 1)
_, err := serverStr.Read(b)
require.Error(t, err)
select {
case <-done:
case <-time.After(3 * time.Second):
t.Fatalf("Close should have completed")
}
}

0 comments on commit 6e4d8c2

Please sign in to comment.