diff --git a/p2p/test/transport/gating_test.go b/p2p/test/transport/gating_test.go index c914b04d31..b0ee31eb47 100644 --- a/p2p/test/transport/gating_test.go +++ b/p2p/test/transport/gating_test.go @@ -6,6 +6,7 @@ import ( "testing" "time" + "github.com/libp2p/go-libp2p/core/control" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/protocol" @@ -83,9 +84,6 @@ func TestInterceptSecuredOutgoing(t *testing.T) { } for _, tc := range transportsToTest { t.Run(tc.Name, func(t *testing.T) { - if strings.Contains(tc.Name, "WebRTCPrivate") { - t.Skip("webrtcprivate needs different handling because of the relay required for connection setup") - } ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -99,14 +97,42 @@ func TestInterceptSecuredOutgoing(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - gomock.InOrder( - connGater.EXPECT().InterceptPeerDial(h2.ID()).Return(true), - connGater.EXPECT().InterceptAddrDial(h2.ID(), gomock.Any()).Return(true), - connGater.EXPECT().InterceptSecured(network.DirOutbound, h2.ID(), gomock.Any()).Do(func(_ network.Direction, _ peer.ID, addrs network.ConnMultiaddrs) { - // remove the certhash component from WebTransport and WebRTC addresses - require.Equal(t, stripCertHash(h2.Addrs()[0]).String(), addrs.RemoteMultiaddr().String()) - }), - ) + if strings.Contains(tc.Name, "WebRTCPrivate") { + gomock.InOrder( + connGater.EXPECT().InterceptPeerDial(h2.ID()).Return(true), + connGater.EXPECT().InterceptAddrDial(h2.ID(), gomock.Any()).Return(true), + + // Calls for circuit-v2 conn setup + connGater.EXPECT().InterceptPeerDial(h2.ID()).Return(true), + connGater.EXPECT().InterceptAddrDial(h2.ID(), gomock.Any()).Return(true), + connGater.EXPECT().InterceptAddrDial(gomock.Any(), gomock.Any()).Return(true), // two addresses for the peer + + // Calls for connection to relay node + connGater.EXPECT().InterceptPeerDial(gomock.Any()).Return(true), + connGater.EXPECT().InterceptAddrDial(gomock.Any(), gomock.Any()).Return(true), + connGater.EXPECT().InterceptSecured(network.DirOutbound, gomock.Any(), gomock.Any()).Return(true), + connGater.EXPECT().InterceptUpgraded(gomock.Any()).AnyTimes().Return(true, control.DisconnectReason(0)), + + // circuit-v2 setup complete + connGater.EXPECT().InterceptSecured(network.DirOutbound, gomock.Any(), gomock.Any()).Return(true), + connGater.EXPECT().InterceptUpgraded(gomock.Any()).AnyTimes().Return(true, control.DisconnectReason(0)), + + // webrtcprivate setup complete + // TODO: fix addresses on both sides of the /webrtc connection + connGater.EXPECT().InterceptSecured(network.DirOutbound, h2.ID(), gomock.Any()), + ) + // force a direct connection + ctx = network.WithForceDirectDial(ctx, "integration test /webrtc") + } else { + gomock.InOrder( + connGater.EXPECT().InterceptPeerDial(h2.ID()).Return(true), + connGater.EXPECT().InterceptAddrDial(h2.ID(), gomock.Any()).Return(true), + connGater.EXPECT().InterceptSecured(network.DirOutbound, h2.ID(), gomock.Any()).Do(func(_ network.Direction, _ peer.ID, addrs network.ConnMultiaddrs) { + // remove the certhash component from WebTransport and WebRTC addresses + require.Equal(t, stripCertHash(h2.Addrs()[0]).String(), addrs.RemoteMultiaddr().String()) + }), + ) + } err := h1.Connect(ctx, peer.AddrInfo{ID: h2.ID(), Addrs: h2.Addrs()}) require.Error(t, err) @@ -121,9 +147,6 @@ func TestInterceptUpgradedOutgoing(t *testing.T) { } for _, tc := range transportsToTest { t.Run(tc.Name, func(t *testing.T) { - if strings.Contains(tc.Name, "WebRTCPrivate") { - t.Skip("webrtc private to private needs special handling because the listener makes a dial to the relay") - } ctrl := gomock.NewController(t) defer ctrl.Finish() connGater := NewMockConnectionGater(ctrl) @@ -137,16 +160,48 @@ func TestInterceptUpgradedOutgoing(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - gomock.InOrder( - connGater.EXPECT().InterceptPeerDial(h2.ID()).Return(true), - connGater.EXPECT().InterceptAddrDial(h2.ID(), gomock.Any()).Return(true), - connGater.EXPECT().InterceptSecured(network.DirOutbound, h2.ID(), gomock.Any()).Return(true), - connGater.EXPECT().InterceptUpgraded(gomock.Any()).Do(func(c network.Conn) { - // remove the certhash component from WebTransport addresses - require.Equal(t, stripCertHash(h2.Addrs()[0]), c.RemoteMultiaddr()) - require.Equal(t, h1.ID(), c.LocalPeer()) - require.Equal(t, h2.ID(), c.RemotePeer()) - })) + if strings.Contains(tc.Name, "WebRTCPrivate") { + gomock.InOrder( + connGater.EXPECT().InterceptPeerDial(h2.ID()).Return(true), + connGater.EXPECT().InterceptAddrDial(h2.ID(), gomock.Any()).Return(true), + + // Calls for circuit-v2 conn setup + connGater.EXPECT().InterceptPeerDial(h2.ID()).Return(true), + connGater.EXPECT().InterceptAddrDial(h2.ID(), gomock.Any()).Return(true), + connGater.EXPECT().InterceptAddrDial(gomock.Any(), gomock.Any()).Return(true), // two addresses for the peer + + // Calls for connection to relay node + connGater.EXPECT().InterceptPeerDial(gomock.Any()).Return(true), + connGater.EXPECT().InterceptAddrDial(gomock.Any(), gomock.Any()).Return(true), + connGater.EXPECT().InterceptSecured(network.DirOutbound, gomock.Any(), gomock.Any()).Return(true), + connGater.EXPECT().InterceptUpgraded(gomock.Any()).Return(true, control.DisconnectReason(0)), + + // circuit-v2 setup complete + connGater.EXPECT().InterceptSecured(network.DirOutbound, gomock.Any(), gomock.Any()).Return(true), + connGater.EXPECT().InterceptUpgraded(gomock.Any()).Return(true, control.DisconnectReason(0)), + + // webrtcprivate setup complete + // TODO: fix addresses on both sides of the /webrtc connection + connGater.EXPECT().InterceptSecured(network.DirOutbound, h2.ID(), gomock.Any()).Return(true), + connGater.EXPECT().InterceptUpgraded(gomock.Any()).Do(func(c network.Conn) { + require.Equal(t, h1.ID(), c.LocalPeer()) + require.Equal(t, h2.ID(), c.RemotePeer()) + }), + ) + // force a direct connection + ctx = network.WithForceDirectDial(ctx, "integration test /webrtc") + } else { + gomock.InOrder( + connGater.EXPECT().InterceptPeerDial(h2.ID()).Return(true), + connGater.EXPECT().InterceptAddrDial(h2.ID(), gomock.Any()).Return(true), + connGater.EXPECT().InterceptSecured(network.DirOutbound, h2.ID(), gomock.Any()).Return(true), + connGater.EXPECT().InterceptUpgraded(gomock.Any()).Do(func(c network.Conn) { + // remove the certhash component from WebTransport addresses + require.Equal(t, stripCertHash(h2.Addrs()[0]), c.RemoteMultiaddr()) + require.Equal(t, h1.ID(), c.LocalPeer()) + require.Equal(t, h2.ID(), c.RemotePeer()) + })) + } err := h1.Connect(ctx, peer.AddrInfo{ID: h2.ID(), Addrs: h2.Addrs()}) require.Error(t, err) require.NotErrorIs(t, err, context.DeadlineExceeded) @@ -161,7 +216,8 @@ func TestInterceptAccept(t *testing.T) { for _, tc := range transportsToTest { t.Run(tc.Name, func(t *testing.T) { if strings.Contains(tc.Name, "WebRTCPrivate") { - t.Skip("webrtc private to private needs special handling because the listener makes a dial to the relay") + testInterceptAcceptIncomingWebRTCPrivate(t, tc) + return } ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -202,6 +258,44 @@ func TestInterceptAccept(t *testing.T) { } } +func testInterceptAcceptIncomingWebRTCPrivate(t *testing.T, tc TransportTestCase) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + connGater := NewMockConnectionGater(ctrl) + + // Relay reservation calls + connGater.EXPECT().InterceptPeerDial(gomock.Any()).Return(true) + connGater.EXPECT().InterceptAddrDial(gomock.Any(), gomock.Any()).Return(true) + connGater.EXPECT().InterceptSecured(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) + connGater.EXPECT().InterceptUpgraded(gomock.Any()).Return(true, control.DisconnectReason(0)) + + h1 := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true}) + h2 := tc.HostGenerator(t, TransportTestCaseOpts{ConnGater: connGater}) + defer h1.Close() + defer h2.Close() + require.Len(t, h2.Addrs(), 1) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + // relayed connection for incoming stream + connGater.EXPECT().InterceptAccept(gomock.Any()).Return(true) + connGater.EXPECT().InterceptSecured(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) + connGater.EXPECT().InterceptUpgraded(gomock.Any()).Return(true, control.DisconnectReason(0)) + // webrtc connection accept + // TODO: Fix webrtc addresses on both sides. + connGater.EXPECT().InterceptAccept(gomock.Any()) + + // The basic host dials the first connection. + h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), time.Hour) + _, err := h1.NewStream(ctx, h2.ID(), protocol.TestingID) + require.Error(t, err) + if _, err := h2.Addrs()[0].ValueForProtocol(ma.P_WEBRTC_DIRECT); err != nil { + // WebRTC rejects connection attempt before an error can be sent to the client. + // This means that the connection attempt will time out. + require.NotErrorIs(t, err, context.DeadlineExceeded) + } +} + func TestInterceptSecuredIncoming(t *testing.T) { if race.WithRace() { t.Skip("The upgrader spawns a new Go routine, which leads to race conditions when using GoMock.") @@ -209,7 +303,8 @@ func TestInterceptSecuredIncoming(t *testing.T) { for _, tc := range transportsToTest { t.Run(tc.Name, func(t *testing.T) { if strings.Contains(tc.Name, "WebRTCPrivate") { - t.Skip("webrtc private to private needs special handling because the listener makes a dial to the relay") + testInterceptSecuredIncomingWebRTCPrivate(t, tc) + return } ctrl := gomock.NewController(t) @@ -239,6 +334,50 @@ func TestInterceptSecuredIncoming(t *testing.T) { } } +func testInterceptSecuredIncomingWebRTCPrivate(t *testing.T, tc TransportTestCase) { + t.Helper() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + connGater := NewMockConnectionGater(ctrl) + gomock.InOrder( + // Dial to relay node for circuit reservation + connGater.EXPECT().InterceptPeerDial(gomock.Any()).Return(true), + connGater.EXPECT().InterceptAddrDial(gomock.Any(), gomock.Any()).Return(true), + connGater.EXPECT().InterceptSecured(gomock.Any(), gomock.Any(), gomock.Any()).Return(true), + connGater.EXPECT().InterceptUpgraded(gomock.Any()).Return(true, control.DisconnectReason(0)), + // Incoming relay connection for signaling stream + connGater.EXPECT().InterceptAccept(gomock.Any()).Return(true), + connGater.EXPECT().InterceptSecured(gomock.Any(), gomock.Any(), gomock.Any()).Return(true), + connGater.EXPECT().InterceptUpgraded(gomock.Any()).Return(true, control.DisconnectReason(0)), + ) + h1 := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true}) + h2 := tc.HostGenerator(t, TransportTestCaseOpts{ConnGater: connGater}) + defer h1.Close() + defer h2.Close() + + require.Len(t, h2.Addrs(), 1) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + ctx = network.WithForceDirectDial(ctx, "transport integration test /webrtc") + gomock.InOrder( + connGater.EXPECT().InterceptAccept(gomock.Any()).Return(true), + connGater.EXPECT().InterceptSecured(network.DirInbound, h1.ID(), gomock.Any()), + ) + + h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), time.Hour) + _, err := h1.NewStream(ctx, h2.ID(), protocol.TestingID) + require.Error(t, err) + + // Do not check that error is Context Deadline Exceeded + // WebRTCPrivate connection establishment is considered complete when the DTLS handshake finishes. + // At this point no SCTP association is established. Closing the connection on the listener side, + // immediately after accepting, closes the listener side of the connection before the SCTP association + // is established. Pion doesn't handle this case nicely and webrtc.PeerConnection on the dialer will + // only be considered closed once the ICE transport times out. +} + func TestInterceptUpgradedIncoming(t *testing.T) { if race.WithRace() { t.Skip("The upgrader spawns a new Go routine, which leads to race conditions when using GoMock.") @@ -246,7 +385,8 @@ func TestInterceptUpgradedIncoming(t *testing.T) { for _, tc := range transportsToTest { t.Run(tc.Name, func(t *testing.T) { if strings.Contains(tc.Name, "WebRTCPrivate") { - t.Skip("webrtc private to private needs special handling because the listener makes a dial to the relay") + testInterceptUpgradeIncomingWebRTCPrivate(t, tc) + return } ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -258,8 +398,6 @@ func TestInterceptUpgradedIncoming(t *testing.T) { defer h2.Close() require.Len(t, h2.Addrs(), 1) - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() gomock.InOrder( connGater.EXPECT().InterceptAccept(gomock.Any()).Return(true), connGater.EXPECT().InterceptSecured(network.DirInbound, h1.ID(), gomock.Any()).Return(true), @@ -270,6 +408,8 @@ func TestInterceptUpgradedIncoming(t *testing.T) { require.Equal(t, h2.ID(), c.LocalPeer()) }), ) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), time.Hour) _, err := h1.NewStream(ctx, h2.ID(), protocol.TestingID) require.Error(t, err) @@ -277,3 +417,46 @@ func TestInterceptUpgradedIncoming(t *testing.T) { }) } } + +func testInterceptUpgradeIncomingWebRTCPrivate(t *testing.T, tc TransportTestCase) { + t.Helper() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + connGater := NewMockConnectionGater(ctrl) + + gomock.InOrder( + // Dial to relay node for circuit reservation + connGater.EXPECT().InterceptPeerDial(gomock.Any()).Return(true), + connGater.EXPECT().InterceptAddrDial(gomock.Any(), gomock.Any()).Return(true), + connGater.EXPECT().InterceptSecured(gomock.Any(), gomock.Any(), gomock.Any()).Return(true), + connGater.EXPECT().InterceptUpgraded(gomock.Any()).Return(true, control.DisconnectReason(0)), + // Incoming relay connection for signaling stream + connGater.EXPECT().InterceptAccept(gomock.Any()).Return(true), + connGater.EXPECT().InterceptSecured(gomock.Any(), gomock.Any(), gomock.Any()).Return(true), + connGater.EXPECT().InterceptUpgraded(gomock.Any()).Return(true, control.DisconnectReason(0)), + ) + + h1 := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true}) + h2 := tc.HostGenerator(t, TransportTestCaseOpts{ConnGater: connGater}) + defer h1.Close() + defer h2.Close() + + require.Len(t, h2.Addrs(), 1) + + gomock.InOrder( + connGater.EXPECT().InterceptAccept(gomock.Any()).Return(true), + connGater.EXPECT().InterceptSecured(network.DirInbound, h1.ID(), gomock.Any()).Return(true), + connGater.EXPECT().InterceptUpgraded(gomock.Any()).Do(func(c network.Conn) { + require.Equal(t, h1.ID(), c.RemotePeer()) + require.Equal(t, h2.ID(), c.LocalPeer()) + }), + ) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + ctx = network.WithForceDirectDial(ctx, "transport integration test /webrtc") + + h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), time.Hour) + _, err := h1.NewStream(ctx, h2.ID(), protocol.TestingID) + require.Error(t, err) +} diff --git a/p2p/test/transport/rcmgr_test.go b/p2p/test/transport/rcmgr_test.go index 1ea2a0a69e..4bea604af8 100644 --- a/p2p/test/transport/rcmgr_test.go +++ b/p2p/test/transport/rcmgr_test.go @@ -25,7 +25,8 @@ func TestResourceManagerIsUsed(t *testing.T) { for _, testDialer := range []bool{true, false} { t.Run(tc.Name+fmt.Sprintf(" test_dialer=%v", testDialer), func(t *testing.T) { if strings.Contains(tc.Name, "WebRTCPrivate") { - t.Skip("webrtcprivate needs different handling because of the relay required for connection setup") + testResourceManagerIsUsedWebRTCPrivate(t, tc, testDialer) + return } var reservedMemory, releasedMemory atomic.Int32 defer func() { @@ -141,3 +142,299 @@ func TestResourceManagerIsUsed(t *testing.T) { }) } } + +func testResourceManagerIsUsedWebRTCPrivate(t *testing.T, tc TransportTestCase, testDialer bool) { + t.Helper() + if testDialer { + testResourceManagerIsUsedWebRTCPrivateDialer(t, tc) + } else { + testResourceManagerIsUsedWebRTCPrivateListener(t, tc) + } +} + +func testResourceManagerIsUsedWebRTCPrivateDialer(t *testing.T, tc TransportTestCase) { + t.Helper() + var reservedMemory, releasedMemory atomic.Int32 + defer func() { + require.Equal(t, reservedMemory.Load(), releasedMemory.Load()) + require.NotEqual(t, 0, reservedMemory.Load()) + }() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + rcmgr := mocknetwork.NewMockResourceManager(ctrl) + rcmgr.EXPECT().Close() + + listener := tc.HostGenerator(t, TransportTestCaseOpts{NoRcmgr: true}) + dialer := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true, ResourceManager: rcmgr}) + + getPeerScope := func() *mocknetwork.MockPeerScope { + peerScope := mocknetwork.NewMockPeerScope(ctrl) + peerScope.EXPECT().ReserveMemory(gomock.Any(), gomock.Any()).AnyTimes().Do(func(amount int, pri uint8) { + reservedMemory.Add(int32(amount)) + }) + peerScope.EXPECT().ReleaseMemory(gomock.Any()).AnyTimes().Do(func(amount int) { + releasedMemory.Add(int32(amount)) + }) + peerScope.EXPECT().BeginSpan().AnyTimes().DoAndReturn(func() (network.ResourceScopeSpan, error) { + s := mocknetwork.NewMockResourceScopeSpan(ctrl) + s.EXPECT().BeginSpan().AnyTimes().Return(mocknetwork.NewMockResourceScopeSpan(ctrl), nil) + // No need to track these memory reservations since we assert that Done is called + s.EXPECT().ReserveMemory(gomock.Any(), gomock.Any()) + s.EXPECT().Done() + return s, nil + }) + return peerScope + } + + relayPeerScope := getPeerScope() + listenerPeerScope := getPeerScope() + + // getRelayConnScope creates a connection scope for the connection to the relay node and the circuitv2 connection + // to the listener + getRelayConnScope := func() *mocknetwork.MockConnManagementScope { + var connPeer atomic.Pointer[peer.ID] + connScope := mocknetwork.NewMockConnManagementScope(ctrl) + connScope.EXPECT().SetPeer(gomock.Any()).AnyTimes().Do(func(p peer.ID) { + connPeer.Store(&p) + }) + connScope.EXPECT().PeerScope().AnyTimes().DoAndReturn(func() *mocknetwork.MockPeerScope { + if connPeer.Load() == nil || *connPeer.Load() == listener.ID() { + return listenerPeerScope + } + return relayPeerScope + }) + connScope.EXPECT().Done().MinTimes(1) + return connScope + } + + // TODO: Fix addresses on rcmgr calls + + var calledSetPeer atomic.Bool + // connScope is the scope to be used for the final connection over /webrtc to the peer + connScope := mocknetwork.NewMockConnManagementScope(ctrl) + connScope.EXPECT().SetPeer(listener.ID()).Do(func(peer.ID) { + calledSetPeer.Store(true) + }) + connScope.EXPECT().PeerScope().AnyTimes().DoAndReturn(func() network.PeerScope { + if calledSetPeer.Load() { + return listenerPeerScope + } + return nil + }) + connScope.EXPECT().Done().MinTimes(1) + + var allStreamsDone sync.WaitGroup + + // Dial to relay node and then to listener over circtuiv2 + rcmgr.EXPECT().OpenConnection(network.DirOutbound, gomock.Any(), gomock.Any()).MaxTimes(2).Return(getRelayConnScope(), nil) + + rcmgr.EXPECT().OpenStream(gomock.Any(), gomock.Any()).AnyTimes().DoAndReturn(func(id peer.ID, dir network.Direction) (network.StreamManagementScope, error) { + streamScope := mocknetwork.NewMockStreamManagementScope(ctrl) + // No need to track these memory reservations since we assert that Done is called + streamScope.EXPECT().ReserveMemory(gomock.Any(), gomock.Any()).AnyTimes() + streamScope.EXPECT().ReleaseMemory(gomock.Any()).AnyTimes() + streamScope.EXPECT().BeginSpan().AnyTimes().DoAndReturn(func() (network.ResourceScopeSpan, error) { + s := mocknetwork.NewMockResourceScopeSpan(ctrl) + s.EXPECT().BeginSpan().AnyTimes().Return(mocknetwork.NewMockResourceScopeSpan(ctrl), nil) + s.EXPECT().Done() + return s, nil + }) + streamScope.EXPECT().SetService(gomock.Any()).AnyTimes() + streamScope.EXPECT().SetProtocol(gomock.Any()).AnyTimes() + streamScope.EXPECT().Done() + return streamScope, nil + }) + + // Now handle /webrtc connection establishment + rcmgr.EXPECT().OpenConnection(network.DirOutbound, true, gomock.Any()).Return(connScope, nil) + rcmgr.EXPECT().OpenStream(listener.ID(), gomock.Any()).AnyTimes().DoAndReturn(func(id peer.ID, dir network.Direction) (network.StreamManagementScope, error) { + allStreamsDone.Add(1) + streamScope := mocknetwork.NewMockStreamManagementScope(ctrl) + // No need to track these memory reservations since we assert that Done is called + streamScope.EXPECT().ReserveMemory(gomock.Any(), gomock.Any()).AnyTimes() + streamScope.EXPECT().ReleaseMemory(gomock.Any()).AnyTimes() + streamScope.EXPECT().BeginSpan().AnyTimes().DoAndReturn(func() (network.ResourceScopeSpan, error) { + s := mocknetwork.NewMockResourceScopeSpan(ctrl) + s.EXPECT().BeginSpan().AnyTimes().Return(mocknetwork.NewMockResourceScopeSpan(ctrl), nil) + s.EXPECT().Done() + return s, nil + }) + + streamScope.EXPECT().SetService(gomock.Any()).MaxTimes(1) + streamScope.EXPECT().SetProtocol(gomock.Any()) + + streamScope.EXPECT().Done().Do(func() { + allStreamsDone.Done() + }) + return streamScope, nil + }) + + require.NoError(t, dialer.Connect(context.Background(), peer.AddrInfo{ + ID: listener.ID(), + Addrs: listener.Addrs(), + })) + + // Wait for any in progress identifies to finish. + // We shouldn't have to do this, but basic host currently + // always does an identify. + <-dialer.(interface{ IDService() identify.IDService }).IDService().IdentifyWait(dialer.Network().ConnsToPeer(listener.ID())[0]) + <-listener.(interface{ IDService() identify.IDService }).IDService().IdentifyWait(listener.Network().ConnsToPeer(dialer.ID())[0]) + <-ping.Ping(context.Background(), dialer, listener.ID()) + err := dialer.Network().ClosePeer(listener.ID()) + require.NoError(t, err) + + // Wait a bit for any pending .Adds before we call .Wait to avoid a data race. + // This shouldn't be necessary since it should be impossible + // for an OpenStream to happen *after* a ClosePeer, however + // in practice it does and leads to test flakiness. + time.Sleep(10 * time.Millisecond) + allStreamsDone.Wait() + dialer.Close() + listener.Close() +} + +func testResourceManagerIsUsedWebRTCPrivateListener(t *testing.T, tc TransportTestCase) { + t.Helper() + var reservedMemory, releasedMemory atomic.Int32 + defer func() { + require.Equal(t, reservedMemory.Load(), releasedMemory.Load()) + require.NotEqual(t, 0, reservedMemory.Load()) + }() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + rcmgr := mocknetwork.NewMockResourceManager(ctrl) + rcmgr.EXPECT().Close() + + var listener, dialer host.Host + + getPeerScope := func() *mocknetwork.MockPeerScope { + peerScope := mocknetwork.NewMockPeerScope(ctrl) + peerScope.EXPECT().ReserveMemory(gomock.Any(), gomock.Any()).AnyTimes().Do(func(amount int, pri uint8) { + reservedMemory.Add(int32(amount)) + }) + peerScope.EXPECT().ReleaseMemory(gomock.Any()).AnyTimes().Do(func(amount int) { + releasedMemory.Add(int32(amount)) + }) + peerScope.EXPECT().BeginSpan().AnyTimes().DoAndReturn(func() (network.ResourceScopeSpan, error) { + s := mocknetwork.NewMockResourceScopeSpan(ctrl) + s.EXPECT().BeginSpan().AnyTimes().Return(mocknetwork.NewMockResourceScopeSpan(ctrl), nil) + // No need to track these memory reservations since we assert that Done is called + s.EXPECT().ReserveMemory(gomock.Any(), gomock.Any()) + s.EXPECT().Done() + return s, nil + }) + return peerScope + } + + relayPeerScope := getPeerScope() + dialerPeerScope := getPeerScope() + + // getRelayConnScope creates a connection scope for the connection to the relay node and the circuitv2 connection + // from the dialer + getRelayConnScope := func() *mocknetwork.MockConnManagementScope { + var connPeer atomic.Pointer[peer.ID] + connScope := mocknetwork.NewMockConnManagementScope(ctrl) + connScope.EXPECT().SetPeer(gomock.Any()).AnyTimes().Do(func(p peer.ID) { + connPeer.Store(&p) + }) + connScope.EXPECT().PeerScope().AnyTimes().DoAndReturn(func() *mocknetwork.MockPeerScope { + // Nil returns from DoAndReturn are typed. This causes problems with nil + // checks. So we return dialerPeerScope when PeerScope is called before + // SetPeer. This happens on circuitv2 connection upgrade. + if connPeer.Load() == nil || *connPeer.Load() == dialer.ID() { + return dialerPeerScope + } + return relayPeerScope + }) + connScope.EXPECT().Done().MinTimes(1) + return connScope + } + + // Connection to the relay node + rcmgr.EXPECT().OpenConnection(network.DirOutbound, gomock.Any(), gomock.Any()).Return(getRelayConnScope(), nil) + rcmgr.EXPECT().OpenStream(gomock.Any(), gomock.Any()).AnyTimes().DoAndReturn(func(id peer.ID, dir network.Direction) (network.StreamManagementScope, error) { + streamScope := mocknetwork.NewMockStreamManagementScope(ctrl) + // No need to track these memory reservations since we assert that Done is called + streamScope.EXPECT().ReserveMemory(gomock.Any(), gomock.Any()).AnyTimes() + streamScope.EXPECT().ReleaseMemory(gomock.Any()).AnyTimes() + streamScope.EXPECT().BeginSpan().AnyTimes().DoAndReturn(func() (network.ResourceScopeSpan, error) { + s := mocknetwork.NewMockResourceScopeSpan(ctrl) + s.EXPECT().BeginSpan().AnyTimes().Return(mocknetwork.NewMockResourceScopeSpan(ctrl), nil) + s.EXPECT().Done() + return s, nil + }) + streamScope.EXPECT().SetService(gomock.Any()).AnyTimes() + streamScope.EXPECT().SetProtocol(gomock.Any()).AnyTimes() + streamScope.EXPECT().Done() + return streamScope, nil + }) + + listener = tc.HostGenerator(t, TransportTestCaseOpts{ResourceManager: rcmgr}) + dialer = tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true, NoRcmgr: true}) + + var calledSetPeer atomic.Bool + connScope := mocknetwork.NewMockConnManagementScope(ctrl) + connScope.EXPECT().SetPeer(dialer.ID()).Do(func(peer.ID) { + calledSetPeer.Store(true) + }) + connScope.EXPECT().PeerScope().AnyTimes().DoAndReturn(func() network.PeerScope { + if calledSetPeer.Load() { + return dialerPeerScope + } + return nil + }) + connScope.EXPECT().Done().MinTimes(1) + + var allStreamsDone sync.WaitGroup + gomock.InOrder( + // Incoming circuit-v2 connection for signaling stream + rcmgr.EXPECT().OpenConnection(network.DirInbound, gomock.Any(), gomock.Any()).Return(getRelayConnScope(), nil), + // /webrtc connection + rcmgr.EXPECT().OpenConnection(network.DirInbound, true, gomock.Any()).Return(connScope, nil), + ) + rcmgr.EXPECT().OpenStream(dialer.ID(), gomock.Any()).AnyTimes().DoAndReturn(func(id peer.ID, dir network.Direction) (network.StreamManagementScope, error) { + allStreamsDone.Add(1) + streamScope := mocknetwork.NewMockStreamManagementScope(ctrl) + // No need to track these memory reservations since we assert that Done is called + streamScope.EXPECT().ReserveMemory(gomock.Any(), gomock.Any()).AnyTimes() + streamScope.EXPECT().ReleaseMemory(gomock.Any()).AnyTimes() + streamScope.EXPECT().BeginSpan().AnyTimes().DoAndReturn(func() (network.ResourceScopeSpan, error) { + s := mocknetwork.NewMockResourceScopeSpan(ctrl) + s.EXPECT().BeginSpan().AnyTimes().Return(mocknetwork.NewMockResourceScopeSpan(ctrl), nil) + s.EXPECT().Done() + return s, nil + }) + + streamScope.EXPECT().SetService(gomock.Any()).MaxTimes(1) + streamScope.EXPECT().SetProtocol(gomock.Any()) + + streamScope.EXPECT().Done().Do(func() { + allStreamsDone.Done() + }) + return streamScope, nil + }) + + require.NoError(t, dialer.Connect(context.Background(), peer.AddrInfo{ + ID: listener.ID(), + Addrs: listener.Addrs(), + })) + // Wait for any in progress identifies to finish. + // We shouldn't have to do this, but basic host currently + // always does an identify. + <-dialer.(interface{ IDService() identify.IDService }).IDService().IdentifyWait(dialer.Network().ConnsToPeer(listener.ID())[0]) + <-listener.(interface{ IDService() identify.IDService }).IDService().IdentifyWait(listener.Network().ConnsToPeer(dialer.ID())[0]) + <-ping.Ping(context.Background(), dialer, listener.ID()) + err := dialer.Network().ClosePeer(listener.ID()) + require.NoError(t, err) + + // Wait a bit for any pending .Adds before we call .Wait to avoid a data race. + // This shouldn't be necessary since it should be impossible + // for an OpenStream to happen *after* a ClosePeer, however + // in practice it does and leads to test flakiness. + time.Sleep(10 * time.Millisecond) + allStreamsDone.Wait() + dialer.Close() + listener.Close() +} diff --git a/p2p/test/transport/transport_test.go b/p2p/test/transport/transport_test.go index d6166ce3db..6127ede57c 100644 --- a/p2p/test/transport/transport_test.go +++ b/p2p/test/transport/transport_test.go @@ -22,10 +22,12 @@ import ( "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/sec" + basichost "github.com/libp2p/go-libp2p/p2p/host/basic" rcmgr "github.com/libp2p/go-libp2p/p2p/host/resource-manager" "github.com/libp2p/go-libp2p/p2p/muxer/yamux" "github.com/libp2p/go-libp2p/p2p/net/swarm" "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/client" + "github.com/libp2p/go-libp2p/p2p/protocol/identify" "github.com/libp2p/go-libp2p/p2p/protocol/ping" "github.com/libp2p/go-libp2p/p2p/security/noise" tls "github.com/libp2p/go-libp2p/p2p/security/tls" @@ -206,6 +208,10 @@ func (h *webrtcHost) Close() error { return nil } +func (h *webrtcHost) IDService() identify.IDService { + return h.Host.(*basichost.BasicHost).IDService() +} + func TestPing(t *testing.T) { for _, tc := range transportsToTest { t.Run(tc.Name, func(t *testing.T) { diff --git a/p2p/transport/webrtc/connection.go b/p2p/transport/webrtc/connection.go index 13300f9e97..7b0fe78d86 100644 --- a/p2p/transport/webrtc/connection.go +++ b/p2p/transport/webrtc/connection.go @@ -48,7 +48,8 @@ type connection struct { transport tpt.Transport scope network.ConnManagementScope - closeErr error + closeOnce sync.Once + closeErr error localPeer peer.ID localMultiaddr ma.Multiaddr @@ -115,7 +116,11 @@ func NewWebRTCConnection( c.nextStreamID.Store(2) } - pc.OnConnectionStateChange(c.onConnectionStateChange) + pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) { + if state == webrtc.PeerConnectionStateFailed || state == webrtc.PeerConnectionStateClosed { + c.closeTimedOut() + } + }) // Between the connection establishing and the callback update in the above line, the // connection may have been closed @@ -135,16 +140,41 @@ func (c *connection) ConnState() network.ConnectionState { // Close closes the underlying peerconnection. func (c *connection) Close() error { - if c.IsClosed() { - return nil - } + c.closeOnce.Do(func() { + c.closeErr = errors.New("connection closed") + // cancel must be called after closeErr is set. This ensures interested goroutines waiting on + // ctx.Done can read closeErr without holding the conn lock. + c.cancel() + c.m.Lock() + streams := c.streams + c.streams = nil + c.m.Unlock() + for _, str := range streams { + str.Reset() + } + c.pc.Close() + c.scope.Done() + }) + return nil +} - c.m.Lock() - defer c.m.Unlock() - c.scope.Done() - c.closeErr = errors.New("connection closed") - c.cancel() - return c.pc.Close() +func (c *connection) closeTimedOut() error { + c.closeOnce.Do(func() { + c.closeErr = errConnectionTimeout{} + // cancel must be called after closeErr is set. This ensures interested goroutines waiting on + // ctx.Done can read closeErr without holding the conn lock. + c.cancel() + c.m.Lock() + streams := c.streams + c.streams = nil + c.m.Unlock() + for _, str := range streams { + str.setCloseError(errConnectionTimeout{}) + } + c.pc.Close() + c.scope.Done() + }) + return nil } func (c *connection) IsClosed() bool { @@ -214,6 +244,9 @@ func (c *connection) Transport() tpt.Transport { return c.transport } func (c *connection) addStream(str *stream) error { c.m.Lock() defer c.m.Unlock() + if c.streams == nil { + return c.closeErr + } if _, ok := c.streams[str.id]; ok { return errors.New("stream ID already exists") } @@ -227,25 +260,6 @@ func (c *connection) removeStream(id uint16) { delete(c.streams, id) } -func (c *connection) onConnectionStateChange(state webrtc.PeerConnectionState) { - if state == webrtc.PeerConnectionStateFailed || state == webrtc.PeerConnectionStateClosed { - // reset any streams - if c.IsClosed() { - return - } - c.m.Lock() - defer c.m.Unlock() - c.closeErr = errConnectionTimeout{} - for k, str := range c.streams { - str.setCloseError(c.closeErr) - delete(c.streams, k) - } - c.cancel() - c.scope.Done() - c.pc.Close() - } -} - // detachChannel detaches an outgoing channel by taking into account the context // passed to `OpenStream` as well the closure of the underlying peerconnection // diff --git a/p2p/transport/webrtc/stream.go b/p2p/transport/webrtc/stream.go index 0358dce56c..5d1368b1ab 100644 --- a/p2p/transport/webrtc/stream.go +++ b/p2p/transport/webrtc/stream.go @@ -122,25 +122,7 @@ func newStream( return } } - - if s.isDone() { - // onDone removes the stream from the connection and requires the connection lock. - // This callback(onBufferedAmountLow) is executing in the sctp readLoop goroutine. - // If Connection.Close is called concurrently, the closing goroutine will acquire - // the connection lock and wait for sctp readLoop to exit, the sctp readLoop will - // wait for the connection lock before exiting, causing a deadlock. - // Run this in a different goroutine to avoid the deadlock. - go func() { - s.mx.Lock() - defer s.mx.Unlock() - // TODO: we should be closing the underlying datachannel, but this resets the stream - // See https://github.com/libp2p/specs/issues/575 for details. - // _ = s.dataChannel.Close() - // TODO: write for the spawned reader to return - s.onDone() - }() - } - + s.maybeDeclareStreamDone() select { case s.writeAvailable <- struct{}{}: default: @@ -175,7 +157,7 @@ func (s *stream) SetDeadline(t time.Time) error { // processIncomingFlag process the flag on an incoming message // It needs to be called with msg.Flag, not msg.GetFlag(), // otherwise we'd misinterpret the default value. -// It needs to be called while the mutex is locked. +// It must be called with mx acquired. func (s *stream) processIncomingFlag(flag *pb.Message_Flag) { if flag == nil { return @@ -202,27 +184,24 @@ func (s *stream) processIncomingFlag(flag *pb.Message_Flag) { s.maybeDeclareStreamDone() } -// maybeDeclareStreamDone is used to force reset a stream. It should be called with -// the stream lock acquired. It calls stream.onDone which requires the connection lock. +// maybeDeclareStreamDone is used to force reset a stream. It must be called with mx acquired func (s *stream) maybeDeclareStreamDone() { - if s.isDone() { + if (s.sendState == sendStateReset || s.sendState == sendStateDataSent) && + (s.receiveState == receiveStateReset || s.receiveState == receiveStateDataRead) && + len(s.controlMsgQueue) == 0 { + + s.mx.Unlock() + defer s.mx.Lock() _ = s.SetReadDeadline(time.Now().Add(-1 * time.Hour)) // pion ignores zero times // TODO: we should be closing the underlying datachannel, but this resets the stream // See https://github.com/libp2p/specs/issues/575 for details. // _ = s.dataChannel.Close() // TODO: write for the spawned reader to return + s.onDone() } } -// isDone indicates whether the stream is completed and all the control messages have also been -// flushed. It must be called with the stream lock acquired. -func (s *stream) isDone() bool { - return (s.sendState == sendStateReset || s.sendState == sendStateDataSent) && - (s.receiveState == receiveStateReset || s.receiveState == receiveStateDataRead) && - len(s.controlMsgQueue) == 0 -} - func (s *stream) setCloseError(e error) { s.mx.Lock() defer s.mx.Unlock() diff --git a/p2p/transport/webrtcprivate/listener.go b/p2p/transport/webrtcprivate/listener.go index 97807e3ef5..30d8cdb85a 100644 --- a/p2p/transport/webrtcprivate/listener.go +++ b/p2p/transport/webrtcprivate/listener.go @@ -88,19 +88,20 @@ func (l *listener) handleSignalingStream(s network.Stream) { return } if err := scope.SetPeer(s.Conn().RemotePeer()); err != nil { + s.Reset() log.Debugf("resource manager blocked incoming conn from peer %s: %s", s.Conn().RemotePeer(), err) return } if err := s.Scope().SetService(name); err != nil { - log.Debugf("error attaching stream to /webrtc listener: %s", err) s.Reset() + log.Debugf("error attaching stream to /webrtc listener: %s", err) return } if err := s.Scope().ReserveMemory(2*maxMsgSize, network.ReservationPriorityAlways); err != nil { - log.Debugf("error reserving memory for /webrtc signaling stream: %s", err) s.Reset() + log.Debugf("error reserving memory for /webrtc signaling stream: %s", err) return } defer s.Scope().ReleaseMemory(maxMsgSize) @@ -125,8 +126,10 @@ func (l *listener) handleSignalingStream(s network.Stream) { } if l.transport.gater != nil && !l.transport.gater.InterceptSecured(network.DirInbound, s.Conn().RemotePeer(), conn) { + s.Reset() conn.Close() log.Debugf("conn gater refused connection to addr: %s", conn.RemoteMultiaddr()) + return } // Close the stream before we wait for the connection to be accepted s.Close() diff --git a/p2p/transport/webrtcprivate/transport.go b/p2p/transport/webrtcprivate/transport.go index 63d562863d..a72e6d5eef 100644 --- a/p2p/transport/webrtcprivate/transport.go +++ b/p2p/transport/webrtcprivate/transport.go @@ -39,7 +39,7 @@ const ( disconnectedTimeout = 20 * time.Second failedTimeout = 30 * time.Second keepaliveTimeout = 15 * time.Second - maxAcceptQueueLen = 10 + maxAcceptQueueLen = 256 ) var (