Skip to content

Commit

Permalink
webrtcprivate: set relay address on the remote side of conn
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed Oct 23, 2023
1 parent 51981cd commit e71b25e
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 65 deletions.
22 changes: 5 additions & 17 deletions p2p/test/swarm/swarm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -276,39 +276,27 @@ func TestDialPeerWebRTC(t *testing.T) {
_, err = client.Reserve(context.Background(), h2, relay1info)
require.NoError(t, err)

webrtcAddr := ma.StringCast(relay1info.Addrs[0].String() + "/p2p/" + relay1info.ID.String() + "/p2p-circuit/webrtc/p2p/" + h2.ID().String())
relayAddrs := ma.StringCast(relay1info.Addrs[0].String() + "/p2p/" + relay1info.ID.String() + "/p2p-circuit/p2p/" + h2.ID().String())
webrtcAddr := ma.StringCast(relay1info.Addrs[0].String() + "/p2p/" + relay1info.ID.String() + "/p2p-circuit/webrtc")
relayAddrs := ma.StringCast(relay1info.Addrs[0].String() + "/p2p/" + relay1info.ID.String() + "/p2p-circuit/")

h1.Peerstore().AddAddrs(h2.ID(), []ma.Multiaddr{webrtcAddr, relayAddrs}, peerstore.TempAddrTTL)

// swarm.DialPeer should connect over transient connections
conn1, err := h1.Network().DialPeer(context.Background(), h2.ID())
require.NoError(t, err)
require.NotNil(t, conn1)
require.Condition(t, func() bool {
_, err1 := conn1.RemoteMultiaddr().ValueForProtocol(ma.P_CIRCUIT)
_, err2 := conn1.RemoteMultiaddr().ValueForProtocol(ma.P_WEBRTC)
return err1 == nil && err2 != nil
})
require.Equal(t, conn1.RemoteMultiaddr(), relayAddrs)

// should connect to webrtc address
ctx := network.WithForceDirectDial(context.Background(), "test")
conn, err := h1.Network().DialPeer(ctx, h2.ID())
require.NoError(t, err)
require.NotNil(t, conn)
require.Condition(t, func() bool {
_, err1 := conn.RemoteMultiaddr().ValueForProtocol(ma.P_CIRCUIT)
_, err2 := conn.RemoteMultiaddr().ValueForProtocol(ma.P_WEBRTC)
return err1 != nil && err2 == nil
})
require.Equal(t, conn.RemoteMultiaddr(), webrtcAddr)

done := make(chan struct{})
h2.SetStreamHandler("test-addr", func(s network.Stream) {
s.Conn().LocalMultiaddr()
_, err1 := conn.RemoteMultiaddr().ValueForProtocol(ma.P_CIRCUIT)
assert.Error(t, err1)
_, err2 := conn.RemoteMultiaddr().ValueForProtocol(ma.P_WEBRTC)
assert.NoError(t, err2)
require.Equal(t, conn.RemoteMultiaddr(), webrtcAddr)
s.Reset()
close(done)
})
Expand Down
46 changes: 32 additions & 14 deletions p2p/test/transport/gating_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,16 @@ func TestInterceptSecuredOutgoing(t *testing.T) {
connGater.EXPECT().InterceptUpgraded(gomock.Any()).AnyTimes().Return(true, control.DisconnectReason(0)),

// webrtcprivate setup complete
// TODO: fix addresses on both sides of the /webrtc connection
connGater.EXPECT().InterceptSecured(network.DirOutbound, h2.ID(), gomock.Any()),
connGater.EXPECT().InterceptSecured(network.DirOutbound, h2.ID(), gomock.Any()).Do(func(_, _ any, c network.ConnMultiaddrs) {
require.Equal(t, h2.Addrs()[0], c.RemoteMultiaddr())
}),
)
// As we are using trickle ICE, some candidates are handled
// as prflx candidates. The default wait time for prflx
// candidates is 1 second. Bump the timeout to handle those
// cases.
ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
// force a direct connection
ctx = network.WithForceDirectDial(ctx, "integration test /webrtc")
} else {
Expand Down Expand Up @@ -181,13 +188,19 @@ func TestInterceptUpgradedOutgoing(t *testing.T) {
connGater.EXPECT().InterceptUpgraded(gomock.Any()).Return(true, control.DisconnectReason(0)),

// webrtcprivate setup complete
// TODO: fix addresses on both sides of the /webrtc connection
connGater.EXPECT().InterceptSecured(network.DirOutbound, h2.ID(), gomock.Any()).Return(true),
connGater.EXPECT().InterceptUpgraded(gomock.Any()).Do(func(c network.Conn) {
require.Equal(t, h1.ID(), c.LocalPeer())
require.Equal(t, h2.ID(), c.RemotePeer())
require.Equal(t, h2.Addrs()[0], c.RemoteMultiaddr())
}),
)
// As we are using trickle ICE, some candidates are handled
// as prflx candidates. The default wait time for prflx
// candidates is 1 second. Bump the timeout to handle those
// cases.
ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
// force a direct connection
ctx = network.WithForceDirectDial(ctx, "integration test /webrtc")
} else {
Expand Down Expand Up @@ -275,25 +288,25 @@ func testInterceptAcceptIncomingWebRTCPrivate(t *testing.T, tc TransportTestCase
defer h2.Close()
require.Len(t, h2.Addrs(), 1)

ctx, cancel := context.WithTimeout(context.Background(), time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
// relayed connection for incoming stream
connGater.EXPECT().InterceptAccept(gomock.Any()).Return(true)
connGater.EXPECT().InterceptSecured(gomock.Any(), gomock.Any(), gomock.Any()).Return(true)
connGater.EXPECT().InterceptUpgraded(gomock.Any()).Return(true, control.DisconnectReason(0))
// webrtc connection accept
// TODO: Fix webrtc addresses on both sides.
connGater.EXPECT().InterceptAccept(gomock.Any())
connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(c network.ConnMultiaddrs) {
// On an incoming webrtc connection, the remote address will be the same as the listen address.
// This is simlar to how addresses are setup on circuit-v2 connections
require.Equal(t, h2.Addrs()[0], c.RemoteMultiaddr(), h2.Addrs(), c.RemoteMultiaddr())
})

// The basic host dials the first connection.
h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), time.Hour)
_, err := h1.NewStream(ctx, h2.ID(), protocol.TestingID)
require.Error(t, err)
if _, err := h2.Addrs()[0].ValueForProtocol(ma.P_WEBRTC_DIRECT); err != nil {
// WebRTC rejects connection attempt before an error can be sent to the client.
// This means that the connection attempt will time out.
require.NotErrorIs(t, err, context.DeadlineExceeded)
}
// Do not check that error is Context Deadline Exceeded
// See details at the end of `testInterceptSecuredIncomingWebRTCPrivate`
}

