Skip to content

Commit

Permalink
webrtc: reuse port when dialing
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed Oct 6, 2023
1 parent 92c8f94 commit a8cbee8
Show file tree
Hide file tree
Showing 5 changed files with 286 additions and 11 deletions.
14 changes: 5 additions & 9 deletions p2p/transport/webrtc/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ type listener struct {

var _ tpt.Listener = &listener{}

func newListener(transport *WebRTCTransport, laddr ma.Multiaddr, socket net.PacketConn, config webrtc.Configuration) (*listener, error) {
func newListener(transport *WebRTCTransport, laddr ma.Multiaddr, socket net.PacketConn, config webrtc.Configuration, mux *udpmux.UDPMux) (*listener, error) {
localFingerprints, err := config.Certificates[0].GetFingerprints()
if err != nil {
return nil, err
Expand Down Expand Up @@ -91,9 +91,7 @@ func newListener(transport *WebRTCTransport, laddr ma.Multiaddr, socket net.Pack
}

l.ctx, l.cancel = context.WithCancel(context.Background())
mux := udpmux.NewUDPMux(socket)
l.mux = mux
mux.Start()

go l.listen()

Expand Down Expand Up @@ -284,7 +282,7 @@ func (l *listener) setupConnection(
localMultiaddrWithoutCerthash,
"", // remotePeer
nil, // remoteKey
remoteMultiaddr,
remoteMultiaddr.Encapsulate(webrtcComponent),
)
if err != nil {
return nil, err
Expand Down Expand Up @@ -321,11 +319,9 @@ func (l *listener) Accept() (tpt.CapableConn, error) {
}

func (l *listener) Close() error {
select {
case <-l.ctx.Done():
default:
l.cancel()
}
l.cancel()
l.mux.Close()
l.transport.RemoveMux(l.mux)
return nil
}

Expand Down
97 changes: 97 additions & 0 deletions p2p/transport/webrtc/reuseudpmux.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package libp2pwebrtc

import (
"fmt"
"net"
"sync"

"github.com/libp2p/go-libp2p/p2p/transport/webrtc/udpmux"
"github.com/libp2p/go-netroute"
)

// reuseUDPMux provides ability to reuse listening udpMux for dialing. This helps with address
// discovery for nodes that don't have access to their public ip address
type reuseUDPMux struct {
mu sync.RWMutex
loopback map[int]*udpmux.UDPMux
specific map[string]map[int]*udpmux.UDPMux // IP.String() => Port => Mux
unspecified map[int]*udpmux.UDPMux
}

// Put stores mux for reuse later in Get calls.
func (r *reuseUDPMux) Put(mux *udpmux.UDPMux) error {
r.mu.Lock()
defer r.mu.Unlock()
for _, a := range mux.GetListenAddresses() {
udpAddr, err := net.ResolveUDPAddr(a.Network(), a.String())
if err != nil {
return fmt.Errorf("udpmux ResolveUDPAddr failed for %s: %w", a, err)
}
if udpAddr.IP.IsLoopback() {
r.loopback[udpAddr.Port] = mux
continue
}
if udpAddr.IP.IsUnspecified() {
r.unspecified[udpAddr.Port] = mux
continue
}
if r.specific[udpAddr.IP.String()] == nil {
r.specific[udpAddr.IP.String()] = make(map[int]*udpmux.UDPMux)
}
r.specific[udpAddr.IP.String()][udpAddr.Port] = mux
}
return nil
}

// Get retrieves a mux capable of dialing addr. Returns nil if no capable mux is present. If
// multiple muxes capable of dialing addr are available, it returns one arbitrarily
func (r *reuseUDPMux) Get(addr *net.UDPAddr) *udpmux.UDPMux {
r.mu.RLock()
defer r.mu.RUnlock()
if addr.IP.IsLoopback() {
for _, m := range r.loopback {
return m
}
}
if len(r.specific) > 0 {
if router, err := netroute.New(); err == nil {
if _, _, preferredSrc, err := router.Route(addr.IP); err == nil {
if len(r.specific[preferredSrc.String()]) != 0 {
for _, m := range r.specific[preferredSrc.String()] {
return m
}
}
}
}
}
for _, m := range r.unspecified {
return m
}
return nil
}

// Delete removes a mux from the reuse pool.
func (r *reuseUDPMux) Delete(mux *udpmux.UDPMux) {
r.mu.Lock()
defer r.mu.Unlock()
for p, m := range r.loopback {
if mux == m {
delete(r.loopback, p)
}
}
for p, m := range r.unspecified {
if mux == m {
delete(r.unspecified, p)
}
}
for ip, mp := range r.specific {
for p, m := range mp {
if m == mux {
delete(mp, p)
}
}
if len(mp) == 0 {
delete(r.specific, ip)
}
}
}
80 changes: 80 additions & 0 deletions p2p/transport/webrtc/reuseudpmux_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package libp2pwebrtc

import (
"net"
"testing"

"github.com/libp2p/go-libp2p/p2p/transport/webrtc/udpmux"
"github.com/stretchr/testify/require"
)

func newReuseUDPMux(t *testing.T) reuseUDPMux {
return reuseUDPMux{
loopback: make(map[int]*udpmux.UDPMux),
specific: make(map[string]map[int]*udpmux.UDPMux),
unspecified: make(map[int]*udpmux.UDPMux),
}
}

func udpAddr(t *testing.T, s string) *net.UDPAddr {
a, err := net.ResolveUDPAddr("udp", s)
require.NoError(t, err)
return a
}

func TestReuseUDPMuxLoopback(t *testing.T) {
socket, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})
require.NoError(t, err)
defer socket.Close()
r := newReuseUDPMux(t)

mux := r.Get(udpAddr(t, "127.0.0.1:1"))
require.Nil(t, mux)

originalMux := udpmux.NewUDPMux(socket)
err = r.Put(originalMux)
require.NoError(t, err)

mux = r.Get(udpAddr(t, "127.0.0.1:1"))
require.Equal(t, originalMux, mux)

mux = r.Get(udpAddr(t, "1.2.3.4:1"))
require.Nil(t, mux)

r.Delete(originalMux)
mux = r.Get(udpAddr(t, "127.0.0.1:1"))
require.Nil(t, mux)
}

func TestReuseUDPMuxUnspecified(t *testing.T) {
s1, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})
require.NoError(t, err)
defer s1.Close()

s2, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
require.NoError(t, err)
defer s2.Close()

r := newReuseUDPMux(t)

loMux := udpmux.NewUDPMux(s1)
err = r.Put(loMux)
require.NoError(t, err)

mux := r.Get(udpAddr(t, "1.2.3.4:1"))
require.Nil(t, mux)

unMux := udpmux.NewUDPMux(s2)
err = r.Put(unMux)
require.NoError(t, err)

mux = r.Get(udpAddr(t, "127.0.0.1:1"))
require.Equal(t, loMux, mux)

mux = r.Get(udpAddr(t, "1.2.3.4:1"))
require.Equal(t, unMux, mux)

r.Delete(loMux)
mux = r.Get(udpAddr(t, "127.0.0.1:1"))
require.Equal(t, unMux, mux)
}
59 changes: 57 additions & 2 deletions p2p/transport/webrtc/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
"github.com/libp2p/go-libp2p/core/sec"
tpt "github.com/libp2p/go-libp2p/core/transport"
"github.com/libp2p/go-libp2p/p2p/security/noise"
"github.com/libp2p/go-libp2p/p2p/transport/webrtc/udpmux"

logging "github.com/ipfs/go-log/v2"
ma "github.com/multiformats/go-multiaddr"
Expand Down Expand Up @@ -94,6 +95,9 @@ type WebRTCTransport struct {

// in-flight connections
maxInFlightConnections uint32

v4Reuse reuseUDPMux
v6Reuse reuseUDPMux
}

var _ tpt.Transport = &WebRTCTransport{}
Expand Down Expand Up @@ -156,6 +160,16 @@ func New(privKey ic.PrivKey, psk pnet.PSK, gater connmgr.ConnectionGater, rcmgr
},

maxInFlightConnections: DefaultMaxInFlightConnections,
v4Reuse: reuseUDPMux{
loopback: make(map[int]*udpmux.UDPMux),
specific: make(map[string]map[int]*udpmux.UDPMux),
unspecified: make(map[int]*udpmux.UDPMux),
},
v6Reuse: reuseUDPMux{
loopback: make(map[int]*udpmux.UDPMux),
specific: make(map[string]map[int]*udpmux.UDPMux),
unspecified: make(map[int]*udpmux.UDPMux),
},
}
for _, opt := range opts {
if err := opt(transport); err != nil {
Expand Down Expand Up @@ -197,15 +211,15 @@ func (t *WebRTCTransport) Listen(addr ma.Multiaddr) (tpt.Listener, error) {
return nil, fmt.Errorf("listen on udp: %w", err)
}

listener, err := t.listenSocket(socket)
listener, err := t.listenSocket(socket, nw)
if err != nil {
socket.Close()
return nil, err
}
return listener, nil
}

func (t *WebRTCTransport) listenSocket(socket *net.UDPConn) (tpt.Listener, error) {
func (t *WebRTCTransport) listenSocket(socket *net.UDPConn, network string) (tpt.Listener, error) {
listenerMultiaddr, err := manet.FromNetAddr(socket.LocalAddr())
if err != nil {
return nil, err
Expand All @@ -225,13 +239,20 @@ func (t *WebRTCTransport) listenSocket(socket *net.UDPConn) (tpt.Listener, error
if err != nil {
return nil, err
}

mux, err := t.newMux(socket, network)
if err != nil {
return nil, err
}

listenerMultiaddr = listenerMultiaddr.Encapsulate(webrtcComponent).Encapsulate(certComp)

return newListener(
t,
listenerMultiaddr,
socket,
t.webrtcConfig,
mux,
)
}

Expand Down Expand Up @@ -306,6 +327,17 @@ func (t *WebRTCTransport) dial(ctx context.Context, scope network.ConnManagement
t.peerConnectionTimeouts.Failed,
t.peerConnectionTimeouts.Keepalive,
)

if rnw == "udp4" {
if mux := t.v4Reuse.Get(raddr); mux != nil {
settingEngine.SetICEUDPMux(mux)
}
} else {
if mux := t.v6Reuse.Get(raddr); mux != nil {
settingEngine.SetICEUDPMux(mux)
}
}

// By default, webrtc will not collect candidates on the loopback address.
// This is disallowed in the ICE specification. However, implementations
// do not strictly follow this, for eg. Chrome gathers TCP loopback candidates.
Expand Down Expand Up @@ -387,6 +419,7 @@ func (t *WebRTCTransport) dial(ctx context.Context, scope network.ConnManagement
if err != nil {
return nil, err
}
localAddr = localAddr.Encapsulate(webrtcComponent)

remoteMultiaddrWithoutCerthash, _ := ma.SplitFunc(remoteMultiaddr, func(c ma.Component) bool { return c.Protocol().Code == ma.P_CERTHASH })

Expand Down Expand Up @@ -521,6 +554,28 @@ func (t *WebRTCTransport) noiseHandshake(ctx context.Context, pc *webrtc.PeerCon
return secureConn.RemotePublicKey(), nil
}

func (t *WebRTCTransport) newMux(socket *net.UDPConn, network string) (*udpmux.UDPMux, error) {
mux := udpmux.NewUDPMux(socket)
if network == "udp4" {
if err := t.v4Reuse.Put(mux); err != nil {
t.v4Reuse.Delete(mux)
return nil, err
}
} else {
if err := t.v6Reuse.Put(mux); err != nil {
t.v6Reuse.Delete(mux)
return nil, err
}
}
mux.Start()
return mux, nil
}

func (t *WebRTCTransport) RemoveMux(mux *udpmux.UDPMux) {
t.v4Reuse.Delete(mux)
t.v6Reuse.Delete(mux)
}

type fakeStreamConn struct{ *stream }

func (fakeStreamConn) LocalAddr() net.Addr { return nil }
Expand Down
Loading

0 comments on commit a8cbee8

Please sign in to comment.