diff --git a/core/network/mux.go b/core/network/mux.go index d12e2ea34b..fdda55365a 100644 --- a/core/network/mux.go +++ b/core/network/mux.go @@ -61,6 +61,13 @@ type MuxedStream interface { SetWriteDeadline(time.Time) error } +// AsyncCloser is implemented by streams that need to do expensive operations on close before +// releasing the resources. Closing the stream async avoids blocking the calling goroutine. +type AsyncCloser interface { + // AsyncClose closes the stream and executes onDone after the stream is closed + AsyncClose(onDone func()) error +} + // MuxedConn represents a connection to a remote peer that has been // extended to support stream multiplexing. // diff --git a/p2p/net/swarm/swarm_stream.go b/p2p/net/swarm/swarm_stream.go index b7846adec2..1339709db2 100644 --- a/p2p/net/swarm/swarm_stream.go +++ b/p2p/net/swarm/swarm_stream.go @@ -78,6 +78,12 @@ func (s *Stream) Write(p []byte) (int, error) { // Close closes the stream, closing both ends and freeing all associated // resources. func (s *Stream) Close() error { + if as, ok := s.stream.(network.AsyncCloser); ok { + err := as.AsyncClose(func() { + s.closeAndRemoveStream() + }) + return err + } err := s.stream.Close() s.closeAndRemoveStream() return err diff --git a/p2p/net/swarm/swarm_stream_test.go b/p2p/net/swarm/swarm_stream_test.go new file mode 100644 index 0000000000..653489fe8f --- /dev/null +++ b/p2p/net/swarm/swarm_stream_test.go @@ -0,0 +1,45 @@ +package swarm + +import ( + "context" + "sync/atomic" + "testing" + + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peerstore" + "github.com/stretchr/testify/require" +) + +type asyncStreamWrapper struct { + network.MuxedStream + beforeClose func() +} + +func (s *asyncStreamWrapper) AsyncClose(onDone func()) error { + s.beforeClose() + err := s.Close() + onDone() + return err +} + +func TestStreamAsyncCloser(t *testing.T) { + s1 := makeSwarm(t) + s2 := makeSwarm(t) + + s1.Peerstore().AddAddrs(s2.LocalPeer(), s2.ListenAddresses(), peerstore.TempAddrTTL) + s, err := s1.NewStream(context.Background(), s2.LocalPeer()) + require.NoError(t, err) + ss, ok := s.(*Stream) + require.True(t, ok) + + var called atomic.Bool + as := &asyncStreamWrapper{ + MuxedStream: ss.stream, + beforeClose: func() { + called.Store(true) + }, + } + ss.stream = as + ss.Close() + require.True(t, called.Load()) +} diff --git a/p2p/transport/webrtc/connection.go b/p2p/transport/webrtc/connection.go index fd31f8351a..776ddd5de5 100644 --- a/p2p/transport/webrtc/connection.go +++ b/p2p/transport/webrtc/connection.go @@ -27,8 +27,6 @@ var _ tpt.CapableConn = &connection{} const maxAcceptQueueLen = 10 -const maxDataChannelID = 1 << 10 - type errConnectionTimeout struct{} var _ net.Error = &errConnectionTimeout{} @@ -47,7 +45,8 @@ type connection struct { transport *WebRTCTransport scope network.ConnManagementScope - closeErr error + closeOnce sync.Once + closeErr error localPeer peer.ID localMultiaddr ma.Multiaddr @@ -110,12 +109,6 @@ func newConnection( if c.IsClosed() { return } - // Limit the number of streams, since we're not able to actually properly close them. - // See https://github.com/libp2p/specs/issues/575 for details. - if *dc.ID() > maxDataChannelID { - c.Close() - return - } dc.OnOpen(func() { rwc, err := dc.Detach() if err != nil { @@ -133,7 +126,6 @@ func newConnection( } }) }) - return c, nil } @@ -144,16 +136,41 @@ func (c *connection) ConnState() network.ConnectionState { // Close closes the underlying peerconnection. func (c *connection) Close() error { - if c.IsClosed() { - return nil - } + c.closeOnce.Do(func() { + c.closeErr = errors.New("connection closed") + // cancel must be called after closeErr is set. This ensures interested goroutines waiting on + // ctx.Done can read closeErr without holding the conn lock. + c.cancel() + c.m.Lock() + streams := c.streams + c.streams = nil + c.m.Unlock() + for _, str := range streams { + str.Reset() + } + c.pc.Close() + c.scope.Done() + }) + return nil +} - c.m.Lock() - defer c.m.Unlock() - c.scope.Done() - c.closeErr = errors.New("connection closed") - c.cancel() - return c.pc.Close() +func (c *connection) closeTimedOut() error { + c.closeOnce.Do(func() { + c.closeErr = errConnectionTimeout{} + // cancel must be called after closeErr is set. This ensures interested goroutines waiting on + // ctx.Done can read closeErr without holding the conn lock. + c.cancel() + c.m.Lock() + streams := c.streams + c.streams = nil + c.m.Unlock() + for _, str := range streams { + str.closeWithError(errConnectionTimeout{}) + } + c.pc.Close() + c.scope.Done() + }) + return nil } func (c *connection) IsClosed() bool { @@ -174,12 +191,6 @@ func (c *connection) OpenStream(ctx context.Context) (network.MuxedStream, error if id > math.MaxUint16 { return nil, errors.New("exhausted stream ID space") } - // Limit the number of streams, since we're not able to actually properly close them. - // See https://github.com/libp2p/specs/issues/575 for details. - if id > maxDataChannelID { - c.Close() - return c.OpenStream(ctx) - } streamID := uint16(id) dc, err := c.pc.CreateDataChannel("", &webrtc.DataChannelInit{ID: &streamID}) @@ -238,20 +249,7 @@ func (c *connection) removeStream(id uint16) { func (c *connection) onConnectionStateChange(state webrtc.PeerConnectionState) { if state == webrtc.PeerConnectionStateFailed || state == webrtc.PeerConnectionStateClosed { - // reset any streams - if c.IsClosed() { - return - } - c.m.Lock() - defer c.m.Unlock() - c.closeErr = errConnectionTimeout{} - for k, str := range c.streams { - str.setCloseError(c.closeErr) - delete(c.streams, k) - } - c.cancel() - c.scope.Done() - c.pc.Close() + c.closeTimedOut() } } diff --git a/p2p/transport/webrtc/pb/message.pb.go b/p2p/transport/webrtc/pb/message.pb.go index fffc025f7f..384bddd289 100644 --- a/p2p/transport/webrtc/pb/message.pb.go +++ b/p2p/transport/webrtc/pb/message.pb.go @@ -31,6 +31,10 @@ const ( // The sender abruptly terminates the sending part of the stream. The // receiver can discard any data that it already received on that stream. Message_RESET Message_Flag = 2 + // Sending the FIN_ACK flag acknowledges the previous receipt of a message + // with the FIN flag set. Receiving a FIN_ACK flag gives the recipient + // confidence that the remote has received all sent messages. + Message_FIN_ACK Message_Flag = 3 ) // Enum value maps for Message_Flag. @@ -39,11 +43,13 @@ var ( 0: "FIN", 1: "STOP_SENDING", 2: "RESET", + 3: "FIN_ACK", } Message_Flag_value = map[string]int32{ "FIN": 0, "STOP_SENDING": 1, "RESET": 2, + "FIN_ACK": 3, } ) @@ -143,17 +149,18 @@ var File_message_proto protoreflect.FileDescriptor var file_message_proto_rawDesc = []byte{ 0x0a, 0x0d, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, - 0x74, 0x0a, 0x07, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x21, 0x0a, 0x04, 0x66, 0x6c, - 0x61, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x0d, 0x2e, 0x4d, 0x65, 0x73, 0x73, 0x61, - 0x67, 0x65, 0x2e, 0x46, 0x6c, 0x61, 0x67, 0x52, 0x04, 0x66, 0x6c, 0x61, 0x67, 0x12, 0x18, 0x0a, - 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, - 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x2c, 0x0a, 0x04, 0x46, 0x6c, 0x61, 0x67, 0x12, - 0x07, 0x0a, 0x03, 0x46, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x10, 0x0a, 0x0c, 0x53, 0x54, 0x4f, 0x50, - 0x5f, 0x53, 0x45, 0x4e, 0x44, 0x49, 0x4e, 0x47, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x52, 0x45, - 0x53, 0x45, 0x54, 0x10, 0x02, 0x42, 0x35, 0x5a, 0x33, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, - 0x63, 0x6f, 0x6d, 0x2f, 0x6c, 0x69, 0x62, 0x70, 0x32, 0x70, 0x2f, 0x67, 0x6f, 0x2d, 0x6c, 0x69, - 0x62, 0x70, 0x32, 0x70, 0x2f, 0x70, 0x32, 0x70, 0x2f, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, - 0x72, 0x74, 0x2f, 0x77, 0x65, 0x62, 0x72, 0x74, 0x63, 0x2f, 0x70, 0x62, + 0x81, 0x01, 0x0a, 0x07, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x21, 0x0a, 0x04, 0x66, + 0x6c, 0x61, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x0d, 0x2e, 0x4d, 0x65, 0x73, 0x73, + 0x61, 0x67, 0x65, 0x2e, 0x46, 0x6c, 0x61, 0x67, 0x52, 0x04, 0x66, 0x6c, 0x61, 0x67, 0x12, 0x18, + 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, + 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x39, 0x0a, 0x04, 0x46, 0x6c, 0x61, 0x67, + 0x12, 0x07, 0x0a, 0x03, 0x46, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x10, 0x0a, 0x0c, 0x53, 0x54, 0x4f, + 0x50, 0x5f, 0x53, 0x45, 0x4e, 0x44, 0x49, 0x4e, 0x47, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x52, + 0x45, 0x53, 0x45, 0x54, 0x10, 0x02, 0x12, 0x0b, 0x0a, 0x07, 0x46, 0x49, 0x4e, 0x5f, 0x41, 0x43, + 0x4b, 0x10, 0x03, 0x42, 0x35, 0x5a, 0x33, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, + 0x6d, 0x2f, 0x6c, 0x69, 0x62, 0x70, 0x32, 0x70, 0x2f, 0x67, 0x6f, 0x2d, 0x6c, 0x69, 0x62, 0x70, + 0x32, 0x70, 0x2f, 0x70, 0x32, 0x70, 0x2f, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, + 0x2f, 0x77, 0x65, 0x62, 0x72, 0x74, 0x63, 0x2f, 0x70, 0x62, } var ( diff --git a/p2p/transport/webrtc/pb/message.proto b/p2p/transport/webrtc/pb/message.proto index d6b1957beb..aab885b0da 100644 --- a/p2p/transport/webrtc/pb/message.proto +++ b/p2p/transport/webrtc/pb/message.proto @@ -12,6 +12,10 @@ message Message { // The sender abruptly terminates the sending part of the stream. The // receiver can discard any data that it already received on that stream. RESET = 2; + // Sending the FIN_ACK flag acknowledges the previous receipt of a message + // with the FIN flag set. Receiving a FIN_ACK flag gives the recipient + // confidence that the remote has received all sent messages. + FIN_ACK = 3; } optional Flag flag=1; diff --git a/p2p/transport/webrtc/stream.go b/p2p/transport/webrtc/stream.go index 0358dce56c..945712b826 100644 --- a/p2p/transport/webrtc/stream.go +++ b/p2p/transport/webrtc/stream.go @@ -1,6 +1,7 @@ package libp2pwebrtc import ( + "errors" "sync" "time" @@ -52,6 +53,7 @@ type sendState uint8 const ( sendStateSending sendState = iota sendStateDataSent + sendStateDataReceived sendStateReset ) @@ -59,27 +61,32 @@ const ( // and then a network.MuxedStream type stream struct { mx sync.Mutex - // pbio.Reader is not thread safe, - // and while our Read is not promised to be thread safe, - // we ourselves internally read from multiple routines... - reader pbio.Reader + + // readerMx ensures that only a single goroutine is blocked on reader.ReadMsg. We need + // concurrent reads from reader as control messages are multiplexed on the same stream with + // data messages. + readerMx sync.Mutex + reader pbio.Reader + // this buffer is limited up to a single message. Reason we need it // is because a reader might read a message midway, and so we need a // wait to buffer that for as long as the remaining part is not (yet) read nextMessage *pb.Message receiveState receiveState - // The public Write API is not promised to be thread safe, - // but we need to be able to write control messages. + // writerMx ensures that only a single goroutine is calling WriteMsg on writer. writer is a + // pbio.uvarintWriter which is not thread safe. The public Write API is not promised to be + // thread safe, but we need to be able to write control messages concurrently + writerMx sync.Mutex writer pbio.Writer sendStateChanged chan struct{} sendState sendState - controlMsgQueue []*pb.Message writeDeadline time.Time writeDeadlineUpdated chan struct{} writeAvailable chan struct{} - readLoopOnce sync.Once + controlMessageReaderOnce sync.Once + controlMessageReaderDone chan struct{} onDone func() id uint16 // for logging purposes @@ -102,6 +109,8 @@ func newStream( writeDeadlineUpdated: make(chan struct{}, 1), writeAvailable: make(chan struct{}, 1), + controlMessageReaderDone: make(chan struct{}), + id: *channel.ID(), dataChannel: rwc.(*datachannel.DataChannel), onDone: onDone, @@ -111,35 +120,6 @@ func newStream( channel.OnBufferedAmountLow(func() { s.mx.Lock() defer s.mx.Unlock() - // first send out queued control messages - for len(s.controlMsgQueue) > 0 { - msg := s.controlMsgQueue[0] - available := s.availableSendSpace() - if controlMsgSize < available { - s.writer.WriteMsg(msg) // TODO: handle error - s.controlMsgQueue = s.controlMsgQueue[1:] - } else { - return - } - } - - if s.isDone() { - // onDone removes the stream from the connection and requires the connection lock. - // This callback(onBufferedAmountLow) is executing in the sctp readLoop goroutine. - // If Connection.Close is called concurrently, the closing goroutine will acquire - // the connection lock and wait for sctp readLoop to exit, the sctp readLoop will - // wait for the connection lock before exiting, causing a deadlock. - // Run this in a different goroutine to avoid the deadlock. - go func() { - s.mx.Lock() - defer s.mx.Unlock() - // TODO: we should be closing the underlying datachannel, but this resets the stream - // See https://github.com/libp2p/specs/issues/575 for details. - // _ = s.dataChannel.Close() - // TODO: write for the spawned reader to return - s.onDone() - }() - } select { case s.writeAvailable <- struct{}{}: @@ -150,15 +130,48 @@ func newStream( } func (s *stream) Close() error { + defer s.cleanup() + closeWriteErr := s.CloseWrite() closeReadErr := s.CloseRead() - if closeWriteErr != nil { - return closeWriteErr + if closeWriteErr != nil || closeReadErr != nil { + s.Reset() + return errors.Join(closeWriteErr, closeReadErr) } - return closeReadErr + s.controlMessageReaderOnce.Do(s.spawnControlMessageReader) + // wait only 10 seconds for FIN_ACK + s.SetReadDeadline(time.Now().Add(10 * time.Second)) + <-s.controlMessageReaderDone + // cleanup the deadline timer goroutine in sctp + s.SetReadDeadline(time.Time{}) + return nil +} + +func (s *stream) AsyncClose(onDone func()) error { + closeWriteErr := s.CloseWrite() + closeReadErr := s.CloseRead() + if closeWriteErr != nil || closeReadErr != nil { + s.Reset() + return errors.Join(closeWriteErr, closeReadErr) + } + s.controlMessageReaderOnce.Do(s.spawnControlMessageReader) + // wait only 10 seconds for FIN_ACK + s.SetReadDeadline(time.Now().Add(10 * time.Second)) + go func() { + <-s.controlMessageReaderDone + s.cleanup() + if onDone != nil { + onDone() + } + // cleanup the deadline timer goroutine in sctp + s.SetReadDeadline(time.Time{}) + }() + return nil } func (s *stream) Reset() error { + defer s.cleanup() + cancelWriteErr := s.cancelWrite() closeReadErr := s.CloseRead() if cancelWriteErr != nil { @@ -167,14 +180,20 @@ func (s *stream) Reset() error { return closeReadErr } +func (s *stream) closeWithError(e error) { + defer s.cleanup() + + s.mx.Lock() + defer s.mx.Unlock() + s.closeErr = e +} + func (s *stream) SetDeadline(t time.Time) error { _ = s.SetReadDeadline(t) return s.SetWriteDeadline(t) } // processIncomingFlag process the flag on an incoming message -// It needs to be called with msg.Flag, not msg.GetFlag(), -// otherwise we'd misinterpret the default value. // It needs to be called while the mutex is locked. func (s *stream) processIncomingFlag(flag *pb.Message_Flag) { if flag == nil { @@ -182,50 +201,77 @@ func (s *stream) processIncomingFlag(flag *pb.Message_Flag) { } switch *flag { - case pb.Message_FIN: - if s.receiveState == receiveStateReceiving { - s.receiveState = receiveStateDataRead - } case pb.Message_STOP_SENDING: - if s.sendState == sendStateSending { + // We must process STOP_SENDING after sending a FIN(sendStateDataSent). Remote peer + // may not send a FIN_ACK once it has sent a STOP_SENDING + if s.sendState == sendStateSending || s.sendState == sendStateDataSent { s.sendState = sendStateReset } select { case s.sendStateChanged <- struct{}{}: default: } + case pb.Message_FIN_ACK: + s.sendState = sendStateDataReceived + select { + case s.sendStateChanged <- struct{}{}: + default: + } + case pb.Message_FIN: + if s.receiveState == receiveStateReceiving { + s.receiveState = receiveStateDataRead + } + if err := s.writeMsgOnWriter(&pb.Message{Flag: pb.Message_FIN_ACK.Enum()}); err != nil { + log.Debugf("failed to send FIN_ACK: %s", err) + // Remote has finished writing all the data It'll stop waiting for the + // FIN_ACK eventually or will be notified when we close the datachannel + } + s.SetReadDeadline(time.Now().Add(-1 * time.Hour)) case pb.Message_RESET: if s.receiveState == receiveStateReceiving { s.receiveState = receiveStateReset } + s.SetReadDeadline(time.Now().Add(-1 * time.Hour)) } - s.maybeDeclareStreamDone() } -// maybeDeclareStreamDone is used to force reset a stream. It should be called with -// the stream lock acquired. It calls stream.onDone which requires the connection lock. -func (s *stream) maybeDeclareStreamDone() { - if s.isDone() { - _ = s.SetReadDeadline(time.Now().Add(-1 * time.Hour)) // pion ignores zero times - // TODO: we should be closing the underlying datachannel, but this resets the stream - // See https://github.com/libp2p/specs/issues/575 for details. - // _ = s.dataChannel.Close() - // TODO: write for the spawned reader to return - s.onDone() - } -} +// spawnControlMessageReader is used for processing control messages after the reader is closed. +func (s *stream) spawnControlMessageReader() { + // Unblock waiting readers + s.SetReadDeadline(time.Now().Add(-1 * time.Hour)) + s.readerMx.Lock() + // no deadline needed, Read will return once there's a new message, or an error occurrs + s.dataChannel.SetReadDeadline(time.Time{}) + s.readerMx.Unlock() -// isDone indicates whether the stream is completed and all the control messages have also been -// flushed. It must be called with the stream lock acquired. -func (s *stream) isDone() bool { - return (s.sendState == sendStateReset || s.sendState == sendStateDataSent) && - (s.receiveState == receiveStateReset || s.receiveState == receiveStateDataRead) && - len(s.controlMsgQueue) == 0 -} + go func() { + defer close(s.controlMessageReaderDone) + s.mx.Lock() + if s.nextMessage != nil { + s.processIncomingFlag(s.nextMessage.Flag) + s.nextMessage = nil + } + s.mx.Unlock() -func (s *stream) setCloseError(e error) { - s.mx.Lock() - defer s.mx.Unlock() + isSendCompleted := func() bool { + s.mx.Lock() + defer s.mx.Unlock() + return s.sendState == sendStateDataReceived || s.sendState == sendStateReset + } - s.closeErr = e + for !isSendCompleted() { + var msg pb.Message + if err := s.reader.ReadMsg(&msg); err != nil { + return + } + s.mx.Lock() + s.processIncomingFlag(msg.Flag) + s.mx.Unlock() + } + }() +} + +func (s *stream) cleanup() { + s.dataChannel.Close() + s.onDone() } diff --git a/p2p/transport/webrtc/stream_read.go b/p2p/transport/webrtc/stream_read.go index e064c8558b..3e5dff3ee2 100644 --- a/p2p/transport/webrtc/stream_read.go +++ b/p2p/transport/webrtc/stream_read.go @@ -1,19 +1,16 @@ package libp2pwebrtc import ( - "errors" + "fmt" "io" "time" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/p2p/transport/webrtc/pb" + "google.golang.org/protobuf/proto" ) func (s *stream) Read(b []byte) (int, error) { - if len(b) == 0 { - return 0, nil - } - s.mx.Lock() defer s.mx.Unlock() @@ -27,13 +24,18 @@ func (s *stream) Read(b []byte) (int, error) { return 0, network.ErrReset } + if len(b) == 0 { + return 0, nil + } + var read int for { if s.nextMessage == nil { // load the next message s.mx.Unlock() var msg pb.Message - if err := s.reader.ReadMsg(&msg); err != nil { + + if err := s.readMsgFromReader(&msg); err != nil { s.mx.Lock() if err == io.EOF { // if the channel was properly closed, return EOF @@ -42,7 +44,8 @@ func (s *stream) Read(b []byte) (int, error) { } // This case occurs when the remote node closes the stream without writing a FIN message // There's little we can do here - return 0, errors.New("didn't receive final state for stream") + s.receiveState = receiveStateReset + return 0, fmt.Errorf("didn't receive final state for stream: %w", network.ErrReset) } if s.receiveState == receiveStateReset { return 0, network.ErrReset @@ -81,20 +84,25 @@ func (s *stream) SetReadDeadline(t time.Time) error { return s.dataChannel.SetRe func (s *stream) CloseRead() error { s.mx.Lock() defer s.mx.Unlock() - if s.nextMessage != nil { s.processIncomingFlag(s.nextMessage.Flag) s.nextMessage = nil } + var err error if s.receiveState == receiveStateReceiving && s.closeErr == nil { - err = s.sendControlMessage(&pb.Message{Flag: pb.Message_STOP_SENDING.Enum()}) + err = s.writeMsgOnWriter(&pb.Message{Flag: pb.Message_STOP_SENDING.Enum()}) + s.receiveState = receiveStateReset } - s.receiveState = receiveStateReset - s.maybeDeclareStreamDone() // make any calls to Read blocking on ReadMsg return immediately - s.dataChannel.SetReadDeadline(time.Now()) + s.dataChannel.SetReadDeadline(time.Now().Add(-1 * time.Hour)) return err } + +func (s *stream) readMsgFromReader(msg proto.Message) error { + s.readerMx.Lock() + defer s.readerMx.Unlock() + return s.reader.ReadMsg(msg) +} diff --git a/p2p/transport/webrtc/stream_test.go b/p2p/transport/webrtc/stream_test.go index f1442b9bfd..22244f8323 100644 --- a/p2p/transport/webrtc/stream_test.go +++ b/p2p/transport/webrtc/stream_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/libp2p/go-libp2p/p2p/transport/webrtc/pb" + "go.uber.org/goleak" "github.com/libp2p/go-libp2p/core/network" @@ -125,13 +126,20 @@ func TestStreamSimpleReadWriteClose(t *testing.T) { _, err = serverStr.Write([]byte("lorem ipsum")) require.NoError(t, err) require.NoError(t, serverStr.CloseWrite()) - require.True(t, serverDone) + // and read it at the client require.False(t, clientDone) b, err = io.ReadAll(clientStr) require.NoError(t, err) require.Equal(t, []byte("lorem ipsum"), b) - require.True(t, clientDone) + + // stream is only cleaned up on calling Close or AsyncClose or Reset + clientStr.AsyncClose(nil) + serverStr.AsyncClose(nil) + require.Eventually(t, func() bool { return clientDone }, 10*time.Second, 100*time.Millisecond) + // Need to call Close for cleanup. Otherwise the FIN_ACK is never read + require.NoError(t, serverStr.Close()) + require.Eventually(t, func() bool { return serverDone }, 10*time.Second, 100*time.Millisecond) } func TestStreamPartialReads(t *testing.T) { @@ -202,7 +210,7 @@ func TestStreamReadReturnsOnClose(t *testing.T) { errChan <- err }() time.Sleep(50 * time.Millisecond) // give the Read call some time to hit the loop - require.NoError(t, clientStr.Close()) + require.NoError(t, clientStr.AsyncClose(nil)) select { case err := <-errChan: require.ErrorIs(t, err, network.ErrReset) @@ -242,6 +250,7 @@ func TestStreamResets(t *testing.T) { _, err := serverStr.Write([]byte("foobar")) return errors.Is(err, network.ErrReset) }, time.Second, 50*time.Millisecond) + serverStr.Close() require.True(t, serverDone) } @@ -305,3 +314,178 @@ func TestStreamWriteDeadlineAsync(t *testing.T) { require.GreaterOrEqual(t, took, timeout) require.LessOrEqual(t, took, timeout*3/2) } + +func TestStreamReadAfterClose(t *testing.T) { + client, server := getDetachedDataChannels(t) + + clientStr := newStream(client.dc, client.rwc, func() {}) + serverStr := newStream(server.dc, server.rwc, func() {}) + + serverStr.AsyncClose(nil) + b := make([]byte, 1) + _, err := clientStr.Read(b) + require.Equal(t, io.EOF, err) + _, err = clientStr.Read(nil) + require.Equal(t, io.EOF, err) + + client, server = getDetachedDataChannels(t) + + clientStr = newStream(client.dc, client.rwc, func() {}) + serverStr = newStream(server.dc, server.rwc, func() {}) + + serverStr.Reset() + b = make([]byte, 1) + _, err = clientStr.Read(b) + require.ErrorIs(t, err, network.ErrReset) + _, err = clientStr.Read(nil) + require.ErrorIs(t, err, network.ErrReset) +} + +func TestStreamCloseAfterFINACK(t *testing.T) { + client, server := getDetachedDataChannels(t) + + done := make(chan bool, 1) + clientStr := newStream(client.dc, client.rwc, func() { done <- true }) + serverStr := newStream(server.dc, server.rwc, func() {}) + + go func() { + done <- true + clientStr.Close() + }() + <-done + + select { + case <-done: + t.Fatalf("Close should not have completed without processing FIN_ACK") + case <-time.After(2 * time.Second): + } + + b := make([]byte, 1) + _, err := serverStr.Read(b) + require.Error(t, err) + require.ErrorIs(t, err, io.EOF) + select { + case <-done: + case <-time.After(3 * time.Second): + t.Fatalf("Close should have completed") + } +} + +func TestStreamFinAckAfterStopSending(t *testing.T) { + client, server := getDetachedDataChannels(t) + + done := make(chan bool, 1) + clientStr := newStream(client.dc, client.rwc, func() { done <- true }) + serverStr := newStream(server.dc, server.rwc, func() {}) + + go func() { + done <- true + clientStr.CloseRead() + clientStr.Write([]byte("hello world")) + clientStr.Close() + }() + <-done + + select { + case <-done: + t.Fatalf("Close should not have completed without processing FIN_ACK") + case <-time.After(1 * time.Second): + } + + // serverStr has write half of the stream closed but the read half should + // respond correctly + b := make([]byte, 24) + _, err := serverStr.Read(b) + require.NoError(t, err) + serverStr.Close() // Sends stop_sending, fin + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatalf("Close should have completed") + } +} + +func TestStreamConcurrentClose(t *testing.T) { + client, server := getDetachedDataChannels(t) + + start := make(chan bool, 1) + done := make(chan bool, 2) + clientStr := newStream(client.dc, client.rwc, func() { done <- true }) + serverStr := newStream(server.dc, server.rwc, func() { done <- true }) + + go func() { + start <- true + clientStr.Close() + }() + go func() { + start <- true + serverStr.Close() + }() + <-start + <-start + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatalf("concurrent close should succeed quickly") + } + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatalf("concurrent close should succeed quickly") + } +} + +func TestStreamCloseGoRoutineLeak(t *testing.T) { + t.Cleanup(func() { goleak.VerifyNone(t) }) + client, server := getDetachedDataChannels(t) + + start := make(chan bool, 1) + done := make(chan bool, 2) + clientStr := newStream(client.dc, client.rwc, func() { done <- true }) + serverStr := newStream(server.dc, server.rwc, func() { done <- true }) + + go func() { + start <- true + clientStr.Close() + }() + go func() { + start <- true + serverStr.Close() + }() + <-start + <-start + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatalf("concurrent close should succeed quickly") + } + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatalf("concurrent close should succeed quickly") + } + + go func() { + start <- true + clientStr.Close() + }() + go func() { + start <- true + serverStr.Close() + }() + <-start + <-start + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatalf("concurrent close should succeed quickly") + } + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatalf("concurrent close should succeed quickly") + } +} diff --git a/p2p/transport/webrtc/stream_write.go b/p2p/transport/webrtc/stream_write.go index 698af9c4d6..837fc2d9b4 100644 --- a/p2p/transport/webrtc/stream_write.go +++ b/p2p/transport/webrtc/stream_write.go @@ -7,6 +7,7 @@ import ( "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/p2p/transport/webrtc/pb" + "google.golang.org/protobuf/proto" ) var errWriteAfterClose = errors.New("write after close") @@ -27,12 +28,14 @@ func (s *stream) Write(b []byte) (int, error) { return 0, network.ErrReset case sendStateDataSent: return 0, errWriteAfterClose + case sendStateDataReceived: + return 0, errWriteAfterClose } // Check if there is any message on the wire. This is used for control // messages only when the read side of the stream is closed if s.receiveState != receiveStateReceiving { - s.readLoopOnce.Do(s.spawnControlMessageReader) + s.controlMessageReaderOnce.Do(s.spawnControlMessageReader) } if !s.writeDeadline.IsZero() && time.Now().After(s.writeDeadline) { @@ -100,7 +103,7 @@ func (s *stream) Write(b []byte) (int, error) { end = len(b) } msg := &pb.Message{Message: b[:end]} - if err := s.writer.WriteMsg(msg); err != nil { + if err := s.writeMsgOnWriter(msg); err != nil { return n, err } n += end @@ -109,30 +112,6 @@ func (s *stream) Write(b []byte) (int, error) { return n, nil } -// used for reading control messages while writing, in case the reader is closed, -// as to ensure we do still get control messages. This is important as according to the spec -// our data and control channels are intermixed on the same conn. -func (s *stream) spawnControlMessageReader() { - if s.nextMessage != nil { - s.processIncomingFlag(s.nextMessage.Flag) - s.nextMessage = nil - } - - go func() { - // no deadline needed, Read will return once there's a new message, or an error occurred - _ = s.dataChannel.SetReadDeadline(time.Time{}) - for { - var msg pb.Message - if err := s.reader.ReadMsg(&msg); err != nil { - return - } - s.mx.Lock() - s.processIncomingFlag(msg.Flag) - s.mx.Unlock() - } - }() -} - func (s *stream) SetWriteDeadline(t time.Time) error { s.mx.Lock() defer s.mx.Unlock() @@ -153,19 +132,6 @@ func (s *stream) availableSendSpace() int { return availableSpace } -// There's no way to determine the size of a Protobuf message in the pbio package. -// Setting the size to 100 works as long as the control messages (incl. the varint prefix) are smaller than that value. -const controlMsgSize = 100 - -func (s *stream) sendControlMessage(msg *pb.Message) error { - available := s.availableSendSpace() - if controlMsgSize < available { - return s.writer.WriteMsg(msg) - } - s.controlMsgQueue = append(s.controlMsgQueue, msg) - return nil -} - func (s *stream) cancelWrite() error { s.mx.Lock() defer s.mx.Unlock() @@ -178,10 +144,9 @@ func (s *stream) cancelWrite() error { case s.sendStateChanged <- struct{}{}: default: } - if err := s.sendControlMessage(&pb.Message{Flag: pb.Message_RESET.Enum()}); err != nil { + if err := s.writeMsgOnWriter(&pb.Message{Flag: pb.Message_RESET.Enum()}); err != nil { return err } - s.maybeDeclareStreamDone() return nil } @@ -193,9 +158,18 @@ func (s *stream) CloseWrite() error { return nil } s.sendState = sendStateDataSent - if err := s.sendControlMessage(&pb.Message{Flag: pb.Message_FIN.Enum()}); err != nil { + select { + case s.sendStateChanged <- struct{}{}: + default: + } + if err := s.writeMsgOnWriter(&pb.Message{Flag: pb.Message_FIN.Enum()}); err != nil { return err } - s.maybeDeclareStreamDone() return nil } + +func (s *stream) writeMsgOnWriter(msg proto.Message) error { + s.writerMx.Lock() + defer s.writerMx.Unlock() + return s.writer.WriteMsg(msg) +} diff --git a/p2p/transport/webrtc/transport_test.go b/p2p/transport/webrtc/transport_test.go index 7f4df94fc1..983e03c00f 100644 --- a/p2p/transport/webrtc/transport_test.go +++ b/p2p/transport/webrtc/transport_test.go @@ -496,7 +496,55 @@ func TestTransportWebRTC_RemoteReadsAfterClose(t *testing.T) { require.NoError(t, err) // require write and close to complete require.NoError(t, <-done) + stream.SetReadDeadline(time.Now().Add(5 * time.Second)) + + buf := make([]byte, 10) + n, err := stream.Read(buf) + require.NoError(t, err) + require.Equal(t, n, 4) +} + +func TestTransportWebRTC_RemoteReadsAfterAsyncClose(t *testing.T) { + tr, listeningPeer := getTransport(t) + listenMultiaddr := ma.StringCast("/ip4/127.0.0.1/udp/0/webrtc-direct") + listener, err := tr.Listen(listenMultiaddr) + require.NoError(t, err) + + tr1, _ := getTransport(t) + + done := make(chan error) + go func() { + lconn, err := listener.Accept() + if err != nil { + done <- err + return + } + s, err := lconn.AcceptStream() + if err != nil { + done <- err + return + } + _, err = s.Write([]byte{1, 2, 3, 4}) + if err != nil { + done <- err + return + } + err = s.(*stream).AsyncClose(nil) + if err != nil { + done <- err + return + } + close(done) + }() + conn, err := tr1.Dial(context.Background(), listener.Multiaddr(), listeningPeer) + require.NoError(t, err) + // create a stream + stream, err := conn.OpenStream(context.Background()) + + require.NoError(t, err) + // require write and close to complete + require.NoError(t, <-done) stream.SetReadDeadline(time.Now().Add(5 * time.Second)) buf := make([]byte, 10)