From 071849e2eec7ff86d647a287dbf4dc62bb6129ea Mon Sep 17 00:00:00 2001 From: sukun Date: Mon, 23 Oct 2023 17:40:29 +0530 Subject: [PATCH] webrtc: wait for fin_ack for closing datachannel --- core/network/mux.go | 7 +++ p2p/net/swarm/swarm_stream.go | 5 ++ p2p/transport/webrtc/pb/message.pb.go | 29 +++++---- p2p/transport/webrtc/pb/message.proto | 4 ++ p2p/transport/webrtc/stream.go | 88 +++++++++++++++++++++++---- p2p/transport/webrtc/stream_read.go | 3 + p2p/transport/webrtc/stream_test.go | 33 +++++++++- 7 files changed, 146 insertions(+), 23 deletions(-) diff --git a/core/network/mux.go b/core/network/mux.go index d12e2ea34b..7f96125207 100644 --- a/core/network/mux.go +++ b/core/network/mux.go @@ -61,6 +61,13 @@ type MuxedStream interface { SetWriteDeadline(time.Time) error } +// AsyncCloser is implemented by streams that need to do expensive operations on close. Closing the stream async avoids +// blocking the calling goroutine. +type AsyncCloser interface { + // AsyncClose closes the stream and executes onDone when stream is closed + AsyncClose(onDone func()) error +} + // MuxedConn represents a connection to a remote peer that has been // extended to support stream multiplexing. // diff --git a/p2p/net/swarm/swarm_stream.go b/p2p/net/swarm/swarm_stream.go index b7846adec2..1f63844544 100644 --- a/p2p/net/swarm/swarm_stream.go +++ b/p2p/net/swarm/swarm_stream.go @@ -78,6 +78,11 @@ func (s *Stream) Write(p []byte) (int, error) { // Close closes the stream, closing both ends and freeing all associated // resources. func (s *Stream) Close() error { + if as, ok := s.stream.(network.AsyncCloser); ok { + return as.AsyncClose(func() { + s.closeOnce.Do(s.remove) + }) + } err := s.stream.Close() s.closeAndRemoveStream() return err diff --git a/p2p/transport/webrtc/pb/message.pb.go b/p2p/transport/webrtc/pb/message.pb.go index fffc025f7f..384bddd289 100644 --- a/p2p/transport/webrtc/pb/message.pb.go +++ b/p2p/transport/webrtc/pb/message.pb.go @@ -31,6 +31,10 @@ const ( // The sender abruptly terminates the sending part of the stream. The // receiver can discard any data that it already received on that stream. Message_RESET Message_Flag = 2 + // Sending the FIN_ACK flag acknowledges the previous receipt of a message + // with the FIN flag set. Receiving a FIN_ACK flag gives the recipient + // confidence that the remote has received all sent messages. + Message_FIN_ACK Message_Flag = 3 ) // Enum value maps for Message_Flag. @@ -39,11 +43,13 @@ var ( 0: "FIN", 1: "STOP_SENDING", 2: "RESET", + 3: "FIN_ACK", } Message_Flag_value = map[string]int32{ "FIN": 0, "STOP_SENDING": 1, "RESET": 2, + "FIN_ACK": 3, } ) @@ -143,17 +149,18 @@ var File_message_proto protoreflect.FileDescriptor var file_message_proto_rawDesc = []byte{ 0x0a, 0x0d, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, - 0x74, 0x0a, 0x07, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x21, 0x0a, 0x04, 0x66, 0x6c, - 0x61, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x0d, 0x2e, 0x4d, 0x65, 0x73, 0x73, 0x61, - 0x67, 0x65, 0x2e, 0x46, 0x6c, 0x61, 0x67, 0x52, 0x04, 0x66, 0x6c, 0x61, 0x67, 0x12, 0x18, 0x0a, - 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, - 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x2c, 0x0a, 0x04, 0x46, 0x6c, 0x61, 0x67, 0x12, - 0x07, 0x0a, 0x03, 0x46, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x10, 0x0a, 0x0c, 0x53, 0x54, 0x4f, 0x50, - 0x5f, 0x53, 0x45, 0x4e, 0x44, 0x49, 0x4e, 0x47, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x52, 0x45, - 0x53, 0x45, 0x54, 0x10, 0x02, 0x42, 0x35, 0x5a, 0x33, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, - 0x63, 0x6f, 0x6d, 0x2f, 0x6c, 0x69, 0x62, 0x70, 0x32, 0x70, 0x2f, 0x67, 0x6f, 0x2d, 0x6c, 0x69, - 0x62, 0x70, 0x32, 0x70, 0x2f, 0x70, 0x32, 0x70, 0x2f, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, - 0x72, 0x74, 0x2f, 0x77, 0x65, 0x62, 0x72, 0x74, 0x63, 0x2f, 0x70, 0x62, + 0x81, 0x01, 0x0a, 0x07, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x21, 0x0a, 0x04, 0x66, + 0x6c, 0x61, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x0d, 0x2e, 0x4d, 0x65, 0x73, 0x73, + 0x61, 0x67, 0x65, 0x2e, 0x46, 0x6c, 0x61, 0x67, 0x52, 0x04, 0x66, 0x6c, 0x61, 0x67, 0x12, 0x18, + 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, + 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x39, 0x0a, 0x04, 0x46, 0x6c, 0x61, 0x67, + 0x12, 0x07, 0x0a, 0x03, 0x46, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x10, 0x0a, 0x0c, 0x53, 0x54, 0x4f, + 0x50, 0x5f, 0x53, 0x45, 0x4e, 0x44, 0x49, 0x4e, 0x47, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x52, + 0x45, 0x53, 0x45, 0x54, 0x10, 0x02, 0x12, 0x0b, 0x0a, 0x07, 0x46, 0x49, 0x4e, 0x5f, 0x41, 0x43, + 0x4b, 0x10, 0x03, 0x42, 0x35, 0x5a, 0x33, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, + 0x6d, 0x2f, 0x6c, 0x69, 0x62, 0x70, 0x32, 0x70, 0x2f, 0x67, 0x6f, 0x2d, 0x6c, 0x69, 0x62, 0x70, + 0x32, 0x70, 0x2f, 0x70, 0x32, 0x70, 0x2f, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, + 0x2f, 0x77, 0x65, 0x62, 0x72, 0x74, 0x63, 0x2f, 0x70, 0x62, } var ( diff --git a/p2p/transport/webrtc/pb/message.proto b/p2p/transport/webrtc/pb/message.proto index d6b1957beb..aab885b0da 100644 --- a/p2p/transport/webrtc/pb/message.proto +++ b/p2p/transport/webrtc/pb/message.proto @@ -12,6 +12,10 @@ message Message { // The sender abruptly terminates the sending part of the stream. The // receiver can discard any data that it already received on that stream. RESET = 2; + // Sending the FIN_ACK flag acknowledges the previous receipt of a message + // with the FIN flag set. Receiving a FIN_ACK flag gives the recipient + // confidence that the remote has received all sent messages. + FIN_ACK = 3; } optional Flag flag=1; diff --git a/p2p/transport/webrtc/stream.go b/p2p/transport/webrtc/stream.go index 5d1368b1ab..3b30257972 100644 --- a/p2p/transport/webrtc/stream.go +++ b/p2p/transport/webrtc/stream.go @@ -1,6 +1,7 @@ package libp2pwebrtc import ( + "errors" "sync" "time" @@ -53,12 +54,16 @@ const ( sendStateSending sendState = iota sendStateDataSent sendStateReset + sendStateDataReceived ) // Package pion detached data channel into a net.Conn // and then a network.MuxedStream type stream struct { mx sync.Mutex + + // readMx ensures there's only a single goroutine reading from reader + readMx sync.Mutex // pbio.Reader is not thread safe, // and while our Read is not promised to be thread safe, // we ourselves internally read from multiple routines... @@ -132,21 +137,50 @@ func newStream( } func (s *stream) Close() error { + // Close read before write to ensure that the STOP_SENDING message is delivered before + // we close the data channel + closeReadErr := s.CloseRead() closeWriteErr := s.CloseWrite() + if closeWriteErr != nil { + // writing FIN failed, reset the stream + s.Reset() + return closeWriteErr + } + s.waitForFINACK() + s.mx.Lock() + defer s.mx.Unlock() + s.maybeDeclareStreamDone() + return errors.Join(closeWriteErr, closeReadErr) +} + +func (s *stream) AsyncClose(onDone func()) error { + // Close read before write to ensure that the STOP_SENDING message is delivered before + // we close the data channel closeReadErr := s.CloseRead() + closeWriteErr := s.CloseWrite() if closeWriteErr != nil { + // writing FIN failed, reset the stream + s.Reset() return closeWriteErr } - return closeReadErr + go func() { + s.waitForFINACK() + s.mx.Lock() + defer s.mx.Unlock() + s.maybeDeclareStreamDone() + onDone() + }() + return errors.Join(closeWriteErr, closeReadErr) } func (s *stream) Reset() error { cancelWriteErr := s.cancelWrite() closeReadErr := s.CloseRead() - if cancelWriteErr != nil { - return cancelWriteErr - } - return closeReadErr + dcCloseErr := s.dataChannel.Close() + s.mx.Lock() + defer s.mx.Unlock() + s.maybeDeclareStreamDone() + return errors.Join(cancelWriteErr, closeReadErr, dcCloseErr) } func (s *stream) SetDeadline(t time.Time) error { @@ -154,6 +188,35 @@ func (s *stream) SetDeadline(t time.Time) error { return s.SetWriteDeadline(t) } +func (s *stream) waitForFINACK() { + s.mx.Lock() + defer s.mx.Unlock() + // We can only wait for FIN_ACK if we've stopped reading from the stream + if s.sendState != sendStateDataSent || s.receiveState == receiveStateReceiving { + return + } + // First wait for any existing readers to exit + s.readMx.Lock() + s.SetReadDeadline(time.Now().Add(5 * time.Second)) + var msg pb.Message + for { + s.mx.Unlock() + if err := s.reader.ReadMsg(&msg); err != nil { + s.readMx.Unlock() + s.mx.Lock() + s.sendState = sendStateDataReceived + break + } + s.readMx.Unlock() + s.mx.Lock() + s.processIncomingFlag(msg.Flag) + if s.sendState != sendStateDataSent { + break + } + s.readMx.Lock() + } +} + // processIncomingFlag process the flag on an incoming message // It needs to be called with msg.Flag, not msg.GetFlag(), // otherwise we'd misinterpret the default value. @@ -168,6 +231,9 @@ func (s *stream) processIncomingFlag(flag *pb.Message_Flag) { if s.receiveState == receiveStateReceiving { s.receiveState = receiveStateDataRead } + if err := s.sendControlMessage(&pb.Message{Flag: pb.Message_FIN_ACK.Enum()}); err != nil { + log.Debugf("failed to send FIN_ACK:", err) + } case pb.Message_STOP_SENDING: if s.sendState == sendStateSending { s.sendState = sendStateReset @@ -180,24 +246,24 @@ func (s *stream) processIncomingFlag(flag *pb.Message_Flag) { if s.receiveState == receiveStateReceiving { s.receiveState = receiveStateReset } + case pb.Message_FIN_ACK: + if s.sendState == sendStateDataSent { + s.sendState = sendStateDataReceived + } } s.maybeDeclareStreamDone() } // maybeDeclareStreamDone is used to force reset a stream. It must be called with mx acquired func (s *stream) maybeDeclareStreamDone() { - if (s.sendState == sendStateReset || s.sendState == sendStateDataSent) && + if (s.sendState == sendStateReset || s.sendState == sendStateDataReceived) && (s.receiveState == receiveStateReset || s.receiveState == receiveStateDataRead) && len(s.controlMsgQueue) == 0 { s.mx.Unlock() defer s.mx.Lock() _ = s.SetReadDeadline(time.Now().Add(-1 * time.Hour)) // pion ignores zero times - // TODO: we should be closing the underlying datachannel, but this resets the stream - // See https://github.com/libp2p/specs/issues/575 for details. - // _ = s.dataChannel.Close() - // TODO: write for the spawned reader to return - + s.dataChannel.Close() s.onDone() } } diff --git a/p2p/transport/webrtc/stream_read.go b/p2p/transport/webrtc/stream_read.go index 8a6300c0a9..a93c6d6e5c 100644 --- a/p2p/transport/webrtc/stream_read.go +++ b/p2p/transport/webrtc/stream_read.go @@ -32,7 +32,9 @@ func (s *stream) Read(b []byte) (int, error) { // load the next message s.mx.Unlock() var msg pb.Message + s.readMx.Lock() if err := s.reader.ReadMsg(&msg); err != nil { + s.readMx.Unlock() s.mx.Lock() if err == io.EOF { // if the channel was properly closed, return EOF @@ -48,6 +50,7 @@ func (s *stream) Read(b []byte) (int, error) { } return 0, err } + s.readMx.Unlock() s.mx.Lock() s.nextMessage = &msg } diff --git a/p2p/transport/webrtc/stream_test.go b/p2p/transport/webrtc/stream_test.go index 6aecea80ea..6cdfaa7cd3 100644 --- a/p2p/transport/webrtc/stream_test.go +++ b/p2p/transport/webrtc/stream_test.go @@ -126,13 +126,15 @@ func TestStreamSimpleReadWriteClose(t *testing.T) { _, err = serverStr.Write([]byte("lorem ipsum")) require.NoError(t, err) require.NoError(t, serverStr.CloseWrite()) - require.True(t, serverDone) // and read it at the client require.False(t, clientDone) b, err = io.ReadAll(clientStr) require.NoError(t, err) require.Equal(t, []byte("lorem ipsum"), b) require.True(t, clientDone) + // Need to call Close for cleanup. Otherwise the FIN_ACK is never read + require.NoError(t, serverStr.Close()) + require.True(t, serverDone) } func TestStreamPartialReads(t *testing.T) { @@ -332,3 +334,32 @@ func TestStreamReadAfterClose(t *testing.T) { _, err = clientStr.Read(nil) require.ErrorIs(t, err, network.ErrReset) } + +func TestStreamCloseAfterFINACK(t *testing.T) { + client, server := getDetachedDataChannels(t) + + done := make(chan bool, 1) + clientStr := newStream(client.dc, client.rwc, func() { done <- true }) + serverStr := newStream(server.dc, server.rwc, func() {}) + + go func() { + done <- true + clientStr.Close() + }() + <-done + + select { + case <-done: + t.Fatalf("Close should not have completed without processing FIN_ACK") + case <-time.After(3 * time.Second): + } + + b := make([]byte, 1) + _, err := serverStr.Read(b) + require.Error(t, err) + select { + case <-done: + case <-time.After(3 * time.Second): + t.Fatalf("Close should have completed") + } +}