Skip to content

Commit

Permalink
feat: Implement Custom TCP Dialers (#3166)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoPolo authored Feb 3, 2025
1 parent fde0e3a commit 9584246
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 0 deletions.
28 changes: 28 additions & 0 deletions libp2p_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -784,3 +784,31 @@ func TestSharedTCPAddr(t *testing.T) {
)
require.ErrorContains(t, err, "cannot use shared TCP listener with PSK")
}

func TestCustomTCPDialer(t *testing.T) {
expectedErr := errors.New("custom dialer called, but not implemented")
customDialer := func(raddr ma.Multiaddr) (tcp.ContextDialer, error) {
// Normally a user would implement this by returning a custom dialer
// Here, we just test that this is called.
return nil, expectedErr
}

h, err := New(
Transport(tcp.NewTCPTransport, tcp.WithDialerForAddr(customDialer)),
)
require.NoError(t, err)
defer h.Close()

var randID peer.ID
priv, _, err := crypto.GenerateKeyPair(crypto.Ed25519, 256)
require.NoError(t, err)
randID, err = peer.IDFromPrivateKey(priv)
require.NoError(t, err)

err = h.Connect(context.Background(), peer.AddrInfo{
ID: randID,
// This won't actually be dialed since we return an error above
Addrs: []ma.Multiaddr{ma.StringCast("/ip4/1.2.3.4/tcp/4")},
})
require.ErrorContains(t, err, expectedErr.Error())
}
57 changes: 57 additions & 0 deletions p2p/transport/tcp/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package tcp
import (
"context"
"errors"
"fmt"
"net"
"os"
"runtime"
Expand Down Expand Up @@ -117,12 +118,35 @@ func WithMetrics() Option {
}
}

// WithDialerForAddr sets a custom dialer for the given address.
// If set, it will be the *ONLY* dialer used.
func WithDialerForAddr(d DialerForAddr) Option {
return func(tr *TcpTransport) error {
tr.overrideDialerForAddr = d
return nil
}
}

type ContextDialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}

// DialerForAddr is a function that returns a dialer for a given address.
// Implementations must return either a ContextDialer or an error. It is
// invalid to return nil, nil.
type DialerForAddr func(raddr ma.Multiaddr) (ContextDialer, error)

// TcpTransport is the TCP transport.
type TcpTransport struct {
// Connection upgrader for upgrading insecure stream connections to
// secure multiplex connections.
upgrader transport.Upgrader

// optional custom dialer to use for dialing. If set, it will be the *ONLY* dialer
// used. The transport will not attempt to reuse the listen port to
// dial or the shared TCP transport for dialing.
overrideDialerForAddr DialerForAddr

disableReuseport bool // Explicitly disable reuseport.
enableMetrics bool

Expand Down Expand Up @@ -170,6 +194,35 @@ func (t *TcpTransport) CanDial(addr ma.Multiaddr) bool {
return dialMatcher.Matches(addr)
}

func (t *TcpTransport) customDial(ctx context.Context, raddr ma.Multiaddr) (manet.Conn, error) {
// get the net.Dial friendly arguments from the remote addr
rnet, rnaddr, err := manet.DialArgs(raddr)
if err != nil {
return nil, err
}
dialer, err := t.overrideDialerForAddr(raddr)
if err != nil {
return nil, err
}
if dialer == nil {
return nil, fmt.Errorf("dialer for address %s is nil", raddr)
}

// ok, Dial!
var nconn net.Conn
switch rnet {
case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6", "unix":
nconn, err = dialer.DialContext(ctx, rnet, rnaddr)
if err != nil {
return nil, err
}
default:
return nil, fmt.Errorf("unrecognized network: %s", rnet)
}

return manet.WrapNetConn(nconn)
}

func (t *TcpTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (manet.Conn, error) {
// Apply the deadline iff applicable
if t.connectTimeout > 0 {
Expand All @@ -178,6 +231,10 @@ func (t *TcpTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (manet.Co
defer cancel()
}

if t.overrideDialerForAddr != nil {
return t.customDial(ctx, raddr)
}

if t.sharedTcp != nil {
return t.sharedTcp.DialContext(ctx, raddr)
}
Expand Down
73 changes: 73 additions & 0 deletions p2p/transport/tcp/tcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package tcp
import (
"context"
"errors"
"net"
"testing"

"github.com/libp2p/go-libp2p/core/crypto"
Expand Down Expand Up @@ -205,3 +206,75 @@ func makeInsecureMuxer(t *testing.T) (peer.ID, []sec.SecureTransport) {
require.NoError(t, err)
return id, []sec.SecureTransport{insecure.NewWithIdentity(insecure.ID, id, priv)}
}

type errDialer struct {
err error
}

func (d errDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
return nil, d.err
}

func TestCustomOverrideTCPDialer(t *testing.T) {
t.Run("success", func(t *testing.T) {
peerA, ia := makeInsecureMuxer(t)
ua, err := tptu.New(ia, muxers, nil, nil, nil)
require.NoError(t, err)
ta, err := NewTCPTransport(ua, nil, nil)
require.NoError(t, err)
ln, err := ta.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0"))
require.NoError(t, err)
defer ln.Close()

_, ib := makeInsecureMuxer(t)
ub, err := tptu.New(ib, muxers, nil, nil, nil)
require.NoError(t, err)
called := false
customDialer := func(raddr ma.Multiaddr) (ContextDialer, error) {
called = true
return &net.Dialer{}, nil
}
tb, err := NewTCPTransport(ub, nil, nil, WithDialerForAddr(customDialer))
require.NoError(t, err)

conn, err := tb.Dial(context.Background(), ln.Multiaddr(), peerA)
require.NoError(t, err)
require.NotNil(t, conn)
require.True(t, called, "custom dialer should have been called")
conn.Close()
})

t.Run("errors", func(t *testing.T) {
peerA, ia := makeInsecureMuxer(t)
ua, err := tptu.New(ia, muxers, nil, nil, nil)
require.NoError(t, err)
ta, err := NewTCPTransport(ua, nil, nil)
require.NoError(t, err)
ln, err := ta.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0"))
require.NoError(t, err)
defer ln.Close()

for _, test := range []string{"error in factory", "error in custom dialer"} {
t.Run(test, func(t *testing.T) {
_, ib := makeInsecureMuxer(t)
ub, err := tptu.New(ib, muxers, nil, nil, nil)
require.NoError(t, err)
customErr := errors.New("custom dialer error")
customDialer := func(raddr ma.Multiaddr) (ContextDialer, error) {
if test == "error in factory" {
return nil, customErr
} else {
return errDialer{err: customErr}, nil
}
}
tb, err := NewTCPTransport(ub, nil, nil, WithDialerForAddr(customDialer))
require.NoError(t, err)

conn, err := tb.Dial(context.Background(), ln.Multiaddr(), peerA)
require.Error(t, err)
require.ErrorContains(t, err, customErr.Error())
require.Nil(t, conn)
})
}
})
}

0 comments on commit 9584246

Please sign in to comment.