diff --git a/config/config.go b/config/config.go index 3f0cd85e91..d9ef1a3aaf 100644 --- a/config/config.go +++ b/config/config.go @@ -24,6 +24,7 @@ import ( "github.com/libp2p/go-libp2p/p2p/host/autonat" "github.com/libp2p/go-libp2p/p2p/host/autorelay" bhost "github.com/libp2p/go-libp2p/p2p/host/basic" + blankhost "github.com/libp2p/go-libp2p/p2p/host/blank" "github.com/libp2p/go-libp2p/p2p/host/eventbus" "github.com/libp2p/go-libp2p/p2p/host/peerstore/pstoremem" rcmgr "github.com/libp2p/go-libp2p/p2p/host/resource-manager" @@ -129,6 +130,8 @@ type Config struct { DialRanker network.DialRanker SwarmOpts []swarm.Option + + DisableAutoNATv2 bool } func (cfg *Config) makeSwarm(eventBus event.Bus, enableMetrics bool) (*swarm.Swarm, error) { @@ -191,6 +194,76 @@ func (cfg *Config) makeSwarm(eventBus event.Bus, enableMetrics bool) (*swarm.Swa return swarm.NewSwarm(pid, cfg.Peerstore, eventBus, opts...) } +func (cfg *Config) makeAutoNATHost() (host.Host, error) { + autonatPrivKey, _, err := crypto.GenerateEd25519Key(rand.Reader) + if err != nil { + return nil, err + } + ps, err := pstoremem.NewPeerstore() + if err != nil { + return nil, err + } + + autoNatCfg := Config{ + Transports: cfg.Transports, + Muxers: cfg.Muxers, + SecurityTransports: cfg.SecurityTransports, + Insecure: cfg.Insecure, + PSK: cfg.PSK, + ConnectionGater: cfg.ConnectionGater, + Reporter: cfg.Reporter, + PeerKey: autonatPrivKey, + Peerstore: ps, + DialRanker: swarm.NoDelayDialRanker, + SwarmOpts: []swarm.Option{ + // Disable black hole detection on autonat dialers + // It is better to attempt a dial and fail for AutoNAT use cases + swarm.WithUDPBlackHoleConfig(false, 0, 0), + swarm.WithIPv6BlackHoleConfig(false, 0, 0), + }, + } + fxopts, err := autoNatCfg.addTransports() + if err != nil { + return nil, err + } + var dialerHost host.Host + fxopts = append(fxopts, + fx.Provide(eventbus.NewBus), + fx.Provide(func(lifecycle fx.Lifecycle, b event.Bus) (*swarm.Swarm, error) { + lifecycle.Append(fx.Hook{ + OnStop: func(context.Context) error { + return ps.Close() + }}) + sw, err := autoNatCfg.makeSwarm(b, false) + return sw, err + }), + fx.Provide(func(sw *swarm.Swarm) *blankhost.BlankHost { + return blankhost.NewBlankHost(sw) + }), + fx.Provide(func(bh *blankhost.BlankHost) host.Host { + return bh + }), + fx.Provide(func() crypto.PrivKey { return autonatPrivKey }), + fx.Provide(func(bh host.Host) peer.ID { return bh.ID() }), + fx.Invoke(func(bh *blankhost.BlankHost) { + dialerHost = bh + }), + ) + app := fx.New(fxopts...) + if err := app.Err(); err != nil { + return nil, err + } + err = app.Start(context.Background()) + if err != nil { + return nil, err + } + go func() { + <-dialerHost.Network().(*swarm.Swarm).Done() + app.Stop(context.Background()) + }() + return dialerHost, nil +} + func (cfg *Config) addTransports() ([]fx.Option, error) { fxopts := []fx.Option{ fx.WithLogger(func() fxevent.Logger { return getFXLogger() }), @@ -289,6 +362,14 @@ func (cfg *Config) addTransports() ([]fx.Option, error) { } func (cfg *Config) newBasicHost(swrm *swarm.Swarm, eventBus event.Bus) (*bhost.BasicHost, error) { + var autonatv2Dialer host.Host + if !cfg.DisableAutoNATv2 { + ah, err := cfg.makeAutoNATHost() + if err != nil { + return nil, err + } + autonatv2Dialer = ah + } h, err := bhost.NewHost(swrm, &bhost.HostOpts{ EventBus: eventBus, ConnManager: cfg.ConnManager, @@ -303,6 +384,8 @@ func (cfg *Config) newBasicHost(swrm *swarm.Swarm, eventBus event.Bus) (*bhost.B RelayServiceOpts: cfg.RelayServiceOpts, EnableMetrics: !cfg.DisableMetrics, PrometheusRegisterer: cfg.PrometheusRegisterer, + EnableAutoNATv2: !cfg.DisableAutoNATv2, + AutoNATv2Dialer: autonatv2Dialer, }) if err != nil { return nil, err diff --git a/core/network/network.go b/core/network/network.go index 66b0a1cd34..42d352a1bd 100644 --- a/core/network/network.go +++ b/core/network/network.go @@ -152,6 +152,20 @@ type Network interface { ResourceManager() ResourceManager } +// Dialability indicates how dialable a (peer, addr) combination is. +type Dialability int + +const ( + // DialabilityUnknown indicates that the dialer cannot dial the (peer, addr) combination + DialabilityUnknown Dialability = iota + + // DialabilityDialable indicates that the dialer can dial the (peer, addr) combination + DialabilityDialable + + // DialabilityUndialable indicates that the dialer cannot dial the (peer, addr) combinaton + DialabilityUndialable +) + // Dialer represents a service that can dial out to peers // (this is usually just a Network, but other services may not need the whole // stack, and thus it becomes easier to mock) @@ -185,6 +199,9 @@ type Dialer interface { // Notify/StopNotify register and unregister a notifiee for signals Notify(Notifiee) StopNotify(Notifiee) + + // CanDial returns whether the dialer can dial peer p at addr + CanDial(p peer.ID, addr ma.Multiaddr) Dialability } // AddrDelay provides an address along with the delay after which the address diff --git a/options.go b/options.go index 747d6c55e6..e356d9dc1b 100644 --- a/options.go +++ b/options.go @@ -598,3 +598,11 @@ func SwarmOpts(opts ...swarm.Option) Option { return nil } } + +// DisableAutoNATv2 disables autonat +func DisableAutoNATv2() Option { + return func(cfg *Config) error { + cfg.DisableAutoNATv2 = true + return nil + } +} diff --git a/p2p/host/basic/basic_host.go b/p2p/host/basic/basic_host.go index 367fca05f2..735923a3a4 100644 --- a/p2p/host/basic/basic_host.go +++ b/p2p/host/basic/basic_host.go @@ -23,6 +23,7 @@ import ( "github.com/libp2p/go-libp2p/p2p/host/eventbus" "github.com/libp2p/go-libp2p/p2p/host/pstoremanager" "github.com/libp2p/go-libp2p/p2p/host/relaysvc" + "github.com/libp2p/go-libp2p/p2p/protocol/autonatv2" relayv2 "github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay" "github.com/libp2p/go-libp2p/p2p/protocol/holepunch" "github.com/libp2p/go-libp2p/p2p/protocol/identify" @@ -102,6 +103,8 @@ type BasicHost struct { caBook peerstore.CertifiedAddrBook autoNat autonat.AutoNAT + + autonatv2 *autonatv2.AutoNAT } var _ host.Host = (*BasicHost)(nil) @@ -161,6 +164,9 @@ type HostOpts struct { EnableMetrics bool // PrometheusRegisterer is the PrometheusRegisterer used for metrics PrometheusRegisterer prometheus.Registerer + + EnableAutoNATv2 bool + AutoNATv2Dialer host.Host } // NewHost constructs a new *BasicHost and activates it by attaching its stream and connection handlers to the given inet.Network. @@ -301,6 +307,13 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) { h.pings = ping.NewPingService(h) } + if opts.EnableAutoNATv2 { + h.autonatv2, err = autonatv2.New(h, opts.AutoNATv2Dialer) + if err != nil { + return nil, fmt.Errorf("failed to create autonatv2: %w", err) + } + } + n.SetStreamHandler(h.newStreamHandler) // register to be notified when the network's listen addrs change, @@ -1029,6 +1042,9 @@ func (h *BasicHost) Close() error { if h.hps != nil { h.hps.Close() } + if h.autonatv2 != nil { + h.autonatv2.Close() + } _ = h.emitters.evtLocalProtocolsUpdated.Close() _ = h.emitters.evtLocalAddrsUpdated.Close() diff --git a/p2p/host/blank/blank.go b/p2p/host/blank/blank.go index 0fdded30ff..0cf6642f69 100644 --- a/p2p/host/blank/blank.go +++ b/p2p/host/blank/blank.go @@ -63,9 +63,10 @@ func NewBlankHost(n network.Network, options ...Option) *BlankHost { } bh := &BlankHost{ - n: n, - cmgr: cfg.cmgr, - mux: mstream.NewMultistreamMuxer[protocol.ID](), + n: n, + cmgr: cfg.cmgr, + mux: mstream.NewMultistreamMuxer[protocol.ID](), + eventbus: cfg.eventBus, } if bh.eventbus == nil { bh.eventbus = eventbus.NewBus(eventbus.WithMetricsTracer(eventbus.NewMetricsTracer())) diff --git a/p2p/net/mock/mock_peernet.go b/p2p/net/mock/mock_peernet.go index 2e56b7f2bb..8a06a95637 100644 --- a/p2p/net/mock/mock_peernet.go +++ b/p2p/net/mock/mock_peernet.go @@ -434,3 +434,7 @@ func (pn *peernet) notifyAll(notification func(f network.Notifiee)) { func (pn *peernet) ResourceManager() network.ResourceManager { return &network.NullResourceManager{} } + +func (pn *peernet) CanDial(p peer.ID, addr ma.Multiaddr) network.Dialability { + return network.DialabilityUnknown +} diff --git a/p2p/net/swarm/black_hole_detector.go b/p2p/net/swarm/black_hole_detector.go index dd7849eea6..be789beae8 100644 --- a/p2p/net/swarm/black_hole_detector.go +++ b/p2p/net/swarm/black_hole_detector.go @@ -144,6 +144,13 @@ func (b *blackHoleFilter) updateState() { } } +func (b *blackHoleFilter) State() blackHoleState { + b.mu.Lock() + defer b.mu.Unlock() + + return b.state +} + func (b *blackHoleFilter) trackMetrics() { if b.metricsTracer == nil { return @@ -244,6 +251,27 @@ func (d *blackHoleDetector) RecordResult(addr ma.Multiaddr, success bool) { } } +func (d *blackHoleDetector) State(addr ma.Multiaddr) blackHoleState { + if !manet.IsPublicAddr(addr) { + return blackHoleStateAllowed + } + + if d.udp != nil && isProtocolAddr(addr, ma.P_UDP) { + udpState := d.udp.State() + if udpState != blackHoleStateAllowed { + return udpState + } + } + + if d.ipv6 != nil && isProtocolAddr(addr, ma.P_IP6) { + ipv6State := d.ipv6.State() + if ipv6State != blackHoleStateAllowed { + return ipv6State + } + } + return blackHoleStateAllowed +} + // blackHoleConfig is the config used for black hole detection type blackHoleConfig struct { // Enabled enables black hole detection diff --git a/p2p/net/swarm/black_hole_detector_test.go b/p2p/net/swarm/black_hole_detector_test.go index dfbb30f90d..f62e8ca8c6 100644 --- a/p2p/net/swarm/black_hole_detector_test.go +++ b/p2p/net/swarm/black_hole_detector_test.go @@ -17,6 +17,9 @@ func TestBlackHoleFilterReset(t *testing.T) { if bhf.HandleRequest() != blackHoleResultProbing { t.Fatalf("expected calls up to n to be probes") } + if bhf.State() != blackHoleStateProbing { + t.Fatalf("expected state to be probing got %s", bhf.State()) + } bhf.RecordResult(false) } @@ -26,6 +29,9 @@ func TestBlackHoleFilterReset(t *testing.T) { if (i%n == 0 && result != blackHoleResultProbing) || (i%n != 0 && result != blackHoleResultBlocked) { t.Fatalf("expected every nth dial to be a probe") } + if bhf.State() != blackHoleStateBlocked { + t.Fatalf("expected state to be blocked, got %s", bhf.State()) + } } bhf.RecordResult(true) @@ -34,12 +40,18 @@ func TestBlackHoleFilterReset(t *testing.T) { if bhf.HandleRequest() != blackHoleResultProbing { t.Fatalf("expected black hole detector state to reset after success") } + if bhf.State() != blackHoleStateProbing { + t.Fatalf("expected state to be probing got %s", bhf.State()) + } bhf.RecordResult(false) } // next call should be blocked if bhf.HandleRequest() != blackHoleResultBlocked { t.Fatalf("expected dial to be blocked") + if bhf.State() != blackHoleStateBlocked { + t.Fatalf("expected state to be blocked, got %s", bhf.State()) + } } } diff --git a/p2p/net/swarm/swarm.go b/p2p/net/swarm/swarm.go index 0140c3f596..845c6963dc 100644 --- a/p2p/net/swarm/swarm.go +++ b/p2p/net/swarm/swarm.go @@ -112,7 +112,7 @@ func WithDialRanker(d network.DialRanker) Option { } } -// WithUDPBlackHoleConfig configures swarm to use c as the config for UDP black hole detection +// WithUDPBlackHoleConfig configures swarm to use the provided config for UDP black hole detection // n is the size of the sliding window used to evaluate black hole state // min is the minimum number of successes out of n required to not block requests func WithUDPBlackHoleConfig(enabled bool, n, min int) Option { @@ -122,7 +122,7 @@ func WithUDPBlackHoleConfig(enabled bool, n, min int) Option { } } -// WithIPv6BlackHoleConfig configures swarm to use c as the config for IPv6 black hole detection +// WithIPv6BlackHoleConfig configures swarm to use the provided config for IPv6 black hole detection // n is the size of the sliding window used to evaluate black hole state // min is the minimum number of successes out of n required to not block requests func WithIPv6BlackHoleConfig(enabled bool, n, min int) Option { diff --git a/p2p/net/swarm/swarm_dial.go b/p2p/net/swarm/swarm_dial.go index f639ce16a2..706a57032f 100644 --- a/p2p/net/swarm/swarm_dial.go +++ b/p2p/net/swarm/swarm_dial.go @@ -416,6 +416,26 @@ func (s *Swarm) dialNextAddr(ctx context.Context, p peer.ID, addr ma.Multiaddr, return nil } +func (s *Swarm) CanDial(p peer.ID, addr ma.Multiaddr) network.Dialability { + dialable, _ := s.filterKnownUndialables(p, []ma.Multiaddr{addr}) + if len(dialable) == 0 { + return network.DialabilityUndialable + } + + bhState := s.bhd.State(addr) + switch bhState { + case blackHoleStateAllowed: + return network.DialabilityDialable + case blackHoleStateProbing: + return network.DialabilityUnknown + case blackHoleStateBlocked: + return network.DialabilityUndialable + default: + log.Errorf("SWARM BUG: unhandled black hole state for dilability check %s", bhState) + return network.DialabilityUnknown + } +} + func (s *Swarm) nonProxyAddr(addr ma.Multiaddr) bool { t := s.TransportForDialing(addr) return !t.Proxy() diff --git a/p2p/protocol/autonatv2/autonat.go b/p2p/protocol/autonatv2/autonat.go new file mode 100644 index 0000000000..26811d0943 --- /dev/null +++ b/p2p/protocol/autonatv2/autonat.go @@ -0,0 +1,250 @@ +package autonatv2 + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + logging "github.com/ipfs/go-log/v2" + "github.com/libp2p/go-libp2p/core/event" + "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/p2p/protocol/autonatv2/pb" + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" + "golang.org/x/exp/rand" + "golang.org/x/exp/slices" +) + +//go:generate protoc --go_out=. --go_opt=Mpb/autonatv2.proto=./pb pb/autonatv2.proto + +const ( + ServiceName = "libp2p.autonatv2" + DialBackProtocol = "/libp2p/autonat/2/dial-back" + DialProtocol = "/libp2p/autonat/2/dial-request" + + maxMsgSize = 8192 + streamTimeout = time.Minute + dialBackStreamTimeout = 5 * time.Second + dialBackDialTimeout = 30 * time.Second + minHandshakeSizeBytes = 30_000 // for amplification attack prevention + maxHandshakeSizeBytes = 100_000 + // maxPeerAddresses is the number of addresses in a dial request the server + // will inspect, rest are ignored. + maxPeerAddresses = 50 +) + +var ( + ErrNoValidPeers = errors.New("no valid peers for autonat v2") + ErrDialRefused = errors.New("dial refused") + + log = logging.Logger("autonatv2") +) + +// Request is the request to verify reachability of a single address +type Request struct { + // Addr is the multiaddr to verify + Addr ma.Multiaddr + // SendDialData indicates whether to send dial data if the server requests it for Addr + SendDialData bool +} + +// Result is the result of the CheckReachability call +type Result struct { + // Idx is the index of the dialed address + Idx int + // Addr is the dialed address + Addr ma.Multiaddr + // Reachability of the dialed address + Reachability network.Reachability + // Status is the outcome of the dialback + Status pb.DialStatus +} + +// AutoNAT implements the AutoNAT v2 client and server. +// Users can check reachability for their addresses using the CheckReachability method. +// The server provides amplification attack prevention and rate limiting. +type AutoNAT struct { + host host.Host + sub event.Subscription + + // for cleanly closing + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup + + srv *server + cli *client + + mx sync.Mutex + peers *peersMap + + allowAllAddrs bool // for testing +} + +// New returns a new AutoNAT instance. The returned instance runs the server when the provided host +// is publicly reachable. +// host and dialerHost should have the same dialing capabilities. In case the host doesn't support +// a transport, dial back requests for address for that transport will be ignored. +func New(host host.Host, dialerHost host.Host, opts ...AutoNATOption) (*AutoNAT, error) { + s := defaultSettings() + for _, o := range opts { + if err := o(s); err != nil { + return nil, fmt.Errorf("failed to apply option: %w", err) + } + } + // We are listening on event.EvtPeerProtocolsUpdated, event.EvtPeerConnectednessChanged + // event.EvtPeerIdentificationCompleted to maintain our set of autonat supporting peers. + // + // We listen on event.EvtLocalReachabilityChanged to Disable the server if we are not + // publicly reachable. Currently this event is sent by the AutoNAT v1 module. During the + // transition period from AutoNAT v1 to v2, there won't be enough v2 servers on the network + // and most clients will be unable to discover a peer which supports AutoNAT v2. So, we use + // v1 to determine reachability for the transition period. + // + // Once there are enough v2 servers on the network for nodes to determine their reachability + // using AutoNAT v2, we'll use Address Pipeline + // (https://github.com/libp2p/go-libp2p/issues/2229)(to be implemented in a future release) + // to determine reachability using v2 client and send this event from Address Pipeline, if + // we are publicly reachable. + sub, err := host.EventBus().Subscribe([]interface{}{ + new(event.EvtLocalReachabilityChanged), + new(event.EvtPeerProtocolsUpdated), + new(event.EvtPeerConnectednessChanged), + new(event.EvtPeerIdentificationCompleted), + }) + if err != nil { + return nil, fmt.Errorf("event subscription: %w", err) + } + + ctx, cancel := context.WithCancel(context.Background()) + an := &AutoNAT{ + host: host, + ctx: ctx, + cancel: cancel, + sub: sub, + srv: newServer(host, dialerHost, s), + cli: newClient(host), + allowAllAddrs: s.allowAllAddrs, + peers: newPeersMap(), + } + an.cli.RegisterDialBack() + + an.wg.Add(1) + go an.background() + return an, nil +} + +func (an *AutoNAT) background() { + for { + select { + case <-an.ctx.Done(): + an.srv.Disable() + an.srv.Close() + an.peers = nil + an.wg.Done() + return + case e := <-an.sub.Out(): + switch evt := e.(type) { + case event.EvtLocalReachabilityChanged: + if evt.Reachability == network.ReachabilityPrivate { + an.srv.Disable() + } else { + an.srv.Enable() + } + case event.EvtPeerProtocolsUpdated: + an.updatePeer(evt.Peer) + case event.EvtPeerConnectednessChanged: + an.updatePeer(evt.Peer) + case event.EvtPeerIdentificationCompleted: + an.updatePeer(evt.Peer) + } + } + } +} + +func (an *AutoNAT) Close() { + an.cancel() + an.wg.Wait() +} + +// CheckReachability makes a single dial request for checking reachability for requested addresses +func (an *AutoNAT) CheckReachability(ctx context.Context, reqs []Request) (Result, error) { + if !an.allowAllAddrs { + for _, r := range reqs { + if !manet.IsPublicAddr(r.Addr) { + return Result{}, fmt.Errorf("private address cannot be verified by autonatv2: %s", r.Addr) + } + } + } + p := an.peers.GetRand() + if p == "" { + return Result{}, ErrNoValidPeers + } + + res, err := an.cli.CheckReachability(ctx, p, reqs) + if err != nil { + log.Debugf("reachability check with %s failed, err: %s", p, err) + return Result{}, fmt.Errorf("reachability check with %s failed: %w", p, err) + } + log.Debugf("reachability check with %s successful", p) + return res, nil +} + +func (an *AutoNAT) updatePeer(p peer.ID) { + an.mx.Lock() + defer an.mx.Unlock() + + // There are no ordering gurantees between identify and swarm events. Check peerstore + // and swarm for the current state + protos, err := an.host.Peerstore().SupportsProtocols(p, DialProtocol) + connectedness := an.host.Network().Connectedness(p) + if err == nil && slices.Contains(protos, DialProtocol) && connectedness == network.Connected { + an.peers.Put(p) + } else { + an.peers.Delete(p) + } +} + +// peersMap provides random access to a set of peers. This is useful when the map iteration order is +// not sufficiently random. +type peersMap struct { + peerIdx map[peer.ID]int + peers []peer.ID +} + +func newPeersMap() *peersMap { + return &peersMap{ + peerIdx: make(map[peer.ID]int), + peers: make([]peer.ID, 0), + } +} + +func (p *peersMap) GetRand() peer.ID { + if len(p.peers) == 0 { + return "" + } + return p.peers[rand.Intn(len(p.peers))] +} + +func (p *peersMap) Put(pid peer.ID) { + if _, ok := p.peerIdx[pid]; ok { + return + } + p.peers = append(p.peers, pid) + p.peerIdx[pid] = len(p.peers) - 1 +} + +func (p *peersMap) Delete(pid peer.ID) { + idx, ok := p.peerIdx[pid] + if !ok { + return + } + p.peers[idx] = p.peers[len(p.peers)-1] + p.peerIdx[p.peers[idx]] = idx + p.peers = p.peers[:len(p.peers)-1] + delete(p.peerIdx, pid) +} diff --git a/p2p/protocol/autonatv2/autonat_test.go b/p2p/protocol/autonatv2/autonat_test.go new file mode 100644 index 0000000000..f1e8299cb2 --- /dev/null +++ b/p2p/protocol/autonatv2/autonat_test.go @@ -0,0 +1,638 @@ +package autonatv2 + +import ( + "context" + "errors" + "fmt" + "sync/atomic" + "testing" + "time" + + "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/peerstore" + bhost "github.com/libp2p/go-libp2p/p2p/host/blank" + "github.com/libp2p/go-libp2p/p2p/host/eventbus" + swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" + "github.com/libp2p/go-libp2p/p2p/protocol/autonatv2/pb" + + "github.com/libp2p/go-msgio/pbio" + ma "github.com/multiformats/go-multiaddr" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func newAutoNAT(t *testing.T, dialer host.Host, opts ...AutoNATOption) *AutoNAT { + t.Helper() + b := eventbus.NewBus() + h := bhost.NewBlankHost(swarmt.GenSwarm(t, swarmt.EventBus(b)), bhost.WithEventBus(b)) + if dialer == nil { + dialer = bhost.NewBlankHost(swarmt.GenSwarm(t)) + } + an, err := New(h, dialer, opts...) + if err != nil { + t.Error(err) + } + an.srv.Enable() + an.cli.RegisterDialBack() + return an +} + +func parseAddrs(t *testing.T, msg *pb.Message) []ma.Multiaddr { + t.Helper() + req := msg.GetDialRequest() + addrs := make([]ma.Multiaddr, 0) + for _, ab := range req.Addrs { + a, err := ma.NewMultiaddrBytes(ab) + if err != nil { + t.Error("invalid addr bytes", ab) + } + addrs = append(addrs, a) + } + return addrs +} + +// idAndConnect identifies b to a and connects them +func idAndConnect(t *testing.T, a, b host.Host) { + a.Peerstore().AddAddrs(b.ID(), b.Addrs(), peerstore.PermanentAddrTTL) + a.Peerstore().AddProtocols(b.ID(), DialProtocol) + + err := a.Connect(context.Background(), peer.AddrInfo{ID: b.ID()}) + require.NoError(t, err) +} + +// waitForPeer waits for a to have 1 peer in the peerMap +func waitForPeer(t *testing.T, a *AutoNAT) { + t.Helper() + require.Eventually(t, func() bool { + a.mx.Lock() + defer a.mx.Unlock() + return a.peers.GetRand() != "" + }, 5*time.Second, 100*time.Millisecond) +} + +// idAndWait provides server address and protocol to client +func idAndWait(t *testing.T, cli *AutoNAT, srv *AutoNAT) { + idAndConnect(t, cli.host, srv.host) + waitForPeer(t, cli) +} + +func TestAutoNATPrivateAddr(t *testing.T) { + an := newAutoNAT(t, nil) + res, err := an.CheckReachability(context.Background(), []Request{{Addr: ma.StringCast("/ip4/192.168.0.1/udp/10/quic-v1")}}) + require.Equal(t, res, Result{}) + require.Contains(t, err.Error(), "private address cannot be verified by autonatv2") +} + +func TestClientRequest(t *testing.T) { + an := newAutoNAT(t, nil, allowAllAddrs) + defer an.Close() + defer an.host.Close() + + b := bhost.NewBlankHost(swarmt.GenSwarm(t)) + defer b.Close() + idAndConnect(t, an.host, b) + waitForPeer(t, an) + + addrs := an.host.Addrs() + addrbs := make([][]byte, len(addrs)) + for i := 0; i < len(addrs); i++ { + addrbs[i] = addrs[i].Bytes() + } + + var receivedRequest atomic.Bool + b.SetStreamHandler(DialProtocol, func(s network.Stream) { + receivedRequest.Store(true) + r := pbio.NewDelimitedReader(s, maxMsgSize) + var msg pb.Message + assert.NoError(t, r.ReadMsg(&msg)) + assert.NotNil(t, msg.GetDialRequest()) + assert.Equal(t, addrbs, msg.GetDialRequest().Addrs) + s.Reset() + }) + + res, err := an.CheckReachability(context.Background(), []Request{ + {Addr: addrs[0], SendDialData: true}, {Addr: addrs[1]}, + }) + require.Equal(t, res, Result{}) + require.NotNil(t, err) + require.True(t, receivedRequest.Load()) +} + +func TestClientServerError(t *testing.T) { + an := newAutoNAT(t, nil, allowAllAddrs) + defer an.Close() + defer an.host.Close() + + b := bhost.NewBlankHost(swarmt.GenSwarm(t)) + defer b.Close() + idAndConnect(t, an.host, b) + waitForPeer(t, an) + + tests := []struct { + handler func(network.Stream) + errorStr string + }{ + { + handler: func(s network.Stream) { + s.Reset() + }, + errorStr: "stream reset", + }, + { + handler: func(s network.Stream) { + w := pbio.NewDelimitedWriter(s) + assert.NoError(t, w.WriteMsg( + &pb.Message{Msg: &pb.Message_DialRequest{DialRequest: &pb.DialRequest{}}})) + }, + errorStr: "invalid msg type", + }, + { + handler: func(s network.Stream) { + w := pbio.NewDelimitedWriter(s) + assert.NoError(t, w.WriteMsg( + &pb.Message{Msg: &pb.Message_DialResponse{ + DialResponse: &pb.DialResponse{ + Status: pb.DialResponse_E_DIAL_REFUSED, + }, + }}, + )) + }, + errorStr: ErrDialRefused.Error(), + }, + } + + for i, tc := range tests { + t.Run(fmt.Sprintf("test-%d", i), func(t *testing.T) { + b.SetStreamHandler(DialProtocol, tc.handler) + addrs := an.host.Addrs() + res, err := an.CheckReachability( + context.Background(), + newTestRequests(addrs, false)) + require.Equal(t, res, Result{}) + require.NotNil(t, err) + require.Contains(t, err.Error(), tc.errorStr) + }) + } +} + +func TestClientDataRequest(t *testing.T) { + an := newAutoNAT(t, nil, allowAllAddrs) + defer an.Close() + defer an.host.Close() + + b := bhost.NewBlankHost(swarmt.GenSwarm(t)) + defer b.Close() + idAndConnect(t, an.host, b) + waitForPeer(t, an) + + tests := []struct { + handler func(network.Stream) + name string + }{ + { + name: "provides dial data", + handler: func(s network.Stream) { + r := pbio.NewDelimitedReader(s, maxMsgSize) + var msg pb.Message + assert.NoError(t, r.ReadMsg(&msg)) + w := pbio.NewDelimitedWriter(s) + if err := w.WriteMsg(&pb.Message{ + Msg: &pb.Message_DialDataRequest{ + DialDataRequest: &pb.DialDataRequest{ + AddrIdx: 0, + NumBytes: 10000, + }, + }}, + ); err != nil { + t.Error(err) + s.Reset() + return + } + var dialData []byte + for len(dialData) < 10000 { + if err := r.ReadMsg(&msg); err != nil { + t.Error(err) + s.Reset() + return + } + if msg.GetDialDataResponse() == nil { + t.Errorf("expected to receive msg of type DialDataResponse") + s.Reset() + return + } + dialData = append(dialData, msg.GetDialDataResponse().Data...) + } + s.Reset() + }, + }, + { + name: "low priority addr", + handler: func(s network.Stream) { + r := pbio.NewDelimitedReader(s, maxMsgSize) + var msg pb.Message + assert.NoError(t, r.ReadMsg(&msg)) + w := pbio.NewDelimitedWriter(s) + if err := w.WriteMsg(&pb.Message{ + Msg: &pb.Message_DialDataRequest{ + DialDataRequest: &pb.DialDataRequest{ + AddrIdx: 1, + NumBytes: 10000, + }, + }}, + ); err != nil { + t.Error(err) + s.Reset() + return + } + assert.Error(t, r.ReadMsg(&msg)) + s.Reset() + }, + }, + { + name: "too high dial data request", + handler: func(s network.Stream) { + r := pbio.NewDelimitedReader(s, maxMsgSize) + var msg pb.Message + assert.NoError(t, r.ReadMsg(&msg)) + w := pbio.NewDelimitedWriter(s) + if err := w.WriteMsg(&pb.Message{ + Msg: &pb.Message_DialDataRequest{ + DialDataRequest: &pb.DialDataRequest{ + AddrIdx: 0, + NumBytes: 1 << 32, + }, + }}, + ); err != nil { + t.Error(err) + s.Reset() + return + } + assert.Error(t, r.ReadMsg(&msg)) + s.Reset() + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + b.SetStreamHandler(DialProtocol, tc.handler) + addrs := an.host.Addrs() + + res, err := an.CheckReachability( + context.Background(), + []Request{ + {Addr: addrs[0], SendDialData: true}, + {Addr: addrs[1]}, + }) + require.Equal(t, res, Result{}) + require.NotNil(t, err) + }) + } +} + +func TestClientDialBacks(t *testing.T) { + an := newAutoNAT(t, nil, allowAllAddrs) + defer an.Close() + defer an.host.Close() + + b := bhost.NewBlankHost(swarmt.GenSwarm(t)) + defer b.Close() + idAndConnect(t, an.host, b) + waitForPeer(t, an) + + dialerHost := bhost.NewBlankHost(swarmt.GenSwarm(t)) + defer dialerHost.Close() + + readReq := func(r pbio.Reader) ([]ma.Multiaddr, uint64, error) { + var msg pb.Message + if err := r.ReadMsg(&msg); err != nil { + return nil, 0, err + } + if msg.GetDialRequest() == nil { + return nil, 0, errors.New("no dial request in msg") + } + addrs := parseAddrs(t, &msg) + return addrs, msg.GetDialRequest().GetNonce(), nil + } + + writeNonce := func(addr ma.Multiaddr, nonce uint64) error { + pid := an.host.ID() + dialerHost.Peerstore().AddAddr(pid, addr, peerstore.PermanentAddrTTL) + defer func() { + dialerHost.Network().ClosePeer(pid) + dialerHost.Peerstore().RemovePeer(pid) + dialerHost.Peerstore().ClearAddrs(pid) + }() + as, err := dialerHost.NewStream(context.Background(), pid, DialBackProtocol) + if err != nil { + return err + } + w := pbio.NewDelimitedWriter(as) + if err := w.WriteMsg(&pb.DialBack{Nonce: nonce}); err != nil { + return err + } + as.CloseWrite() + data := make([]byte, 1) + as.Read(data) + as.Close() + return nil + } + + tests := []struct { + name string + handler func(network.Stream) + success bool + }{ + { + name: "correct dial attempt", + handler: func(s network.Stream) { + r := pbio.NewDelimitedReader(s, maxMsgSize) + w := pbio.NewDelimitedWriter(s) + + addrs, nonce, err := readReq(r) + if err != nil { + s.Reset() + t.Error(err) + return + } + if err := writeNonce(addrs[1], nonce); err != nil { + s.Reset() + t.Error(err) + return + } + w.WriteMsg(&pb.Message{ + Msg: &pb.Message_DialResponse{ + DialResponse: &pb.DialResponse{ + Status: pb.DialResponse_OK, + DialStatus: pb.DialStatus_OK, + AddrIdx: 1, + }, + }, + }) + s.Close() + }, + success: true, + }, + { + name: "no dial attempt", + handler: func(s network.Stream) { + r := pbio.NewDelimitedReader(s, maxMsgSize) + if _, _, err := readReq(r); err != nil { + s.Reset() + t.Error(err) + return + } + resp := &pb.DialResponse{ + Status: pb.DialResponse_OK, + DialStatus: pb.DialStatus_OK, + AddrIdx: 0, + } + w := pbio.NewDelimitedWriter(s) + w.WriteMsg(&pb.Message{ + Msg: &pb.Message_DialResponse{ + DialResponse: resp, + }, + }) + s.Close() + }, + success: false, + }, + { + name: "invalid reported address", + handler: func(s network.Stream) { + r := pbio.NewDelimitedReader(s, maxMsgSize) + addrs, nonce, err := readReq(r) + if err != nil { + s.Reset() + t.Error(err) + return + } + + if err := writeNonce(addrs[1], nonce); err != nil { + s.Reset() + t.Error(err) + return + } + + w := pbio.NewDelimitedWriter(s) + w.WriteMsg(&pb.Message{ + Msg: &pb.Message_DialResponse{ + DialResponse: &pb.DialResponse{ + Status: pb.DialResponse_OK, + DialStatus: pb.DialStatus_OK, + AddrIdx: 0, + }, + }, + }) + s.Close() + }, + success: false, + }, + { + name: "invalid nonce", + handler: func(s network.Stream) { + r := pbio.NewDelimitedReader(s, maxMsgSize) + addrs, nonce, err := readReq(r) + if err != nil { + s.Reset() + t.Error(err) + return + } + if err := writeNonce(addrs[0], nonce-1); err != nil { + s.Reset() + t.Error(err) + return + } + w := pbio.NewDelimitedWriter(s) + w.WriteMsg(&pb.Message{ + Msg: &pb.Message_DialResponse{ + DialResponse: &pb.DialResponse{ + Status: pb.DialResponse_OK, + DialStatus: pb.DialStatus_OK, + AddrIdx: 0, + }, + }, + }) + s.Close() + }, + success: false, + }, + { + name: "invalid addr index", + handler: func(s network.Stream) { + r := pbio.NewDelimitedReader(s, maxMsgSize) + _, _, err := readReq(r) + if err != nil { + s.Reset() + t.Error(err) + return + } + w := pbio.NewDelimitedWriter(s) + w.WriteMsg(&pb.Message{ + Msg: &pb.Message_DialResponse{ + DialResponse: &pb.DialResponse{ + Status: pb.DialResponse_OK, + DialStatus: pb.DialStatus_OK, + AddrIdx: 10, + }, + }, + }) + s.Close() + }, + success: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + addrs := an.host.Addrs() + b.SetStreamHandler(DialProtocol, tc.handler) + res, err := an.CheckReachability( + context.Background(), + []Request{ + {Addr: addrs[0], SendDialData: true}, + {Addr: addrs[1]}, + }) + if !tc.success { + require.Error(t, err) + require.Equal(t, Result{}, res) + } else { + require.NoError(t, err) + require.Equal(t, res.Reachability, network.ReachabilityPublic) + require.Equal(t, res.Status, pb.DialStatus_OK) + } + }) + } +} + +func TestEventSubscription(t *testing.T) { + an := newAutoNAT(t, nil) + defer an.host.Close() + + b := bhost.NewBlankHost(swarmt.GenSwarm(t)) + defer b.Close() + c := bhost.NewBlankHost(swarmt.GenSwarm(t)) + defer c.Close() + + idAndConnect(t, an.host, b) + require.Eventually(t, func() bool { + an.mx.Lock() + defer an.mx.Unlock() + return len(an.peers.peers) == 1 + }, 5*time.Second, 100*time.Millisecond) + + idAndConnect(t, an.host, c) + require.Eventually(t, func() bool { + an.mx.Lock() + defer an.mx.Unlock() + return len(an.peers.peers) == 2 + }, 5*time.Second, 100*time.Millisecond) + + an.host.Network().ClosePeer(b.ID()) + require.Eventually(t, func() bool { + an.mx.Lock() + defer an.mx.Unlock() + return len(an.peers.peers) == 1 + }, 5*time.Second, 100*time.Millisecond) + + an.host.Network().ClosePeer(c.ID()) + require.Eventually(t, func() bool { + an.mx.Lock() + defer an.mx.Unlock() + return len(an.peers.peers) == 0 + }, 5*time.Second, 100*time.Millisecond) +} + +func TestPeersMap(t *testing.T) { + emptyPeerID := peer.ID("") + + t.Run("single_item", func(t *testing.T) { + p := newPeersMap() + p.Put("peer1") + p.Delete("peer1") + p.Put("peer1") + require.Equal(t, peer.ID("peer1"), p.GetRand()) + p.Delete("peer1") + require.Equal(t, emptyPeerID, p.GetRand()) + }) + + t.Run("multiple_items", func(t *testing.T) { + p := newPeersMap() + require.Equal(t, emptyPeerID, p.GetRand()) + + allPeers := make(map[peer.ID]bool) + for i := 0; i < 20; i++ { + pid := peer.ID(fmt.Sprintf("peer-%d", i)) + allPeers[pid] = true + p.Put(pid) + } + foundPeers := make(map[peer.ID]bool) + for i := 0; i < 1000; i++ { + pid := p.GetRand() + require.NotEqual(t, emptyPeerID, p) + require.True(t, allPeers[pid]) + foundPeers[pid] = true + if len(foundPeers) == len(allPeers) { + break + } + } + for pid := range allPeers { + p.Delete(pid) + } + require.Equal(t, emptyPeerID, p.GetRand()) + }) +} + +func TestAreAddrsConsistency(t *testing.T) { + tests := []struct { + name string + localAddr ma.Multiaddr + dialAddr ma.Multiaddr + success bool + }{ + { + name: "simple match", + localAddr: ma.StringCast("/ip4/192.168.0.1/tcp/12345"), + dialAddr: ma.StringCast("/ip4/1.2.3.4/tcp/23232"), + success: true, + }, + { + name: "nat64 match", + localAddr: ma.StringCast("/ip6/1::1/tcp/12345"), + dialAddr: ma.StringCast("/ip4/1.2.3.4/tcp/23232"), + success: true, + }, + { + name: "simple mismatch", + localAddr: ma.StringCast("/ip4/192.168.0.1/tcp/12345"), + dialAddr: ma.StringCast("/ip4/1.2.3.4/udp/23232/quic-v1"), + success: false, + }, + { + name: "quic-vs-webtransport", + localAddr: ma.StringCast("/ip4/192.168.0.1/udp/12345/quic-v1"), + dialAddr: ma.StringCast("/ip4/1.2.3.4/udp/123/quic-v1/webtransport"), + success: false, + }, + { + name: "nat64 mismatch", + localAddr: ma.StringCast("/ip4/192.168.0.1/udp/12345/quic-v1"), + dialAddr: ma.StringCast("/ip6/1::1/udp/123/quic-v1/"), + success: false, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if areAddrsConsistent(tc.localAddr, tc.dialAddr) != tc.success { + wantStr := "match" + if !tc.success { + wantStr = "mismatch" + } + t.Errorf("expected %s between\nlocal addr: %s\ndial addr: %s", wantStr, tc.localAddr, tc.dialAddr) + } + }) + } + +} diff --git a/p2p/protocol/autonatv2/client.go b/p2p/protocol/autonatv2/client.go new file mode 100644 index 0000000000..e2a7db75cd --- /dev/null +++ b/p2p/protocol/autonatv2/client.go @@ -0,0 +1,286 @@ +package autonatv2 + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/p2p/protocol/autonatv2/pb" + "github.com/libp2p/go-msgio/pbio" + ma "github.com/multiformats/go-multiaddr" + "golang.org/x/exp/rand" +) + +//go:generate protoc --go_out=. --go_opt=Mpb/autonatv2.proto=./pb pb/autonatv2.proto + +// client implements the client for making dial requests for AutoNAT v2. It verifies successful +// dials and provides an option to send data for dial requests. +type client struct { + host host.Host + dialData []byte + + mu sync.Mutex + // dialBackQueues maps nonce to the channel for providing the local multiaddr of the connection + // the nonce was received on + dialBackQueues map[uint64]chan ma.Multiaddr +} + +func newClient(h host.Host) *client { + return &client{host: h, dialData: make([]byte, 8000), dialBackQueues: make(map[uint64]chan ma.Multiaddr)} +} + +// RegisterDialBack registers the client to receive DialBack streams initiated by the server to send the nonce. +func (ac *client) RegisterDialBack() { + ac.host.SetStreamHandler(DialBackProtocol, ac.handleDialBack) +} + +// CheckReachability verifies address reachability with a AutoNAT v2 server p. +func (ac *client) CheckReachability(ctx context.Context, p peer.ID, reqs []Request) (Result, error) { + ctx, cancel := context.WithTimeout(ctx, streamTimeout) + defer cancel() + + s, err := ac.host.NewStream(ctx, p, DialProtocol) + if err != nil { + return Result{}, fmt.Errorf("open %s stream failed: %w", DialProtocol, err) + } + + if err := s.Scope().SetService(ServiceName); err != nil { + s.Reset() + return Result{}, fmt.Errorf("attach stream %s to service %s failed: %w", DialProtocol, ServiceName, err) + } + + if err := s.Scope().ReserveMemory(maxMsgSize, network.ReservationPriorityAlways); err != nil { + s.Reset() + return Result{}, fmt.Errorf("failed to reserve memory for stream %s: %w", DialProtocol, err) + } + defer s.Scope().ReleaseMemory(maxMsgSize) + + s.SetDeadline(time.Now().Add(streamTimeout)) + defer s.Close() + + nonce := rand.Uint64() + ch := make(chan ma.Multiaddr, 1) + ac.mu.Lock() + ac.dialBackQueues[nonce] = ch + ac.mu.Unlock() + defer func() { + ac.mu.Lock() + delete(ac.dialBackQueues, nonce) + ac.mu.Unlock() + }() + + msg := newDialRequest(reqs, nonce) + w := pbio.NewDelimitedWriter(s) + if err := w.WriteMsg(&msg); err != nil { + s.Reset() + return Result{}, fmt.Errorf("dial request write failed: %w", err) + } + + r := pbio.NewDelimitedReader(s, maxMsgSize) + if err := r.ReadMsg(&msg); err != nil { + s.Reset() + return Result{}, fmt.Errorf("dial msg read failed: %w", err) + } + + switch { + case msg.GetDialResponse() != nil: + break + // provide dial data if appropriate + case msg.GetDialDataRequest() != nil: + idx := int(msg.GetDialDataRequest().AddrIdx) + if idx >= len(reqs) { // invalid address index + s.Reset() + return Result{}, fmt.Errorf("dial data: addr index out of range: %d [0-%d)", idx, len(reqs)) + } + if msg.GetDialDataRequest().NumBytes > maxHandshakeSizeBytes { // data request is too high + s.Reset() + return Result{}, fmt.Errorf("dial data requested too high: %d", msg.GetDialDataRequest().NumBytes) + } + if !reqs[idx].SendDialData { // low priority addr + s.Reset() + return Result{}, fmt.Errorf("dial data requested for low priority addr: %s index %d", reqs[idx].Addr, idx) + } + + // dial data request is valid and we want to send data + if err := ac.sendDialData(msg.GetDialDataRequest(), w, &msg); err != nil { + s.Reset() + return Result{}, fmt.Errorf("dial data send failed: %w", err) + } + if err := r.ReadMsg(&msg); err != nil { + s.Reset() + return Result{}, fmt.Errorf("dial response read failed: %w", err) + } + if msg.GetDialResponse() == nil { + s.Reset() + return Result{}, fmt.Errorf("invalid response type: %T", msg.Msg) + } + default: + s.Reset() + return Result{}, fmt.Errorf("invalid msg type: %T", msg.Msg) + } + + resp := msg.GetDialResponse() + if resp.GetStatus() != pb.DialResponse_OK { + // E_DIAL_REFUSED has implication for deciding future address verificiation priorities + // wrap a distinct error for convenient errors.Is usage + if resp.GetStatus() == pb.DialResponse_E_DIAL_REFUSED { + return Result{}, fmt.Errorf("dial request failed: %w", ErrDialRefused) + } + return Result{}, fmt.Errorf("dial request failed: response status %d %s", resp.GetStatus(), + pb.DialStatus_name[int32(resp.GetStatus())]) + } + if resp.GetDialStatus() == pb.DialStatus_UNUSED { + return Result{}, fmt.Errorf("invalid response: invalid dial status UNUSED") + } + if int(resp.AddrIdx) >= len(reqs) { + return Result{}, fmt.Errorf("invalid response: addr index out of range: %d [0-%d)", resp.AddrIdx, len(reqs)) + } + + // wait for nonce from the server + var dialBackAddr ma.Multiaddr + if resp.GetDialStatus() == pb.DialStatus_OK { + timer := time.NewTimer(dialBackStreamTimeout) + select { + case at := <-ch: + dialBackAddr = at + case <-ctx.Done(): + case <-timer.C: + } + timer.Stop() + } + return ac.newResult(resp, reqs, dialBackAddr) +} + +func (ac *client) newResult(resp *pb.DialResponse, reqs []Request, dialBackAddr ma.Multiaddr) (Result, error) { + idx := int(resp.AddrIdx) + addr := reqs[idx].Addr + + var rch network.Reachability + switch resp.DialStatus { + case pb.DialStatus_OK: + if !areAddrsConsistent(dialBackAddr, addr) { + // The server reported a successful dial back but we didn't receive the nonce. + // Discard the response and fail. + return Result{}, fmt.Errorf("invalid repsonse: no dialback received") + } + rch = network.ReachabilityPublic + case pb.DialStatus_E_DIAL_ERROR: + rch = network.ReachabilityPrivate + case pb.DialStatus_E_DIAL_BACK_ERROR: + rch = network.ReachabilityUnknown + default: + // Unexpected response code. Discard the response and fail. + log.Warnf("invalid status code received in response for addr %s: %d", addr, resp.DialStatus) + return Result{}, fmt.Errorf("invalid response: invalid status code for addr %s: %d", addr, resp.DialStatus) + } + + return Result{ + Idx: idx, + Addr: addr, + Reachability: rch, + Status: resp.DialStatus, + }, nil +} + +func (ac *client) sendDialData(req *pb.DialDataRequest, w pbio.Writer, msg *pb.Message) error { + nb := req.GetNumBytes() + ddResp := &pb.DialDataResponse{Data: ac.dialData} + *msg = pb.Message{ + Msg: &pb.Message_DialDataResponse{ + DialDataResponse: ddResp, + }, + } + for remain := int(nb); remain > 0; { + end := remain + if end > len(ddResp.Data) { + end = len(ddResp.Data) + } + ddResp.Data = ddResp.Data[:end] + if err := w.WriteMsg(msg); err != nil { + return err + } + remain -= end + } + return nil +} + +func newDialRequest(reqs []Request, nonce uint64) pb.Message { + addrbs := make([][]byte, len(reqs)) + for i, r := range reqs { + addrbs[i] = r.Addr.Bytes() + } + return pb.Message{ + Msg: &pb.Message_DialRequest{ + DialRequest: &pb.DialRequest{ + Addrs: addrbs, + Nonce: nonce, + }, + }, + } +} + +// handleDialBack receives the nonce on the dial-back stream +func (ac *client) handleDialBack(s network.Stream) { + if err := s.Scope().SetService(ServiceName); err != nil { + log.Debugf("failed to attach stream to service %s: %w", ServiceName, err) + s.Reset() + return + } + + if err := s.Scope().ReserveMemory(maxMsgSize, network.ReservationPriorityAlways); err != nil { + log.Debugf("failed to reserve memory for stream %s: %w", DialBackProtocol, err) + s.Reset() + return + } + defer s.Scope().ReleaseMemory(maxMsgSize) + + s.SetDeadline(time.Now().Add(dialBackStreamTimeout)) + defer s.Close() + + r := pbio.NewDelimitedReader(s, maxMsgSize) + var msg pb.DialBack + if err := r.ReadMsg(&msg); err != nil { + log.Debugf("failed to read dialback msg from %s: %s", s.Conn().RemotePeer(), err) + s.Reset() + return + } + nonce := msg.GetNonce() + + ac.mu.Lock() + ch := ac.dialBackQueues[nonce] + ac.mu.Unlock() + select { + case ch <- s.Conn().LocalMultiaddr(): + default: + } +} + +func areAddrsConsistent(local, external ma.Multiaddr) bool { + if local == nil || external == nil { + return false + } + + localProtos := local.Protocols() + externalProtos := external.Protocols() + if len(localProtos) != len(externalProtos) { + return false + } + for i := 0; i < len(localProtos); i++ { + if i == 0 { + if localProtos[i].Code == externalProtos[i].Code || + localProtos[i].Code == ma.P_IP6 && externalProtos[i].Code == ma.P_IP4 /* NAT64 */ { + continue + } + return false + } else { + if localProtos[i].Code != externalProtos[i].Code { + return false + } + } + } + return true +} diff --git a/p2p/protocol/autonatv2/options.go b/p2p/protocol/autonatv2/options.go new file mode 100644 index 0000000000..3a59d8d823 --- /dev/null +++ b/p2p/protocol/autonatv2/options.go @@ -0,0 +1,48 @@ +package autonatv2 + +import "time" + +// autoNATSettings is used to configure AutoNAT +type autoNATSettings struct { + allowAllAddrs bool + serverRPM int + serverPerPeerRPM int + serverDialDataRPM int + dataRequestPolicy dataRequestPolicyFunc + now func() time.Time +} + +func defaultSettings() *autoNATSettings { + return &autoNATSettings{ + allowAllAddrs: false, + // TODO: confirm rate limiting defaults + serverRPM: 20, + serverPerPeerRPM: 2, + serverDialDataRPM: 5, + dataRequestPolicy: amplificationAttackPrevention, + now: time.Now, + } +} + +type AutoNATOption func(s *autoNATSettings) error + +func WithServerRateLimit(rpm, perPeerRPM, dialDataRPM int) AutoNATOption { + return func(s *autoNATSettings) error { + s.serverRPM = rpm + s.serverPerPeerRPM = perPeerRPM + s.serverDialDataRPM = dialDataRPM + return nil + } +} + +func withDataRequestPolicy(drp dataRequestPolicyFunc) AutoNATOption { + return func(s *autoNATSettings) error { + s.dataRequestPolicy = drp + return nil + } +} + +func allowAllAddrs(s *autoNATSettings) error { + s.allowAllAddrs = true + return nil +} diff --git a/p2p/protocol/autonatv2/pb/autonatv2.pb.go b/p2p/protocol/autonatv2/pb/autonatv2.pb.go new file mode 100644 index 0000000000..f94b077acf --- /dev/null +++ b/p2p/protocol/autonatv2/pb/autonatv2.pb.go @@ -0,0 +1,706 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.30.0 +// protoc v3.21.12 +// source: pb/autonatv2.proto + +package pb + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type DialStatus int32 + +const ( + DialStatus_UNUSED DialStatus = 0 + DialStatus_E_DIAL_ERROR DialStatus = 100 + DialStatus_E_DIAL_BACK_ERROR DialStatus = 101 + DialStatus_OK DialStatus = 200 +) + +// Enum value maps for DialStatus. +var ( + DialStatus_name = map[int32]string{ + 0: "UNUSED", + 100: "E_DIAL_ERROR", + 101: "E_DIAL_BACK_ERROR", + 200: "OK", + } + DialStatus_value = map[string]int32{ + "UNUSED": 0, + "E_DIAL_ERROR": 100, + "E_DIAL_BACK_ERROR": 101, + "OK": 200, + } +) + +func (x DialStatus) Enum() *DialStatus { + p := new(DialStatus) + *p = x + return p +} + +func (x DialStatus) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (DialStatus) Descriptor() protoreflect.EnumDescriptor { + return file_pb_autonatv2_proto_enumTypes[0].Descriptor() +} + +func (DialStatus) Type() protoreflect.EnumType { + return &file_pb_autonatv2_proto_enumTypes[0] +} + +func (x DialStatus) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use DialStatus.Descriptor instead. +func (DialStatus) EnumDescriptor() ([]byte, []int) { + return file_pb_autonatv2_proto_rawDescGZIP(), []int{0} +} + +type DialResponse_ResponseStatus int32 + +const ( + DialResponse_E_INTERNAL_ERROR DialResponse_ResponseStatus = 0 + DialResponse_E_REQUEST_REJECTED DialResponse_ResponseStatus = 100 + DialResponse_E_DIAL_REFUSED DialResponse_ResponseStatus = 101 + DialResponse_OK DialResponse_ResponseStatus = 200 +) + +// Enum value maps for DialResponse_ResponseStatus. +var ( + DialResponse_ResponseStatus_name = map[int32]string{ + 0: "E_INTERNAL_ERROR", + 100: "E_REQUEST_REJECTED", + 101: "E_DIAL_REFUSED", + 200: "OK", + } + DialResponse_ResponseStatus_value = map[string]int32{ + "E_INTERNAL_ERROR": 0, + "E_REQUEST_REJECTED": 100, + "E_DIAL_REFUSED": 101, + "OK": 200, + } +) + +func (x DialResponse_ResponseStatus) Enum() *DialResponse_ResponseStatus { + p := new(DialResponse_ResponseStatus) + *p = x + return p +} + +func (x DialResponse_ResponseStatus) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (DialResponse_ResponseStatus) Descriptor() protoreflect.EnumDescriptor { + return file_pb_autonatv2_proto_enumTypes[1].Descriptor() +} + +func (DialResponse_ResponseStatus) Type() protoreflect.EnumType { + return &file_pb_autonatv2_proto_enumTypes[1] +} + +func (x DialResponse_ResponseStatus) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use DialResponse_ResponseStatus.Descriptor instead. +func (DialResponse_ResponseStatus) EnumDescriptor() ([]byte, []int) { + return file_pb_autonatv2_proto_rawDescGZIP(), []int{3, 0} +} + +type Message struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // Types that are assignable to Msg: + // + // *Message_DialRequest + // *Message_DialResponse + // *Message_DialDataRequest + // *Message_DialDataResponse + Msg isMessage_Msg `protobuf_oneof:"msg"` +} + +func (x *Message) Reset() { + *x = Message{} + if protoimpl.UnsafeEnabled { + mi := &file_pb_autonatv2_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Message) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Message) ProtoMessage() {} + +func (x *Message) ProtoReflect() protoreflect.Message { + mi := &file_pb_autonatv2_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Message.ProtoReflect.Descriptor instead. +func (*Message) Descriptor() ([]byte, []int) { + return file_pb_autonatv2_proto_rawDescGZIP(), []int{0} +} + +func (m *Message) GetMsg() isMessage_Msg { + if m != nil { + return m.Msg + } + return nil +} + +func (x *Message) GetDialRequest() *DialRequest { + if x, ok := x.GetMsg().(*Message_DialRequest); ok { + return x.DialRequest + } + return nil +} + +func (x *Message) GetDialResponse() *DialResponse { + if x, ok := x.GetMsg().(*Message_DialResponse); ok { + return x.DialResponse + } + return nil +} + +func (x *Message) GetDialDataRequest() *DialDataRequest { + if x, ok := x.GetMsg().(*Message_DialDataRequest); ok { + return x.DialDataRequest + } + return nil +} + +func (x *Message) GetDialDataResponse() *DialDataResponse { + if x, ok := x.GetMsg().(*Message_DialDataResponse); ok { + return x.DialDataResponse + } + return nil +} + +type isMessage_Msg interface { + isMessage_Msg() +} + +type Message_DialRequest struct { + DialRequest *DialRequest `protobuf:"bytes,1,opt,name=dialRequest,proto3,oneof"` +} + +type Message_DialResponse struct { + DialResponse *DialResponse `protobuf:"bytes,2,opt,name=dialResponse,proto3,oneof"` +} + +type Message_DialDataRequest struct { + DialDataRequest *DialDataRequest `protobuf:"bytes,3,opt,name=dialDataRequest,proto3,oneof"` +} + +type Message_DialDataResponse struct { + DialDataResponse *DialDataResponse `protobuf:"bytes,4,opt,name=dialDataResponse,proto3,oneof"` +} + +func (*Message_DialRequest) isMessage_Msg() {} + +func (*Message_DialResponse) isMessage_Msg() {} + +func (*Message_DialDataRequest) isMessage_Msg() {} + +func (*Message_DialDataResponse) isMessage_Msg() {} + +type DialRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Addrs [][]byte `protobuf:"bytes,1,rep,name=addrs,proto3" json:"addrs,omitempty"` + Nonce uint64 `protobuf:"fixed64,2,opt,name=nonce,proto3" json:"nonce,omitempty"` +} + +func (x *DialRequest) Reset() { + *x = DialRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_pb_autonatv2_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *DialRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DialRequest) ProtoMessage() {} + +func (x *DialRequest) ProtoReflect() protoreflect.Message { + mi := &file_pb_autonatv2_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DialRequest.ProtoReflect.Descriptor instead. +func (*DialRequest) Descriptor() ([]byte, []int) { + return file_pb_autonatv2_proto_rawDescGZIP(), []int{1} +} + +func (x *DialRequest) GetAddrs() [][]byte { + if x != nil { + return x.Addrs + } + return nil +} + +func (x *DialRequest) GetNonce() uint64 { + if x != nil { + return x.Nonce + } + return 0 +} + +type DialDataRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + AddrIdx uint32 `protobuf:"varint,1,opt,name=addrIdx,proto3" json:"addrIdx,omitempty"` + NumBytes uint64 `protobuf:"varint,2,opt,name=numBytes,proto3" json:"numBytes,omitempty"` +} + +func (x *DialDataRequest) Reset() { + *x = DialDataRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_pb_autonatv2_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *DialDataRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DialDataRequest) ProtoMessage() {} + +func (x *DialDataRequest) ProtoReflect() protoreflect.Message { + mi := &file_pb_autonatv2_proto_msgTypes[2] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DialDataRequest.ProtoReflect.Descriptor instead. +func (*DialDataRequest) Descriptor() ([]byte, []int) { + return file_pb_autonatv2_proto_rawDescGZIP(), []int{2} +} + +func (x *DialDataRequest) GetAddrIdx() uint32 { + if x != nil { + return x.AddrIdx + } + return 0 +} + +func (x *DialDataRequest) GetNumBytes() uint64 { + if x != nil { + return x.NumBytes + } + return 0 +} + +type DialResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Status DialResponse_ResponseStatus `protobuf:"varint,1,opt,name=status,proto3,enum=autonatv2.pb.DialResponse_ResponseStatus" json:"status,omitempty"` + AddrIdx uint32 `protobuf:"varint,2,opt,name=addrIdx,proto3" json:"addrIdx,omitempty"` + DialStatus DialStatus `protobuf:"varint,3,opt,name=dialStatus,proto3,enum=autonatv2.pb.DialStatus" json:"dialStatus,omitempty"` +} + +func (x *DialResponse) Reset() { + *x = DialResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_pb_autonatv2_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *DialResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DialResponse) ProtoMessage() {} + +func (x *DialResponse) ProtoReflect() protoreflect.Message { + mi := &file_pb_autonatv2_proto_msgTypes[3] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DialResponse.ProtoReflect.Descriptor instead. +func (*DialResponse) Descriptor() ([]byte, []int) { + return file_pb_autonatv2_proto_rawDescGZIP(), []int{3} +} + +func (x *DialResponse) GetStatus() DialResponse_ResponseStatus { + if x != nil { + return x.Status + } + return DialResponse_E_INTERNAL_ERROR +} + +func (x *DialResponse) GetAddrIdx() uint32 { + if x != nil { + return x.AddrIdx + } + return 0 +} + +func (x *DialResponse) GetDialStatus() DialStatus { + if x != nil { + return x.DialStatus + } + return DialStatus_UNUSED +} + +type DialDataResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Data []byte `protobuf:"bytes,1,opt,name=data,proto3" json:"data,omitempty"` +} + +func (x *DialDataResponse) Reset() { + *x = DialDataResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_pb_autonatv2_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *DialDataResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DialDataResponse) ProtoMessage() {} + +func (x *DialDataResponse) ProtoReflect() protoreflect.Message { + mi := &file_pb_autonatv2_proto_msgTypes[4] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DialDataResponse.ProtoReflect.Descriptor instead. +func (*DialDataResponse) Descriptor() ([]byte, []int) { + return file_pb_autonatv2_proto_rawDescGZIP(), []int{4} +} + +func (x *DialDataResponse) GetData() []byte { + if x != nil { + return x.Data + } + return nil +} + +type DialBack struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Nonce uint64 `protobuf:"fixed64,1,opt,name=nonce,proto3" json:"nonce,omitempty"` +} + +func (x *DialBack) Reset() { + *x = DialBack{} + if protoimpl.UnsafeEnabled { + mi := &file_pb_autonatv2_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *DialBack) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DialBack) ProtoMessage() {} + +func (x *DialBack) ProtoReflect() protoreflect.Message { + mi := &file_pb_autonatv2_proto_msgTypes[5] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DialBack.ProtoReflect.Descriptor instead. +func (*DialBack) Descriptor() ([]byte, []int) { + return file_pb_autonatv2_proto_rawDescGZIP(), []int{5} +} + +func (x *DialBack) GetNonce() uint64 { + if x != nil { + return x.Nonce + } + return 0 +} + +var File_pb_autonatv2_proto protoreflect.FileDescriptor + +var file_pb_autonatv2_proto_rawDesc = []byte{ + 0x0a, 0x12, 0x70, 0x62, 0x2f, 0x61, 0x75, 0x74, 0x6f, 0x6e, 0x61, 0x74, 0x76, 0x32, 0x2e, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0c, 0x61, 0x75, 0x74, 0x6f, 0x6e, 0x61, 0x74, 0x76, 0x32, 0x2e, + 0x70, 0x62, 0x22, 0xaa, 0x02, 0x0a, 0x07, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x3d, + 0x0a, 0x0b, 0x64, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x61, 0x75, 0x74, 0x6f, 0x6e, 0x61, 0x74, 0x76, 0x32, 0x2e, + 0x70, 0x62, 0x2e, 0x44, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x48, 0x00, + 0x52, 0x0b, 0x64, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x40, 0x0a, + 0x0c, 0x64, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x61, 0x75, 0x74, 0x6f, 0x6e, 0x61, 0x74, 0x76, 0x32, 0x2e, + 0x70, 0x62, 0x2e, 0x44, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x48, + 0x00, 0x52, 0x0c, 0x64, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, + 0x49, 0x0a, 0x0f, 0x64, 0x69, 0x61, 0x6c, 0x44, 0x61, 0x74, 0x61, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x61, 0x75, 0x74, 0x6f, 0x6e, + 0x61, 0x74, 0x76, 0x32, 0x2e, 0x70, 0x62, 0x2e, 0x44, 0x69, 0x61, 0x6c, 0x44, 0x61, 0x74, 0x61, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x48, 0x00, 0x52, 0x0f, 0x64, 0x69, 0x61, 0x6c, 0x44, + 0x61, 0x74, 0x61, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x4c, 0x0a, 0x10, 0x64, 0x69, + 0x61, 0x6c, 0x44, 0x61, 0x74, 0x61, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x18, 0x04, + 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x61, 0x75, 0x74, 0x6f, 0x6e, 0x61, 0x74, 0x76, 0x32, + 0x2e, 0x70, 0x62, 0x2e, 0x44, 0x69, 0x61, 0x6c, 0x44, 0x61, 0x74, 0x61, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x48, 0x00, 0x52, 0x10, 0x64, 0x69, 0x61, 0x6c, 0x44, 0x61, 0x74, 0x61, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x05, 0x0a, 0x03, 0x6d, 0x73, 0x67, 0x22, + 0x39, 0x0a, 0x0b, 0x44, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x14, + 0x0a, 0x05, 0x61, 0x64, 0x64, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0c, 0x52, 0x05, 0x61, + 0x64, 0x64, 0x72, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x6e, 0x6f, 0x6e, 0x63, 0x65, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x06, 0x52, 0x05, 0x6e, 0x6f, 0x6e, 0x63, 0x65, 0x22, 0x47, 0x0a, 0x0f, 0x44, 0x69, + 0x61, 0x6c, 0x44, 0x61, 0x74, 0x61, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x18, 0x0a, + 0x07, 0x61, 0x64, 0x64, 0x72, 0x49, 0x64, 0x78, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x07, + 0x61, 0x64, 0x64, 0x72, 0x49, 0x64, 0x78, 0x12, 0x1a, 0x0a, 0x08, 0x6e, 0x75, 0x6d, 0x42, 0x79, + 0x74, 0x65, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x04, 0x52, 0x08, 0x6e, 0x75, 0x6d, 0x42, 0x79, + 0x74, 0x65, 0x73, 0x22, 0x82, 0x02, 0x0a, 0x0c, 0x44, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x41, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x0e, 0x32, 0x29, 0x2e, 0x61, 0x75, 0x74, 0x6f, 0x6e, 0x61, 0x74, 0x76, 0x32, + 0x2e, 0x70, 0x62, 0x2e, 0x44, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, + 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, + 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x61, 0x64, 0x64, 0x72, 0x49, + 0x64, 0x78, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x07, 0x61, 0x64, 0x64, 0x72, 0x49, 0x64, + 0x78, 0x12, 0x38, 0x0a, 0x0a, 0x64, 0x69, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x61, 0x75, 0x74, 0x6f, 0x6e, 0x61, 0x74, 0x76, + 0x32, 0x2e, 0x70, 0x62, 0x2e, 0x44, 0x69, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, + 0x0a, 0x64, 0x69, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x22, 0x5b, 0x0a, 0x0e, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x14, 0x0a, + 0x10, 0x45, 0x5f, 0x49, 0x4e, 0x54, 0x45, 0x52, 0x4e, 0x41, 0x4c, 0x5f, 0x45, 0x52, 0x52, 0x4f, + 0x52, 0x10, 0x00, 0x12, 0x16, 0x0a, 0x12, 0x45, 0x5f, 0x52, 0x45, 0x51, 0x55, 0x45, 0x53, 0x54, + 0x5f, 0x52, 0x45, 0x4a, 0x45, 0x43, 0x54, 0x45, 0x44, 0x10, 0x64, 0x12, 0x12, 0x0a, 0x0e, 0x45, + 0x5f, 0x44, 0x49, 0x41, 0x4c, 0x5f, 0x52, 0x45, 0x46, 0x55, 0x53, 0x45, 0x44, 0x10, 0x65, 0x12, + 0x07, 0x0a, 0x02, 0x4f, 0x4b, 0x10, 0xc8, 0x01, 0x22, 0x26, 0x0a, 0x10, 0x44, 0x69, 0x61, 0x6c, + 0x44, 0x61, 0x74, 0x61, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, + 0x64, 0x61, 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x64, 0x61, 0x74, 0x61, + 0x22, 0x20, 0x0a, 0x08, 0x44, 0x69, 0x61, 0x6c, 0x42, 0x61, 0x63, 0x6b, 0x12, 0x14, 0x0a, 0x05, + 0x6e, 0x6f, 0x6e, 0x63, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x06, 0x52, 0x05, 0x6e, 0x6f, 0x6e, + 0x63, 0x65, 0x2a, 0x4a, 0x0a, 0x0a, 0x44, 0x69, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, + 0x12, 0x0a, 0x0a, 0x06, 0x55, 0x4e, 0x55, 0x53, 0x45, 0x44, 0x10, 0x00, 0x12, 0x10, 0x0a, 0x0c, + 0x45, 0x5f, 0x44, 0x49, 0x41, 0x4c, 0x5f, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x64, 0x12, 0x15, + 0x0a, 0x11, 0x45, 0x5f, 0x44, 0x49, 0x41, 0x4c, 0x5f, 0x42, 0x41, 0x43, 0x4b, 0x5f, 0x45, 0x52, + 0x52, 0x4f, 0x52, 0x10, 0x65, 0x12, 0x07, 0x0a, 0x02, 0x4f, 0x4b, 0x10, 0xc8, 0x01, 0x62, 0x06, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_pb_autonatv2_proto_rawDescOnce sync.Once + file_pb_autonatv2_proto_rawDescData = file_pb_autonatv2_proto_rawDesc +) + +func file_pb_autonatv2_proto_rawDescGZIP() []byte { + file_pb_autonatv2_proto_rawDescOnce.Do(func() { + file_pb_autonatv2_proto_rawDescData = protoimpl.X.CompressGZIP(file_pb_autonatv2_proto_rawDescData) + }) + return file_pb_autonatv2_proto_rawDescData +} + +var file_pb_autonatv2_proto_enumTypes = make([]protoimpl.EnumInfo, 2) +var file_pb_autonatv2_proto_msgTypes = make([]protoimpl.MessageInfo, 6) +var file_pb_autonatv2_proto_goTypes = []interface{}{ + (DialStatus)(0), // 0: autonatv2.pb.DialStatus + (DialResponse_ResponseStatus)(0), // 1: autonatv2.pb.DialResponse.ResponseStatus + (*Message)(nil), // 2: autonatv2.pb.Message + (*DialRequest)(nil), // 3: autonatv2.pb.DialRequest + (*DialDataRequest)(nil), // 4: autonatv2.pb.DialDataRequest + (*DialResponse)(nil), // 5: autonatv2.pb.DialResponse + (*DialDataResponse)(nil), // 6: autonatv2.pb.DialDataResponse + (*DialBack)(nil), // 7: autonatv2.pb.DialBack +} +var file_pb_autonatv2_proto_depIdxs = []int32{ + 3, // 0: autonatv2.pb.Message.dialRequest:type_name -> autonatv2.pb.DialRequest + 5, // 1: autonatv2.pb.Message.dialResponse:type_name -> autonatv2.pb.DialResponse + 4, // 2: autonatv2.pb.Message.dialDataRequest:type_name -> autonatv2.pb.DialDataRequest + 6, // 3: autonatv2.pb.Message.dialDataResponse:type_name -> autonatv2.pb.DialDataResponse + 1, // 4: autonatv2.pb.DialResponse.status:type_name -> autonatv2.pb.DialResponse.ResponseStatus + 0, // 5: autonatv2.pb.DialResponse.dialStatus:type_name -> autonatv2.pb.DialStatus + 6, // [6:6] is the sub-list for method output_type + 6, // [6:6] is the sub-list for method input_type + 6, // [6:6] is the sub-list for extension type_name + 6, // [6:6] is the sub-list for extension extendee + 0, // [0:6] is the sub-list for field type_name +} + +func init() { file_pb_autonatv2_proto_init() } +func file_pb_autonatv2_proto_init() { + if File_pb_autonatv2_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_pb_autonatv2_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Message); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pb_autonatv2_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*DialRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pb_autonatv2_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*DialDataRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pb_autonatv2_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*DialResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pb_autonatv2_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*DialDataResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pb_autonatv2_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*DialBack); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + file_pb_autonatv2_proto_msgTypes[0].OneofWrappers = []interface{}{ + (*Message_DialRequest)(nil), + (*Message_DialResponse)(nil), + (*Message_DialDataRequest)(nil), + (*Message_DialDataResponse)(nil), + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_pb_autonatv2_proto_rawDesc, + NumEnums: 2, + NumMessages: 6, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_pb_autonatv2_proto_goTypes, + DependencyIndexes: file_pb_autonatv2_proto_depIdxs, + EnumInfos: file_pb_autonatv2_proto_enumTypes, + MessageInfos: file_pb_autonatv2_proto_msgTypes, + }.Build() + File_pb_autonatv2_proto = out.File + file_pb_autonatv2_proto_rawDesc = nil + file_pb_autonatv2_proto_goTypes = nil + file_pb_autonatv2_proto_depIdxs = nil +} diff --git a/p2p/protocol/autonatv2/pb/autonatv2.proto b/p2p/protocol/autonatv2/pb/autonatv2.proto new file mode 100644 index 0000000000..1a02286060 --- /dev/null +++ b/p2p/protocol/autonatv2/pb/autonatv2.proto @@ -0,0 +1,55 @@ +syntax = "proto3"; + +package autonatv2.pb; + +message Message { + oneof msg { + DialRequest dialRequest = 1; + DialResponse dialResponse = 2; + DialDataRequest dialDataRequest = 3; + DialDataResponse dialDataResponse = 4; + } +} + +message DialRequest { + repeated bytes addrs = 1; + fixed64 nonce = 2; +} + + +message DialDataRequest { + uint32 addrIdx = 1; + uint64 numBytes = 2; +} + + +enum DialStatus { + UNUSED = 0; + E_DIAL_ERROR = 100; + E_DIAL_BACK_ERROR = 101; + OK = 200; +} + + +message DialResponse { + enum ResponseStatus { + E_INTERNAL_ERROR = 0; + E_REQUEST_REJECTED = 100; + E_DIAL_REFUSED = 101; + OK = 200; + } + + ResponseStatus status = 1; + uint32 addrIdx = 2; + DialStatus dialStatus = 3; +} + + +message DialDataResponse { + bytes data = 1; +} + + +message DialBack { + fixed64 nonce = 1; +} \ No newline at end of file diff --git a/p2p/protocol/autonatv2/server.go b/p2p/protocol/autonatv2/server.go new file mode 100644 index 0000000000..5e788173cb --- /dev/null +++ b/p2p/protocol/autonatv2/server.go @@ -0,0 +1,362 @@ +package autonatv2 + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/peerstore" + "github.com/libp2p/go-libp2p/p2p/protocol/autonatv2/pb" + "github.com/libp2p/go-msgio/pbio" + + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" + "golang.org/x/exp/rand" +) + +type dataRequestPolicyFunc = func(s network.Stream, dialAddr ma.Multiaddr) bool + +// server implements the AutoNATv2 server. +// It can ask client to provide dial data before attempting the requested dial. +// It rate limits requests on a global level, per peer level and on whether the request requires dial data. +// +// This uses the host's dialer as well as the dialerHost's dialer to determine whether an address is +// dialable. +type server struct { + host host.Host + dialerHost host.Host + limiter *rateLimiter + + // dialDataRequestPolicy is used to determine whether dialing the address requires receiving dial data. + // It is set to amplification attack prevention by default. + dialDataRequestPolicy dataRequestPolicyFunc + + // for tests + now func() time.Time + allowAllAddrs bool +} + +func newServer(host, dialer host.Host, s *autoNATSettings) *server { + return &server{ + dialerHost: dialer, + host: host, + dialDataRequestPolicy: s.dataRequestPolicy, + allowAllAddrs: s.allowAllAddrs, + limiter: &rateLimiter{ + RPM: s.serverRPM, + PerPeerRPM: s.serverPerPeerRPM, + DialDataRPM: s.serverDialDataRPM, + now: s.now, + }, + now: s.now, + } +} + +// Enable attaches the stream handler to the host. +func (as *server) Enable() { + as.host.SetStreamHandler(DialProtocol, as.handleDialRequest) +} + +// Disable removes the stream handles from the host. +func (as *server) Disable() { + as.host.RemoveStreamHandler(DialProtocol) +} + +func (as *server) Close() { + as.dialerHost.Close() +} + +// handleDialRequest is the dial-request protocol stream handler +func (as *server) handleDialRequest(s network.Stream) { + if err := s.Scope().SetService(ServiceName); err != nil { + s.Reset() + log.Debugf("failed to attach stream to service %s: %w", ServiceName, err) + return + } + + if err := s.Scope().ReserveMemory(maxMsgSize, network.ReservationPriorityAlways); err != nil { + s.Reset() + log.Debugf("failed to reserve memory for stream %s: %w", DialProtocol, err) + return + } + defer s.Scope().ReleaseMemory(maxMsgSize) + + s.SetDeadline(as.now().Add(streamTimeout)) + defer s.Close() + + p := s.Conn().RemotePeer() + r := pbio.NewDelimitedReader(s, maxMsgSize) + var msg pb.Message + if err := r.ReadMsg(&msg); err != nil { + s.Reset() + log.Debugf("failed to read request from %s: %s", p, err) + return + } + if msg.GetDialRequest() == nil { + s.Reset() + log.Debugf("invalid message type from %s: %T", p, msg.Msg) + return + } + + nonce := msg.GetDialRequest().Nonce + // parse peer's addresses + var dialAddr ma.Multiaddr + var addrIdx int + for i, ab := range msg.GetDialRequest().GetAddrs() { + if i >= maxPeerAddresses { + break + } + a, err := ma.NewMultiaddrBytes(ab) + if err != nil { + continue + } + if (!as.allowAllAddrs && !manet.IsPublicAddr(a)) || + (as.dialerHost.Network().CanDial(p, a) != network.DialabilityDialable) { + continue + } + // Check if the host can dial the address. This check ensures that we do not + // attempt dialing an IPv6 address if we have no IPv6 connectivity as the host dialer's + // black hole detector is likely to be more accurate. + if as.host.Network().CanDial(p, a) != network.DialabilityDialable { + continue + } + dialAddr = a + addrIdx = i + break + } + + w := pbio.NewDelimitedWriter(s) + // No dialable address + if dialAddr == nil { + msg = pb.Message{ + Msg: &pb.Message_DialResponse{ + DialResponse: &pb.DialResponse{ + Status: pb.DialResponse_E_DIAL_REFUSED, + }, + }, + } + if err := w.WriteMsg(&msg); err != nil { + s.Reset() + log.Debugf("failed to write response to %s: %s", p, err) + return + } + return + } + + isDialDataRequired := as.dialDataRequestPolicy(s, dialAddr) + + if !as.limiter.Accept(p, isDialDataRequired) { + msg = pb.Message{ + Msg: &pb.Message_DialResponse{ + DialResponse: &pb.DialResponse{ + Status: pb.DialResponse_E_REQUEST_REJECTED, + }, + }, + } + if err := w.WriteMsg(&msg); err != nil { + s.Reset() + log.Debugf("failed to write response to %s: %s", p, err) + return + } + log.Debugf("rejecting request from %s: rate limit exceeded", p) + return + } + defer as.limiter.CompleteRequest(p) + + if isDialDataRequired { + if err := getDialData(w, r, &msg, addrIdx); err != nil { + s.Reset() + log.Debugf("%s refused dial data request: %s", p, err) + return + } + } + + dialStatus := as.dialBack(s.Conn().RemotePeer(), dialAddr, nonce) + msg = pb.Message{ + Msg: &pb.Message_DialResponse{ + DialResponse: &pb.DialResponse{ + Status: pb.DialResponse_OK, + DialStatus: dialStatus, + AddrIdx: uint32(addrIdx), + }, + }, + } + if err := w.WriteMsg(&msg); err != nil { + s.Reset() + log.Debugf("failed to write response to %s: %s", p, err) + return + } +} + +// getDialData gets data from the client for dialing the address +func getDialData(w pbio.Writer, r pbio.Reader, msg *pb.Message, addrIdx int) error { + numBytes := minHandshakeSizeBytes + rand.Intn(maxHandshakeSizeBytes-minHandshakeSizeBytes) + *msg = pb.Message{ + Msg: &pb.Message_DialDataRequest{ + DialDataRequest: &pb.DialDataRequest{ + AddrIdx: uint32(addrIdx), + NumBytes: uint64(numBytes), + }, + }, + } + if err := w.WriteMsg(msg); err != nil { + return fmt.Errorf("dial data write: %w", err) + } + for remain := numBytes; remain > 0; { + if err := r.ReadMsg(msg); err != nil { + return fmt.Errorf("dial data read: %w", err) + } + if msg.GetDialDataResponse() == nil { + return fmt.Errorf("invalid msg type %T", msg.Msg) + } + remain -= len(msg.GetDialDataResponse().Data) + } + return nil +} + +func (as *server) dialBack(p peer.ID, addr ma.Multiaddr, nonce uint64) pb.DialStatus { + ctx, cancel := context.WithTimeout(context.Background(), dialBackDialTimeout) + ctx = network.WithForceDirectDial(ctx, "autonatv2") + as.dialerHost.Peerstore().AddAddr(p, addr, peerstore.TempAddrTTL) + defer func() { + cancel() + as.dialerHost.Network().ClosePeer(p) + as.dialerHost.Peerstore().ClearAddrs(p) + as.dialerHost.Peerstore().RemovePeer(p) + }() + + err := as.dialerHost.Connect(ctx, peer.AddrInfo{ID: p}) + if err != nil { + return pb.DialStatus_E_DIAL_ERROR + } + + s, err := as.dialerHost.NewStream(ctx, p, DialBackProtocol) + if err != nil { + return pb.DialStatus_E_DIAL_BACK_ERROR + } + + defer s.Close() + s.SetDeadline(as.now().Add(dialBackStreamTimeout)) + + w := pbio.NewDelimitedWriter(s) + if err := w.WriteMsg(&pb.DialBack{Nonce: nonce}); err != nil { + s.Reset() + return pb.DialStatus_E_DIAL_BACK_ERROR + } + + // Since the underlying connection is on a separate dialer, it'll be closed after this function returns. + // Connection close will drop all the queued writes. To ensure message delivery, do a CloseWrite and + // wait a second for the peer to Close its end of the stream. + s.CloseWrite() + s.SetDeadline(as.now().Add(1 * time.Second)) + b := make([]byte, 1) // Read 1 byte here because 0 len reads are free to return (0, nil) immediately + s.Read(b) + + return pb.DialStatus_OK +} + +// rateLimiter implements a sliding window rate limit of requests per minute. It allows 1 concurrent request +// per peer. It rate limits requests globally, at a peer level and depending on whether it requires dial data. +type rateLimiter struct { + // PerPeerRPM is the rate limit per peer + PerPeerRPM int + // RPM is the global rate limit + RPM int + // DialDataRPM is the rate limit for requests that require dial data + DialDataRPM int + + mu sync.Mutex + reqs []time.Time + peerReqs map[peer.ID][]time.Time + dialDataReqs []time.Time + // ongoingReqs tracks in progress requests. This is used to disallow multiple concurrent requests by the + // same peer + ongoingReqs map[peer.ID]struct{} + + now func() time.Time // for tests +} + +func (r *rateLimiter) Accept(p peer.ID, requiresData bool) bool { + r.mu.Lock() + defer r.mu.Unlock() + if r.peerReqs == nil { + r.peerReqs = make(map[peer.ID][]time.Time) + r.ongoingReqs = make(map[peer.ID]struct{}) + } + + nw := r.now() + r.cleanup(p, nw) + + if _, ok := r.ongoingReqs[p]; ok { + return false + } + if len(r.reqs) >= r.RPM || len(r.peerReqs[p]) >= r.PerPeerRPM { + return false + } + if requiresData && len(r.dialDataReqs) >= r.DialDataRPM { + return false + } + + r.ongoingReqs[p] = struct{}{} + r.reqs = append(r.reqs, nw) + r.peerReqs[p] = append(r.peerReqs[p], nw) + if requiresData { + r.dialDataReqs = append(r.dialDataReqs, nw) + } + return true +} + +// cleanup removes stale requests. +// +// This is fast enough in rate limited cases and the state is small enough to +// clean up quickly when blocking requests. +func (r *rateLimiter) cleanup(p peer.ID, now time.Time) { + idx := len(r.reqs) + for i, t := range r.reqs { + if now.Sub(t) < time.Minute { + idx = i + break + } + } + r.reqs = r.reqs[idx:] + + idx = len(r.dialDataReqs) + for i, t := range r.dialDataReqs { + if now.Sub(t) < time.Minute { + idx = i + break + } + } + r.dialDataReqs = r.dialDataReqs[idx:] + + idx = len(r.peerReqs[p]) + for i, t := range r.peerReqs[p] { + if now.Sub(t) < time.Minute { + idx = i + break + } + } + r.peerReqs[p] = r.peerReqs[p][idx:] +} + +func (r *rateLimiter) CompleteRequest(p peer.ID) { + r.mu.Lock() + defer r.mu.Unlock() + + delete(r.ongoingReqs, p) +} + +// amplificationAttackPrevention is a dialDataRequestPolicy which requests data when the peer's observed +// IP address is different from the dial back IP address +func amplificationAttackPrevention(s network.Stream, dialAddr ma.Multiaddr) bool { + connIP, err := manet.ToIP(s.Conn().RemoteMultiaddr()) + if err != nil { + return true + } + dialIP, _ := manet.ToIP(s.Conn().LocalMultiaddr()) // must be an IP multiaddr + return !connIP.Equal(dialIP) +} diff --git a/p2p/protocol/autonatv2/server_test.go b/p2p/protocol/autonatv2/server_test.go new file mode 100644 index 0000000000..00409bac6d --- /dev/null +++ b/p2p/protocol/autonatv2/server_test.go @@ -0,0 +1,187 @@ +package autonatv2 + +import ( + "context" + "testing" + "time" + + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/test" + bhost "github.com/libp2p/go-libp2p/p2p/host/blank" + swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" + "github.com/libp2p/go-libp2p/p2p/protocol/autonatv2/pb" + ma "github.com/multiformats/go-multiaddr" + "github.com/stretchr/testify/require" +) + +func newTestRequests(addrs []ma.Multiaddr, sendDialData bool) (reqs []Request) { + reqs = make([]Request, len(addrs)) + for i := 0; i < len(addrs); i++ { + reqs[i] = Request{Addr: addrs[i], SendDialData: sendDialData} + } + return +} + +func TestServerInvalidAddrsRejected(t *testing.T) { + c := newAutoNAT(t, nil, allowAllAddrs) + defer c.Close() + defer c.host.Close() + + t.Run("no transport", func(t *testing.T) { + dialer := bhost.NewBlankHost(swarmt.GenSwarm(t, swarmt.OptDisableQUIC, swarmt.OptDisableTCP)) + an := newAutoNAT(t, dialer, allowAllAddrs) + defer an.Close() + defer an.host.Close() + + idAndWait(t, c, an) + + res, err := c.CheckReachability(context.Background(), newTestRequests(c.host.Addrs(), true)) + require.ErrorIs(t, err, ErrDialRefused) + require.Equal(t, Result{}, res) + }) + + t.Run("private addrs", func(t *testing.T) { + an := newAutoNAT(t, nil) + defer an.Close() + defer an.host.Close() + + idAndWait(t, c, an) + + res, err := c.CheckReachability(context.Background(), newTestRequests(c.host.Addrs(), true)) + require.ErrorIs(t, err, ErrDialRefused) + require.Equal(t, Result{}, res) + }) +} + +func TestServerDataRequest(t *testing.T) { + // server will skip all tcp addresses + dialer := bhost.NewBlankHost(swarmt.GenSwarm(t, swarmt.OptDisableTCP)) + // ask for dial data for quic address + an := newAutoNAT(t, dialer, allowAllAddrs, withDataRequestPolicy( + func(s network.Stream, dialAddr ma.Multiaddr) bool { + if _, err := dialAddr.ValueForProtocol(ma.P_QUIC_V1); err == nil { + return true + } + return false + }), + WithServerRateLimit(10, 10, 10), + ) + defer an.Close() + defer an.host.Close() + + c := newAutoNAT(t, nil, allowAllAddrs) + defer c.Close() + defer c.host.Close() + + idAndWait(t, c, an) + + var quicAddr, tcpAddr ma.Multiaddr + for _, a := range c.host.Addrs() { + if _, err := a.ValueForProtocol(ma.P_QUIC_V1); err == nil { + quicAddr = a + } else if _, err := a.ValueForProtocol(ma.P_TCP); err == nil { + tcpAddr = a + } + } + + _, err := c.CheckReachability(context.Background(), []Request{{Addr: tcpAddr, SendDialData: true}, {Addr: quicAddr}}) + require.Error(t, err) + + res, err := c.CheckReachability(context.Background(), []Request{{Addr: quicAddr, SendDialData: true}, {Addr: tcpAddr}}) + require.NoError(t, err) + + require.Equal(t, Result{ + Idx: 0, + Addr: quicAddr, + Reachability: network.ReachabilityPublic, + Status: pb.DialStatus_OK, + }, res) +} + +func TestServerDial(t *testing.T) { + an := newAutoNAT(t, nil, WithServerRateLimit(10, 10, 10), allowAllAddrs) + defer an.Close() + defer an.host.Close() + + c := newAutoNAT(t, nil, allowAllAddrs) + defer c.Close() + defer c.host.Close() + + idAndWait(t, c, an) + + unreachableAddr := ma.StringCast("/ip4/1.2.3.4/tcp/2") + hostAddrs := c.host.Addrs() + + t.Run("unreachable addr", func(t *testing.T) { + res, err := c.CheckReachability(context.Background(), + append([]Request{{Addr: unreachableAddr, SendDialData: true}}, newTestRequests(hostAddrs, false)...)) + require.NoError(t, err) + require.Equal(t, Result{ + Idx: 0, + Addr: unreachableAddr, + Reachability: network.ReachabilityPrivate, + Status: pb.DialStatus_E_DIAL_ERROR, + }, res) + }) + + t.Run("reachable addr", func(t *testing.T) { + res, err := c.CheckReachability(context.Background(), newTestRequests(c.host.Addrs(), false)) + require.NoError(t, err) + require.Equal(t, Result{ + Idx: 0, + Addr: hostAddrs[0], + Reachability: network.ReachabilityPublic, + Status: pb.DialStatus_OK, + }, res) + }) + + t.Run("dialback error", func(t *testing.T) { + c.host.RemoveStreamHandler(DialBackProtocol) + res, err := c.CheckReachability(context.Background(), newTestRequests(c.host.Addrs(), false)) + require.NoError(t, err) + require.Equal(t, Result{ + Idx: 0, + Addr: hostAddrs[0], + Reachability: network.ReachabilityUnknown, + Status: pb.DialStatus_E_DIAL_BACK_ERROR, + }, res) + }) +} + +func TestRateLimiter(t *testing.T) { + cl := test.NewMockClock() + r := rateLimiter{RPM: 3, PerPeerRPM: 2, DialDataRPM: 1, now: cl.Now} + + require.True(t, r.Accept("peer1", false)) + + cl.AdvanceBy(10 * time.Second) + require.False(t, r.Accept("peer1", false)) // first request is still active + r.CompleteRequest("peer1") + + require.True(t, r.Accept("peer1", false)) + r.CompleteRequest("peer1") + + cl.AdvanceBy(10 * time.Second) + require.False(t, r.Accept("peer1", false)) + + cl.AdvanceBy(10 * time.Second) + require.True(t, r.Accept("peer2", false)) + r.CompleteRequest("peer2") + + cl.AdvanceBy(10 * time.Second) + require.False(t, r.Accept("peer3", false)) + + cl.AdvanceBy(21 * time.Second) // first request expired + require.True(t, r.Accept("peer1", false)) + r.CompleteRequest("peer1") + + cl.AdvanceBy(10 * time.Second) + require.True(t, r.Accept("peer3", true)) + r.CompleteRequest("peer3") + + cl.AdvanceBy(50 * time.Second) + require.False(t, r.Accept("peer3", true)) + + cl.AdvanceBy(11 * time.Second) + require.True(t, r.Accept("peer3", true)) +}