From e71b25e2ad47888481074f677cb4ff89373dbdb7 Mon Sep 17 00:00:00 2001 From: sukun Date: Sat, 14 Oct 2023 15:12:34 +0530 Subject: [PATCH] webrtcprivate: set relay address on the remote side of conn --- p2p/test/swarm/swarm_test.go | 22 ++------- p2p/test/transport/gating_test.go | 46 +++++++++++++------ p2p/test/transport/rcmgr_test.go | 7 ++- p2p/test/transport/transport_test.go | 2 +- p2p/transport/webrtcprivate/listener.go | 17 +++++-- p2p/transport/webrtcprivate/transport.go | 42 ++++++++--------- p2p/transport/webrtcprivate/transport_test.go | 6 +-- 7 files changed, 77 insertions(+), 65 deletions(-) diff --git a/p2p/test/swarm/swarm_test.go b/p2p/test/swarm/swarm_test.go index 5fdafb2913..ea81c77d0b 100644 --- a/p2p/test/swarm/swarm_test.go +++ b/p2p/test/swarm/swarm_test.go @@ -276,8 +276,8 @@ func TestDialPeerWebRTC(t *testing.T) { _, err = client.Reserve(context.Background(), h2, relay1info) require.NoError(t, err) - webrtcAddr := ma.StringCast(relay1info.Addrs[0].String() + "/p2p/" + relay1info.ID.String() + "/p2p-circuit/webrtc/p2p/" + h2.ID().String()) - relayAddrs := ma.StringCast(relay1info.Addrs[0].String() + "/p2p/" + relay1info.ID.String() + "/p2p-circuit/p2p/" + h2.ID().String()) + webrtcAddr := ma.StringCast(relay1info.Addrs[0].String() + "/p2p/" + relay1info.ID.String() + "/p2p-circuit/webrtc") + relayAddrs := ma.StringCast(relay1info.Addrs[0].String() + "/p2p/" + relay1info.ID.String() + "/p2p-circuit/") h1.Peerstore().AddAddrs(h2.ID(), []ma.Multiaddr{webrtcAddr, relayAddrs}, peerstore.TempAddrTTL) @@ -285,30 +285,18 @@ func TestDialPeerWebRTC(t *testing.T) { conn1, err := h1.Network().DialPeer(context.Background(), h2.ID()) require.NoError(t, err) require.NotNil(t, conn1) - require.Condition(t, func() bool { - _, err1 := conn1.RemoteMultiaddr().ValueForProtocol(ma.P_CIRCUIT) - _, err2 := conn1.RemoteMultiaddr().ValueForProtocol(ma.P_WEBRTC) - return err1 == nil && err2 != nil - }) + require.Equal(t, conn1.RemoteMultiaddr(), relayAddrs) // should connect to webrtc address ctx := network.WithForceDirectDial(context.Background(), "test") conn, err := h1.Network().DialPeer(ctx, h2.ID()) require.NoError(t, err) require.NotNil(t, conn) - require.Condition(t, func() bool { - _, err1 := conn.RemoteMultiaddr().ValueForProtocol(ma.P_CIRCUIT) - _, err2 := conn.RemoteMultiaddr().ValueForProtocol(ma.P_WEBRTC) - return err1 != nil && err2 == nil - }) + require.Equal(t, conn.RemoteMultiaddr(), webrtcAddr) done := make(chan struct{}) h2.SetStreamHandler("test-addr", func(s network.Stream) { - s.Conn().LocalMultiaddr() - _, err1 := conn.RemoteMultiaddr().ValueForProtocol(ma.P_CIRCUIT) - assert.Error(t, err1) - _, err2 := conn.RemoteMultiaddr().ValueForProtocol(ma.P_WEBRTC) - assert.NoError(t, err2) + require.Equal(t, conn.RemoteMultiaddr(), webrtcAddr) s.Reset() close(done) }) diff --git a/p2p/test/transport/gating_test.go b/p2p/test/transport/gating_test.go index b0ee31eb47..ce66be3bb2 100644 --- a/p2p/test/transport/gating_test.go +++ b/p2p/test/transport/gating_test.go @@ -118,9 +118,16 @@ func TestInterceptSecuredOutgoing(t *testing.T) { connGater.EXPECT().InterceptUpgraded(gomock.Any()).AnyTimes().Return(true, control.DisconnectReason(0)), // webrtcprivate setup complete - // TODO: fix addresses on both sides of the /webrtc connection - connGater.EXPECT().InterceptSecured(network.DirOutbound, h2.ID(), gomock.Any()), + connGater.EXPECT().InterceptSecured(network.DirOutbound, h2.ID(), gomock.Any()).Do(func(_, _ any, c network.ConnMultiaddrs) { + require.Equal(t, h2.Addrs()[0], c.RemoteMultiaddr()) + }), ) + // As we are using trickle ICE, some candidates are handled + // as prflx candidates. The default wait time for prflx + // candidates is 1 second. Bump the timeout to handle those + // cases. + ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() // force a direct connection ctx = network.WithForceDirectDial(ctx, "integration test /webrtc") } else { @@ -181,13 +188,19 @@ func TestInterceptUpgradedOutgoing(t *testing.T) { connGater.EXPECT().InterceptUpgraded(gomock.Any()).Return(true, control.DisconnectReason(0)), // webrtcprivate setup complete - // TODO: fix addresses on both sides of the /webrtc connection connGater.EXPECT().InterceptSecured(network.DirOutbound, h2.ID(), gomock.Any()).Return(true), connGater.EXPECT().InterceptUpgraded(gomock.Any()).Do(func(c network.Conn) { require.Equal(t, h1.ID(), c.LocalPeer()) require.Equal(t, h2.ID(), c.RemotePeer()) + require.Equal(t, h2.Addrs()[0], c.RemoteMultiaddr()) }), ) + // As we are using trickle ICE, some candidates are handled + // as prflx candidates. The default wait time for prflx + // candidates is 1 second. Bump the timeout to handle those + // cases. + ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() // force a direct connection ctx = network.WithForceDirectDial(ctx, "integration test /webrtc") } else { @@ -275,25 +288,25 @@ func testInterceptAcceptIncomingWebRTCPrivate(t *testing.T, tc TransportTestCase defer h2.Close() require.Len(t, h2.Addrs(), 1) - ctx, cancel := context.WithTimeout(context.Background(), time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() // relayed connection for incoming stream connGater.EXPECT().InterceptAccept(gomock.Any()).Return(true) connGater.EXPECT().InterceptSecured(gomock.Any(), gomock.Any(), gomock.Any()).Return(true) connGater.EXPECT().InterceptUpgraded(gomock.Any()).Return(true, control.DisconnectReason(0)) // webrtc connection accept - // TODO: Fix webrtc addresses on both sides. - connGater.EXPECT().InterceptAccept(gomock.Any()) + connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(c network.ConnMultiaddrs) { + // On an incoming webrtc connection, the remote address will be the same as the listen address. + // This is simlar to how addresses are setup on circuit-v2 connections + require.Equal(t, h2.Addrs()[0], c.RemoteMultiaddr(), h2.Addrs(), c.RemoteMultiaddr()) + }) // The basic host dials the first connection. h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), time.Hour) _, err := h1.NewStream(ctx, h2.ID(), protocol.TestingID) require.Error(t, err) - if _, err := h2.Addrs()[0].ValueForProtocol(ma.P_WEBRTC_DIRECT); err != nil { - // WebRTC rejects connection attempt before an error can be sent to the client. - // This means that the connection attempt will time out. - require.NotErrorIs(t, err, context.DeadlineExceeded) - } + // Do not check that error is Context Deadline Exceeded + // See details at the end of `testInterceptSecuredIncomingWebRTCPrivate` } func TestInterceptSecuredIncoming(t *testing.T) { @@ -358,12 +371,14 @@ func testInterceptSecuredIncomingWebRTCPrivate(t *testing.T, tc TransportTestCas require.Len(t, h2.Addrs(), 1) - ctx, cancel := context.WithTimeout(context.Background(), time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() ctx = network.WithForceDirectDial(ctx, "transport integration test /webrtc") gomock.InOrder( connGater.EXPECT().InterceptAccept(gomock.Any()).Return(true), - connGater.EXPECT().InterceptSecured(network.DirInbound, h1.ID(), gomock.Any()), + connGater.EXPECT().InterceptSecured(network.DirInbound, h1.ID(), gomock.Any()).Do(func(_, _ any, c network.ConnMultiaddrs) { + require.Equal(t, h2.Addrs()[0], c.RemoteMultiaddr()) + }), ) h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), time.Hour) @@ -447,16 +462,19 @@ func testInterceptUpgradeIncomingWebRTCPrivate(t *testing.T, tc TransportTestCas connGater.EXPECT().InterceptAccept(gomock.Any()).Return(true), connGater.EXPECT().InterceptSecured(network.DirInbound, h1.ID(), gomock.Any()).Return(true), connGater.EXPECT().InterceptUpgraded(gomock.Any()).Do(func(c network.Conn) { + require.Equal(t, h2.Addrs()[0], c.RemoteMultiaddr()) require.Equal(t, h1.ID(), c.RemotePeer()) require.Equal(t, h2.ID(), c.LocalPeer()) }), ) - ctx, cancel := context.WithTimeout(context.Background(), time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() ctx = network.WithForceDirectDial(ctx, "transport integration test /webrtc") h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), time.Hour) _, err := h1.NewStream(ctx, h2.ID(), protocol.TestingID) require.Error(t, err) + // Do not check that error is Context Deadline Exceeded + // See details at the end of `testInterceptSecuredIncomingWebRTCPrivate` } diff --git a/p2p/test/transport/rcmgr_test.go b/p2p/test/transport/rcmgr_test.go index 4bea604af8..d5e214742f 100644 --- a/p2p/test/transport/rcmgr_test.go +++ b/p2p/test/transport/rcmgr_test.go @@ -208,8 +208,6 @@ func testResourceManagerIsUsedWebRTCPrivateDialer(t *testing.T, tc TransportTest return connScope } - // TODO: Fix addresses on rcmgr calls - var calledSetPeer atomic.Bool // connScope is the scope to be used for the final connection over /webrtc to the peer connScope := mocknetwork.NewMockConnManagementScope(ctrl) @@ -247,7 +245,7 @@ func testResourceManagerIsUsedWebRTCPrivateDialer(t *testing.T, tc TransportTest }) // Now handle /webrtc connection establishment - rcmgr.EXPECT().OpenConnection(network.DirOutbound, true, gomock.Any()).Return(connScope, nil) + rcmgr.EXPECT().OpenConnection(network.DirOutbound, true, listener.Addrs()[0]).Return(connScope, nil) rcmgr.EXPECT().OpenStream(listener.ID(), gomock.Any()).AnyTimes().DoAndReturn(func(id peer.ID, dir network.Direction) (network.StreamManagementScope, error) { allStreamsDone.Add(1) streamScope := mocknetwork.NewMockStreamManagementScope(ctrl) @@ -392,7 +390,8 @@ func testResourceManagerIsUsedWebRTCPrivateListener(t *testing.T, tc TransportTe // Incoming circuit-v2 connection for signaling stream rcmgr.EXPECT().OpenConnection(network.DirInbound, gomock.Any(), gomock.Any()).Return(getRelayConnScope(), nil), // /webrtc connection - rcmgr.EXPECT().OpenConnection(network.DirInbound, true, gomock.Any()).Return(connScope, nil), + // The remote multiaddr is the same as our + rcmgr.EXPECT().OpenConnection(network.DirInbound, true, listener.Addrs()[0]).Return(connScope, nil), ) rcmgr.EXPECT().OpenStream(dialer.ID(), gomock.Any()).AnyTimes().DoAndReturn(func(id peer.ID, dir network.Direction) (network.StreamManagementScope, error) { allStreamsDone.Add(1) diff --git a/p2p/test/transport/transport_test.go b/p2p/test/transport/transport_test.go index 6127ede57c..64d74ecb67 100644 --- a/p2p/test/transport/transport_test.go +++ b/p2p/test/transport/transport_test.go @@ -712,7 +712,7 @@ func TestDiscoverPeerIDFromSecurityNegotiation(t *testing.T) { for _, tc := range transportsToTest { t.Run(tc.Name, func(t *testing.T) { if strings.Contains(tc.Name, "WebRTCPrivate") { - t.Skip("webrtcprivate needs different handling because of the relay required for connection setup") + t.Skip("inapplicable for webrtc private to private since the connection is established over an authenticated channel") } h1 := tc.HostGenerator(t, TransportTestCaseOpts{}) diff --git a/p2p/transport/webrtcprivate/listener.go b/p2p/transport/webrtcprivate/listener.go index 30d8cdb85a..5256506e2f 100644 --- a/p2p/transport/webrtcprivate/listener.go +++ b/p2p/transport/webrtcprivate/listener.go @@ -81,7 +81,8 @@ func (l *listener) handleSignalingStream(s network.Stream) { defer cancel() defer s.Close() - scope, err := l.transport.rcmgr.OpenConnection(network.DirInbound, true, ma.StringCast("/webrtc")) // we don't have a better remote adress right now + remoteAddr := s.Conn().RemoteMultiaddr().Encapsulate(WebRTCAddr) + scope, err := l.transport.rcmgr.OpenConnection(network.DirInbound, true, remoteAddr) if err != nil { s.Reset() log.Debug("failed to create connection scope:", err) @@ -110,7 +111,6 @@ func (l *listener) handleSignalingStream(s network.Stream) { if l.transport.gater != nil { localAddr := s.Conn().LocalMultiaddr().Encapsulate(WebRTCAddr) - remoteAddr := s.Conn().RemoteMultiaddr().Encapsulate(WebRTCAddr) if !l.transport.gater.InterceptAccept(&libp2pwebrtc.ConnMultiaddrs{Local: localAddr, Remote: remoteAddr}) { log.Debug("gater disallowed accepting connection from %s at %s", s.Conn().RemotePeer(), remoteAddr) s.Reset() @@ -141,9 +141,15 @@ func (l *listener) handleSignalingStream(s network.Stream) { } } -func (l *listener) setupConnection(ctx context.Context, s network.Stream, scope network.ConnManagementScope) (tpt.CapableConn, error) { +func (l *listener) setupConnection(ctx context.Context, s network.Stream, scope network.ConnManagementScope) (_ tpt.CapableConn, err error) { + var pc *webrtc.PeerConnection + defer func() { + if err != nil { + pc.Close() + } + }() - pc, err := l.transport.NewPeerConnection() + pc, err = l.transport.NewPeerConnection() if err != nil { err = fmt.Errorf("error creating a webrtc.PeerConnection: %w", err) log.Debug(err) @@ -296,11 +302,12 @@ func (l *listener) setupConnection(ctx context.Context, s network.Stream, scope } } - localAddr, remoteAddr, err := getConnectionAddresses(pc) + localAddr, err := getLocalConnectionAddress(pc) if err != nil { pc.Close() return nil, fmt.Errorf("failed to get connection addresses: %w", err) } + remoteAddr := s.Conn().RemoteMultiaddr().Encapsulate(WebRTCAddr) conn, err := libp2pwebrtc.NewWebRTCConnection( network.DirInbound, diff --git a/p2p/transport/webrtcprivate/transport.go b/p2p/transport/webrtcprivate/transport.go index a72e6d5eef..6bc27ca13b 100644 --- a/p2p/transport/webrtcprivate/transport.go +++ b/p2p/transport/webrtcprivate/transport.go @@ -157,7 +157,7 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp return nil, err } - c, err := t.dialWithScope(ctx, p, scope) + c, err := t.dialWithScope(ctx, p, scope, raddr) if err != nil { scope.Done() log.Debug(err) @@ -166,7 +166,7 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp return c, nil } -func (t *transport) dialWithScope(ctx context.Context, p peer.ID, scope network.ConnManagementScope) (tpt.CapableConn, error) { +func (t *transport) dialWithScope(ctx context.Context, p peer.ID, scope network.ConnManagementScope, raddr ma.Multiaddr) (tpt.CapableConn, error) { s, err := t.host.NewStream(ctx, p, SignalingProtocol) if err != nil { return nil, fmt.Errorf("error opening stream %s: %w", SignalingProtocol, err) @@ -190,7 +190,7 @@ func (t *transport) dialWithScope(ctx context.Context, p peer.ID, scope network. } s.SetDeadline(deadline) - conn, err := t.setupConnection(ctx, s, scope) + conn, err := t.setupConnection(ctx, s, scope, raddr) if err != nil { s.Reset() return nil, fmt.Errorf("error establishing webrtc.PeerConnection: %w", err) @@ -204,11 +204,17 @@ func (t *transport) dialWithScope(ctx context.Context, p peer.ID, scope network. return conn, nil } -func (t *transport) setupConnection(ctx context.Context, s network.Stream, scope network.ConnManagementScope) (tpt.CapableConn, error) { +func (t *transport) setupConnection(ctx context.Context, s network.Stream, scope network.ConnManagementScope, raddr ma.Multiaddr) (_ tpt.CapableConn, err error) { r := pbio.NewDelimitedReader(s, maxMsgSize) w := pbio.NewDelimitedWriter(s) - pc, err := t.NewPeerConnection() + var pc *webrtc.PeerConnection + defer func() { + if err != nil { + pc.Close() + } + }() + pc, err = t.NewPeerConnection() if err != nil { return nil, fmt.Errorf("failed to create webrtc.PeerConnection: %w", err) } @@ -370,7 +376,7 @@ func (t *transport) setupConnection(ctx context.Context, s network.Stream, scope case webrtc.PeerConnectionStateConnected: } } - localAddr, remoteAddr, err := getConnectionAddresses(pc) + localAddr, err := getLocalConnectionAddress(pc) if err != nil { pc.Close() return nil, fmt.Errorf("failed to get connection addresses: %w", err) @@ -385,7 +391,7 @@ func (t *transport) setupConnection(ctx context.Context, s network.Stream, scope localAddr, s.Conn().RemotePeer(), t.host.Network().Peerstore().PubKey(s.Conn().RemotePeer()), // we have the pubkey from the relayed connection - remoteAddr, + raddr, dataChannelQueue, ) if err != nil { @@ -469,33 +475,27 @@ func getRelayAddr(addr ma.Multiaddr) ma.Multiaddr { return first.Encapsulate(rest) } -// getConnectionAddresses provides multiaddresses on the two sides of the connection pc -func getConnectionAddresses(pc *webrtc.PeerConnection) (local ma.Multiaddr, remote ma.Multiaddr, err error) { +// getLocalConnectionAddress returns the local connection multiaddr +func getLocalConnectionAddress(pc *webrtc.PeerConnection) (local ma.Multiaddr, err error) { if pc.SCTP() == nil { - return nil, nil, errors.New("no sctp transport") + return nil, errors.New("no sctp transport") } if pc.SCTP().Transport() == nil { - return nil, nil, errors.New("no dtls transport") + return nil, errors.New("no dtls transport") } if pc.SCTP().Transport().ICETransport() == nil { - return nil, nil, errors.New("no ice transport") + return nil, errors.New("no ice transport") } cp, err := pc.SCTP().Transport().ICETransport().GetSelectedCandidatePair() if cp == nil || err != nil { - return nil, nil, fmt.Errorf("invalid candidate pair %s: %w", cp, err) + return nil, fmt.Errorf("invalid candidate pair %s: %w", cp, err) } localAddr, err := manet.FromNetAddr(&net.UDPAddr{IP: net.ParseIP(cp.Local.Address), Port: int(cp.Local.Port)}) if err != nil { - return nil, nil, fmt.Errorf("failed to infer local address from candidate %s: %w", cp, err) + return nil, fmt.Errorf("failed to infer local address from candidate %s: %w", cp, err) } localAddr = localAddr.Encapsulate(WebRTCAddr) - remoteAddr, err := manet.FromNetAddr(&net.UDPAddr{IP: net.ParseIP(cp.Remote.Address), Port: int(cp.Remote.Port)}) - if err != nil { - return nil, nil, fmt.Errorf("failed to infer remote address from candidate %s: %w", cp, err) - } - remoteAddr = remoteAddr.Encapsulate(WebRTCAddr) - - return localAddr, remoteAddr, nil + return localAddr, nil } diff --git a/p2p/transport/webrtcprivate/transport_test.go b/p2p/transport/webrtcprivate/transport_test.go index f9d889cee2..cad2cfbaaa 100644 --- a/p2p/transport/webrtcprivate/transport_test.go +++ b/p2p/transport/webrtcprivate/transport_test.go @@ -56,7 +56,7 @@ func newWebRTCHost(t *testing.T) *webrtcHost { } func newRelayedHost(t *testing.T) *relayedHost { - rh := blankhost.NewBlankHost(swarmt.GenSwarm(t)) + rh := blankhost.NewBlankHost(swarmt.GenSwarm(t, swarmt.OptDisableTCP)) rr := relay.DefaultResources() rr.MaxCircuits = 100 _, err := relay.New(rh, relay.WithResources(rr)) @@ -135,9 +135,9 @@ func TestConnectionProperties(t *testing.T) { require.NoError(t, err) } testAddr(ca.LocalMultiaddr()) - testAddr(ca.RemoteMultiaddr()) testAddr(cb.LocalMultiaddr()) - testAddr(cb.RemoteMultiaddr()) + require.Equal(t, ca.RemoteMultiaddr(), b.Addr, "%s\n%s", ca.RemoteMultiaddr(), b.Addr) + require.Equal(t, cb.RemoteMultiaddr(), b.Addr, "%s\n%s", cb.RemoteMultiaddr(), b.Addr) }) t.Run("ConnectionState", func(t *testing.T) {