func TestInterceptSecuredIncoming(t *testing.T) {
Expand Down Expand Up @@ -358,12 +371,14 @@ func testInterceptSecuredIncomingWebRTCPrivate(t *testing.T, tc TransportTestCas

require.Len(t, h2.Addrs(), 1)

ctx, cancel := context.WithTimeout(context.Background(), time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
ctx = network.WithForceDirectDial(ctx, "transport integration test /webrtc")
gomock.InOrder(
connGater.EXPECT().InterceptAccept(gomock.Any()).Return(true),
connGater.EXPECT().InterceptSecured(network.DirInbound, h1.ID(), gomock.Any()),
connGater.EXPECT().InterceptSecured(network.DirInbound, h1.ID(), gomock.Any()).Do(func(_, _ any, c network.ConnMultiaddrs) {
require.Equal(t, h2.Addrs()[0], c.RemoteMultiaddr())
}),
)

h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), time.Hour)
Expand Down Expand Up @@ -447,16 +462,19 @@ func testInterceptUpgradeIncomingWebRTCPrivate(t *testing.T, tc TransportTestCas
connGater.EXPECT().InterceptAccept(gomock.Any()).Return(true),
connGater.EXPECT().InterceptSecured(network.DirInbound, h1.ID(), gomock.Any()).Return(true),
connGater.EXPECT().InterceptUpgraded(gomock.Any()).Do(func(c network.Conn) {
require.Equal(t, h2.Addrs()[0], c.RemoteMultiaddr())
require.Equal(t, h1.ID(), c.RemotePeer())
require.Equal(t, h2.ID(), c.LocalPeer())
}),
)

ctx, cancel := context.WithTimeout(context.Background(), time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
ctx = network.WithForceDirectDial(ctx, "transport integration test /webrtc")

h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), time.Hour)
_, err := h1.NewStream(ctx, h2.ID(), protocol.TestingID)
require.Error(t, err)
// Do not check that error is Context Deadline Exceeded
// See details at the end of `testInterceptSecuredIncomingWebRTCPrivate`
}
7 changes: 3 additions & 4 deletions p2p/test/transport/rcmgr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,6 @@ func testResourceManagerIsUsedWebRTCPrivateDialer(t *testing.T, tc TransportTest
return connScope
}

// TODO: Fix addresses on rcmgr calls

var calledSetPeer atomic.Bool
// connScope is the scope to be used for the final connection over /webrtc to the peer
connScope := mocknetwork.NewMockConnManagementScope(ctrl)
Expand Down Expand Up @@ -247,7 +245,7 @@ func testResourceManagerIsUsedWebRTCPrivateDialer(t *testing.T, tc TransportTest
})

// Now handle /webrtc connection establishment
rcmgr.EXPECT().OpenConnection(network.DirOutbound, true, gomock.Any()).Return(connScope, nil)
rcmgr.EXPECT().OpenConnection(network.DirOutbound, true, listener.Addrs()[0]).Return(connScope, nil)
rcmgr.EXPECT().OpenStream(listener.ID(), gomock.Any()).AnyTimes().DoAndReturn(func(id peer.ID, dir network.Direction) (network.StreamManagementScope, error) {
allStreamsDone.Add(1)
streamScope := mocknetwork.NewMockStreamManagementScope(ctrl)
Expand Down Expand Up @@ -392,7 +390,8 @@ func testResourceManagerIsUsedWebRTCPrivateListener(t *testing.T, tc TransportTe
// Incoming circuit-v2 connection for signaling stream
rcmgr.EXPECT().OpenConnection(network.DirInbound, gomock.Any(), gomock.Any()).Return(getRelayConnScope(), nil),
// /webrtc connection
rcmgr.EXPECT().OpenConnection(network.DirInbound, true, gomock.Any()).Return(connScope, nil),
// The remote multiaddr is the same as our
rcmgr.EXPECT().OpenConnection(network.DirInbound, true, listener.Addrs()[0]).Return(connScope, nil),
)
rcmgr.EXPECT().OpenStream(dialer.ID(), gomock.Any()).AnyTimes().DoAndReturn(func(id peer.ID, dir network.Direction) (network.StreamManagementScope, error) {
allStreamsDone.Add(1)
Expand Down
2 changes: 1 addition & 1 deletion p2p/test/transport/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@ func TestDiscoverPeerIDFromSecurityNegotiation(t *testing.T) {
for _, tc := range transportsToTest {
t.Run(tc.Name, func(t *testing.T) {
if strings.Contains(tc.Name, "WebRTCPrivate") {
t.Skip("webrtcprivate needs different handling because of the relay required for connection setup")
t.Skip("inapplicable for webrtc private to private since the connection is established over an authenticated channel")
}

h1 := tc.HostGenerator(t, TransportTestCaseOpts{})
Expand Down
17 changes: 12 additions & 5 deletions p2p/transport/webrtcprivate/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ func (l *listener) handleSignalingStream(s network.Stream) {
defer cancel()
defer s.Close()

scope, err := l.transport.rcmgr.OpenConnection(network.DirInbound, true, ma.StringCast("/webrtc")) // we don't have a better remote adress right now
remoteAddr := s.Conn().RemoteMultiaddr().Encapsulate(WebRTCAddr)
scope, err := l.transport.rcmgr.OpenConnection(network.DirInbound, true, remoteAddr)
if err != nil {
s.Reset()
log.Debug("failed to create connection scope:", err)
Expand Down Expand Up @@ -110,7 +111,6 @@ func (l *listener) handleSignalingStream(s network.Stream) {

if l.transport.gater != nil {
localAddr := s.Conn().LocalMultiaddr().Encapsulate(WebRTCAddr)
remoteAddr := s.Conn().RemoteMultiaddr().Encapsulate(WebRTCAddr)
if !l.transport.gater.InterceptAccept(&libp2pwebrtc.ConnMultiaddrs{Local: localAddr, Remote: remoteAddr}) {
log.Debug("gater disallowed accepting connection from %s at %s", s.Conn().RemotePeer(), remoteAddr)
s.Reset()
Expand Down Expand Up @@ -141,9 +141,15 @@ func (l *listener) handleSignalingStream(s network.Stream) {
}
}

func (l *listener) setupConnection(ctx context.Context, s network.Stream, scope network.ConnManagementScope) (tpt.CapableConn, error) {
func (l *listener) setupConnection(ctx context.Context, s network.Stream, scope network.ConnManagementScope) (_ tpt.CapableConn, err error) {
var pc *webrtc.PeerConnection
defer func() {
if err != nil {
pc.Close()
}
}()

pc, err := l.transport.NewPeerConnection()
pc, err = l.transport.NewPeerConnection()
if err != nil {
err = fmt.Errorf("error creating a webrtc.PeerConnection: %w", err)
log.Debug(err)
Expand Down Expand Up @@ -296,11 +302,12 @@ func (l *listener) setupConnection(ctx context.Context, s network.Stream, scope
}
}

localAddr, remoteAddr, err := getConnectionAddresses(pc)
localAddr, err := getLocalConnectionAddress(pc)
if err != nil {
pc.Close()
return nil, fmt.Errorf("failed to get connection addresses: %w", err)
}
remoteAddr := s.Conn().RemoteMultiaddr().Encapsulate(WebRTCAddr)

conn, err := libp2pwebrtc.NewWebRTCConnection(
network.DirInbound,
Expand Down
42 changes: 21 additions & 21 deletions p2p/transport/webrtcprivate/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp
return nil, err
}

c, err := t.dialWithScope(ctx, p, scope)
c, err := t.dialWithScope(ctx, p, scope, raddr)
if err != nil {
scope.Done()
log.Debug(err)
Expand All @@ -166,7 +166,7 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp
return c, nil
}

func (t *transport) dialWithScope(ctx context.Context, p peer.ID, scope network.ConnManagementScope) (tpt.CapableConn, error) {
func (t *transport) dialWithScope(ctx context.Context, p peer.ID, scope network.ConnManagementScope, raddr ma.Multiaddr) (tpt.CapableConn, error) {
s, err := t.host.NewStream(ctx, p, SignalingProtocol)
if err != nil {
return nil, fmt.Errorf("error opening stream %s: %w", SignalingProtocol, err)
Expand All @@ -190,7 +190,7 @@ func (t *transport) dialWithScope(ctx context.Context, p peer.ID, scope network.
}
s.SetDeadline(deadline)

conn, err := t.setupConnection(ctx, s, scope)
conn, err := t.setupConnection(ctx, s, scope, raddr)
if err != nil {
s.Reset()
return nil, fmt.Errorf("error establishing webrtc.PeerConnection: %w", err)
Expand All @@ -204,11 +204,17 @@ func (t *transport) dialWithScope(ctx context.Context, p peer.ID, scope network.
return conn, nil
}

func (t *transport) setupConnection(ctx context.Context, s network.Stream, scope network.ConnManagementScope) (tpt.CapableConn, error) {
func (t *transport) setupConnection(ctx context.Context, s network.Stream, scope network.ConnManagementScope, raddr ma.Multiaddr) (_ tpt.CapableConn, err error) {
r := pbio.NewDelimitedReader(s, maxMsgSize)
w := pbio.NewDelimitedWriter(s)

pc, err := t.NewPeerConnection()
var pc *webrtc.PeerConnection
defer func() {
if err != nil {
pc.Close()
}
}()
pc, err = t.NewPeerConnection()
if err != nil {
return nil, fmt.Errorf("failed to create webrtc.PeerConnection: %w", err)
}
Expand Down Expand Up @@ -370,7 +376,7 @@ func (t *transport) setupConnection(ctx context.Context, s network.Stream, scope
case webrtc.PeerConnectionStateConnected:
}
}
localAddr, remoteAddr, err := getConnectionAddresses(pc)
localAddr, err := getLocalConnectionAddress(pc)
if err != nil {
pc.Close()
return nil, fmt.Errorf("failed to get connection addresses: %w", err)
Expand All @@ -385,7 +391,7 @@ func (t *transport) setupConnection(ctx context.Context, s network.Stream, scope
localAddr,
s.Conn().RemotePeer(),
t.host.Network().Peerstore().PubKey(s.Conn().RemotePeer()), // we have the pubkey from the relayed connection
remoteAddr,
raddr,
dataChannelQueue,
)
if err != nil {
Expand Down Expand Up @@ -469,33 +475,27 @@ func getRelayAddr(addr ma.Multiaddr) ma.Multiaddr {
return first.Encapsulate(rest)
}

// getConnectionAddresses provides multiaddresses on the two sides of the connection pc
func getConnectionAddresses(pc *webrtc.PeerConnection) (local ma.Multiaddr, remote ma.Multiaddr, err error) {
// getLocalConnectionAddress returns the local connection multiaddr
func getLocalConnectionAddress(pc *webrtc.PeerConnection) (local ma.Multiaddr, err error) {
if pc.SCTP() == nil {
return nil, nil, errors.New("no sctp transport")
return nil, errors.New("no sctp transport")
}
if pc.SCTP().Transport() == nil {
return nil, nil, errors.New("no dtls transport")
return nil, errors.New("no dtls transport")
}
if pc.SCTP().Transport().ICETransport() == nil {
return nil, nil, errors.New("no ice transport")
return nil, errors.New("no ice transport")
}
cp, err := pc.SCTP().Transport().ICETransport().GetSelectedCandidatePair()
if cp == nil || err != nil {
return nil, nil, fmt.Errorf("invalid candidate pair %s: %w", cp, err)
return nil, fmt.Errorf("invalid candidate pair %s: %w", cp, err)
}

localAddr, err := manet.FromNetAddr(&net.UDPAddr{IP: net.ParseIP(cp.Local.Address), Port: int(cp.Local.Port)})
if err != nil {
return nil, nil, fmt.Errorf("failed to infer local address from candidate %s: %w", cp, err)
return nil, fmt.Errorf("failed to infer local address from candidate %s: %w", cp, err)
}
localAddr = localAddr.Encapsulate(WebRTCAddr)

remoteAddr, err := manet.FromNetAddr(&net.UDPAddr{IP: net.ParseIP(cp.Remote.Address), Port: int(cp.Remote.Port)})
if err != nil {
return nil, nil, fmt.Errorf("failed to infer remote address from candidate %s: %w", cp, err)
}
remoteAddr = remoteAddr.Encapsulate(WebRTCAddr)

return localAddr, remoteAddr, nil
return localAddr, nil
}
6 changes: 3 additions & 3 deletions p2p/transport/webrtcprivate/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func newWebRTCHost(t *testing.T) *webrtcHost {
}

func newRelayedHost(t *testing.T) *relayedHost {
rh := blankhost.NewBlankHost(swarmt.GenSwarm(t))
rh := blankhost.NewBlankHost(swarmt.GenSwarm(t, swarmt.OptDisableTCP))
rr := relay.DefaultResources()
rr.MaxCircuits = 100
_, err := relay.New(rh, relay.WithResources(rr))
Expand Down Expand Up @@ -135,9 +135,9 @@ func TestConnectionProperties(t *testing.T) {
require.NoError(t, err)
}
testAddr(ca.LocalMultiaddr())
testAddr(ca.RemoteMultiaddr())
testAddr(cb.LocalMultiaddr())
testAddr(cb.RemoteMultiaddr())
require.Equal(t, ca.RemoteMultiaddr(), b.Addr, "%s\n%s", ca.RemoteMultiaddr(), b.Addr)
require.Equal(t, cb.RemoteMultiaddr(), b.Addr, "%s\n%s", cb.RemoteMultiaddr(), b.Addr)
})

t.Run("ConnectionState", func(t *testing.T) {
Expand Down

0 comments on commit e71b25e

Please sign in to comment.