diff --git a/p2p/host/basic/address_service.go b/p2p/host/basic/address_service.go new file mode 100644 index 0000000000..fb1225d76a --- /dev/null +++ b/p2p/host/basic/address_service.go @@ -0,0 +1,470 @@ +package basichost + +import ( + "context" + "net" + "slices" + "sync" + "time" + + "github.com/libp2p/go-libp2p/core/event" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/transport" + "github.com/libp2p/go-libp2p/p2p/host/basic/internal/backoff" + libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc" + libp2pwebtransport "github.com/libp2p/go-libp2p/p2p/transport/webtransport" + "github.com/libp2p/go-netroute" + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" +) + +type observedAddrsService interface { + OwnObservedAddrs() []ma.Multiaddr + ObservedAddrsFor(local ma.Multiaddr) []ma.Multiaddr +} + +type addressService struct { + net network.Network + addrsFactory AddrsFactory + natmgr NATManager + observedAddrsService observedAddrsService + addrsChangeChan chan struct{} + addrsUpdated chan struct{} + autoRelayAddrsSub event.Subscription + // There are wrapped in to functions for mocking + autoRelayAddrs func() []ma.Multiaddr + reachability func() network.Reachability + ifaceAddrs *interfaceAddrsCache + wg sync.WaitGroup + ctx context.Context + ctxCancel context.CancelFunc +} + +func NewAddressService(h *BasicHost, natmgr func(network.Network) NATManager, + addrFactory AddrsFactory) (*addressService, error) { + var nmgr NATManager + if natmgr != nil { + nmgr = natmgr(h.Network()) + } + addrSub, err := h.EventBus().Subscribe(new(event.EvtAutoRelayAddrs)) + if err != nil { + return nil, err + } + + var autoRelayAddrs func() []ma.Multiaddr + if h.autorelay != nil { + autoRelayAddrs = h.autorelay.RelayAddrs + } + + ctx, cancel := context.WithCancel(context.Background()) + as := &addressService{ + net: h.Network(), + observedAddrsService: h.IDService(), + natmgr: nmgr, + addrsFactory: addrFactory, + addrsChangeChan: make(chan struct{}, 1), + addrsUpdated: make(chan struct{}, 1), + autoRelayAddrsSub: addrSub, + autoRelayAddrs: autoRelayAddrs, + ifaceAddrs: &interfaceAddrsCache{}, + reachability: func() network.Reachability { + if h.GetAutoNat() != nil { + return h.GetAutoNat().Status() + } + return network.ReachabilityUnknown + }, + ctx: ctx, + ctxCancel: cancel, + } + return as, nil +} + +func (a *addressService) Start() { + a.wg.Add(1) + go a.background() +} + +func (a *addressService) Close() { + a.ctxCancel() + a.wg.Wait() + if a.natmgr != nil { + err := a.natmgr.Close() + if err != nil { + log.Warnf("error closing natmgr: %s", err) + } + } + err := a.autoRelayAddrsSub.Close() + if err != nil { + log.Warnf("error closing addrs update emitter: %s", err) + } +} + +func (a *addressService) SignalAddressChange() { + select { + case a.addrsChangeChan <- struct{}{}: + default: + } +} + +func (a *addressService) AddrsUpdated() chan struct{} { + return a.addrsUpdated +} + +func (a *addressService) background() { + defer a.wg.Done() + + var prev []ma.Multiaddr + + ticker := time.NewTicker(addrChangeTickrInterval) + defer ticker.Stop() + for { + curr := a.Addrs() + if a.areAddrsDifferent(prev, curr) { + select { + case a.addrsUpdated <- struct{}{}: + default: + } + } + prev = curr + + select { + case <-ticker.C: + case <-a.addrsChangeChan: + case <-a.autoRelayAddrsSub.Out(): + case <-a.ctx.Done(): + return + } + } +} + +// Addrs returns the node's dialable addresses both public and private. +// If autorealy is enabled and node reachability is private, it returns +// the node's relay addresses and private network addresses. +func (a *addressService) Addrs() []ma.Multiaddr { + addrs := a.AllAddrs() + // Delete public addresses if the node's reachability is private, and we have autorelay. + if a.reachability() == network.ReachabilityPrivate && a.autoRelayAddrs != nil { + addrs = slices.DeleteFunc(addrs, func(a ma.Multiaddr) bool { return manet.IsPublicAddr(a) }) + addrs = append(addrs, a.autoRelayAddrs()...) + } + // Make a copy. Consumers can modify the slice elements + addrs = slices.Clone(a.addrsFactory(addrs)) + // Add certhashes for the addresses provided by the user via address factory. + return a.addCertHashes(ma.Unique(addrs)) +} + +// GetHolePunchAddrs returns the node's public direct listen addresses. +func (a *addressService) GetHolePunchAddrs() []ma.Multiaddr { + addrs := a.AllAddrs() + addrs = slices.Clone(a.addrsFactory(addrs)) + // AllAddrs may ignore observed addresses in favour of NAT mappings. + // Use both for hole punching. + addrs = append(addrs, a.observedAddrsService.OwnObservedAddrs()...) + addrs = ma.Unique(addrs) + return slices.DeleteFunc(addrs, func(a ma.Multiaddr) bool { return !manet.IsPublicAddr(a) }) +} + +var p2pCircuitAddr = ma.StringCast("/p2p-circuit") + +// AllAddrs returns all the addresses the host is listening on except circuit addresses. +func (a *addressService) AllAddrs() []ma.Multiaddr { + listenAddrs := a.net.ListenAddresses() + if len(listenAddrs) == 0 { + return nil + } + + finalAddrs := make([]ma.Multiaddr, 0, 8) + finalAddrs = a.appendInterfaceAddrs(finalAddrs, listenAddrs) + finalAddrs = a.appendNATAddrs(finalAddrs, listenAddrs) + finalAddrs = ma.Unique(finalAddrs) + + // Remove "/p2p-circuit" addresses from the list. + // The p2p-circuit listener reports its address as just /p2p-circuit. This is + // useless for dialing. Users need to manage their circuit addresses themselves, + // or use AutoRelay. + finalAddrs = slices.DeleteFunc(finalAddrs, func(a ma.Multiaddr) bool { + return a.Equal(p2pCircuitAddr) + }) + + // Remove any unspecified address from the list + finalAddrs = slices.DeleteFunc(finalAddrs, func(a ma.Multiaddr) bool { + return manet.IsIPUnspecified(a) + }) + + // Add certhashes for /webrtc-direct, /webtransport, etc addresses discovered + // using identify. + finalAddrs = a.addCertHashes(finalAddrs) + return finalAddrs +} + +func (a *addressService) appendInterfaceAddrs(result []ma.Multiaddr, listenAddrs []ma.Multiaddr) []ma.Multiaddr { + // resolving any unspecified listen addressees to use only the primary + // interface to avoid advertising too many addresses. + if resolved, err := manet.ResolveUnspecifiedAddresses(listenAddrs, a.ifaceAddrs.Filtered()); err != nil { + log.Warnw("failed to resolve listen addrs", "error", err) + } else { + result = append(result, resolved...) + } + result = ma.Unique(result) + return result +} + +// appendNATAddrs appends the NAT-ed addrs for the listenAddrs. For unspecified listen addrs it appends the +// public address for all the interfaces. +// This automatically infers addresses from other transport addresses. For example, it'll infer a webtransport +// address from a quic observed address. +// +// TODO: Merge the natmgr and identify.ObservedAddrManager in to one NatMapper module. +func (a *addressService) appendNATAddrs(result []ma.Multiaddr, listenAddrs []ma.Multiaddr) []ma.Multiaddr { + ifaceAddrs := a.ifaceAddrs.All() + if a.natmgr == nil || !a.natmgr.HasDiscoveredNAT() { + if a.observedAddrsService != nil { + result = append(result, a.observedAddrsService.OwnObservedAddrs()...) + } + return result + } + for _, listen := range listenAddrs { + extMaddr := a.natmgr.GetMapping(listen) + result = appendNATAddrsForListenAddrs(result, listen, extMaddr, a.observedAddrsService.ObservedAddrsFor, ifaceAddrs) + } + return result +} + +func (a *addressService) addCertHashes(addrs []ma.Multiaddr) []ma.Multiaddr { + // This is a temporary workaround/hack that fixes #2233. Once we have a + // proper address pipeline, rework this. See the issue for more context. + type transportForListeninger interface { + TransportForListening(a ma.Multiaddr) transport.Transport + } + + type addCertHasher interface { + AddCertHashes(m ma.Multiaddr) (ma.Multiaddr, bool) + } + + s, ok := a.net.(transportForListeninger) + if !ok { + return addrs + } + + for i, addr := range addrs { + wtOK, wtN := libp2pwebtransport.IsWebtransportMultiaddr(addr) + webrtcOK, webrtcN := libp2pwebrtc.IsWebRTCDirectMultiaddr(addr) + if (wtOK && wtN == 0) || (webrtcOK && webrtcN == 0) { + t := s.TransportForListening(addr) + tpt, ok := t.(addCertHasher) + if !ok { + continue + } + addrWithCerthash, added := tpt.AddCertHashes(addr) + if !added { + log.Debugf("Couldn't add certhashes to multiaddr: %s", addr) + continue + } + addrs[i] = addrWithCerthash + } + } + return addrs +} + +func (a *addressService) areAddrsDifferent(prev, current []ma.Multiaddr) bool { + prevmap := make(map[string]struct{}) + currmap := make(map[string]struct{}) + for _, p := range prev { + prevmap[string(p.Bytes())] = struct{}{} + } + for _, c := range current { + currmap[string(c.Bytes())] = struct{}{} + } + for p := range prevmap { + if _, ok := currmap[p]; !ok { + return true + } + } + for c := range currmap { + if _, ok := prevmap[c]; !ok { + return true + } + } + return false +} + +const ifaceAddrsTTL = time.Minute + +type interfaceAddrsCache struct { + mx sync.RWMutex + filtered []ma.Multiaddr + all []ma.Multiaddr + updateLocalIPv4Backoff backoff.ExpBackoff + updateLocalIPv6Backoff backoff.ExpBackoff + lastUpdated time.Time +} + +func (i *interfaceAddrsCache) Filtered() []ma.Multiaddr { + i.mx.RLock() + if time.Now().After(i.lastUpdated.Add(ifaceAddrsTTL)) { + i.mx.RUnlock() + return i.update(true) + } + defer i.mx.RUnlock() + return i.filtered +} + +func (i *interfaceAddrsCache) All() []ma.Multiaddr { + i.mx.RLock() + if time.Now().After(i.lastUpdated.Add(ifaceAddrsTTL)) { + i.mx.RUnlock() + return i.update(false) + } + defer i.mx.RUnlock() + return i.all +} + +func (i *interfaceAddrsCache) update(filtered bool) []ma.Multiaddr { + i.mx.Lock() + defer i.mx.Unlock() + if !time.Now().After(i.lastUpdated.Add(ifaceAddrsTTL)) { + if filtered { + return i.filtered + } + return i.all + } + i.updateUnlocked() + i.lastUpdated = time.Now() + if filtered { + return i.filtered + } + return i.all +} + +func (i *interfaceAddrsCache) updateUnlocked() { + i.filtered = nil + i.all = nil + + // Try to use the default ipv4/6 addresses. + // TODO: Remove this. We should advertise all interface addresses. + if r, err := netroute.New(); err != nil { + log.Debugw("failed to build Router for kernel's routing table", "error", err) + } else { + + var localIPv4 net.IP + var ran bool + err, ran = i.updateLocalIPv4Backoff.Run(func() error { + _, _, localIPv4, err = r.Route(net.IPv4zero) + return err + }) + + if ran && err != nil { + log.Debugw("failed to fetch local IPv4 address", "error", err) + } else if ran && localIPv4.IsGlobalUnicast() { + maddr, err := manet.FromIP(localIPv4) + if err == nil { + i.filtered = append(i.filtered, maddr) + } + } + + var localIPv6 net.IP + err, ran = i.updateLocalIPv6Backoff.Run(func() error { + _, _, localIPv6, err = r.Route(net.IPv6unspecified) + return err + }) + + if ran && err != nil { + log.Debugw("failed to fetch local IPv6 address", "error", err) + } else if ran && localIPv6.IsGlobalUnicast() { + maddr, err := manet.FromIP(localIPv6) + if err == nil { + i.filtered = append(i.filtered, maddr) + } + } + } + + // Resolve the interface addresses + ifaceAddrs, err := manet.InterfaceMultiaddrs() + if err != nil { + // This usually shouldn't happen, but we could be in some kind + // of funky restricted environment. + log.Errorw("failed to resolve local interface addresses", "error", err) + + // Add the loopback addresses to the filtered addrs and use them as the non-filtered addrs. + // Then bail. There's nothing else we can do here. + i.filtered = append(i.filtered, manet.IP4Loopback, manet.IP6Loopback) + i.all = i.filtered + return + } + + // remove link local ipv6 addresses + i.all = slices.DeleteFunc(ifaceAddrs, manet.IsIP6LinkLocal) + + // If netroute failed to get us any interface addresses, use all of + // them. + if len(i.filtered) == 0 { + // Add all addresses. + i.filtered = i.all + } else { + // Only add loopback addresses. Filter these because we might + // not _have_ an IPv6 loopback address. + for _, addr := range i.all { + if manet.IsIPLoopback(addr) { + i.filtered = append(i.filtered, addr) + } + } + } +} + +// getAllPossibleLocalAddrs gives all the possible address returned for `conn.LocalAddr` correspoinding +// to the `listenAddr` +func getAllPossibleLocalAddrs(listenAddr ma.Multiaddr, ifaceAddrs []ma.Multiaddr) []ma.Multiaddr { + // If the nat mapping fails, use the observed addrs + resolved, err := manet.ResolveUnspecifiedAddress(listenAddr, ifaceAddrs) + if err != nil { + log.Warnf("failed to resolve listen addr %s, %s: %s", listenAddr, ifaceAddrs, err) + return nil + } + return append(resolved, listenAddr) +} + +// appendNATAddrsForListenAddrs adds the NAT-ed addresses to the result. If the NAT device doesn't provide +// us with a public IP address, we use the observed addresses. +func appendNATAddrsForListenAddrs(result []ma.Multiaddr, listenAddr ma.Multiaddr, natMapping ma.Multiaddr, + obsAddrsFunc func(ma.Multiaddr) []ma.Multiaddr, + ifaceAddrs []ma.Multiaddr) []ma.Multiaddr { + if natMapping == nil { + allAddrs := getAllPossibleLocalAddrs(listenAddr, ifaceAddrs) + for _, a := range allAddrs { + result = append(result, obsAddrsFunc(a)...) + } + return result + } + + // if the router reported a sane address, use it. + if !manet.IsIPUnspecified(natMapping) { + result = append(result, natMapping) + } else { + log.Warn("NAT device reported an unspecified IP as it's external address") + } + + // If the router gave us a public address, use it and ignore observed addresses + if manet.IsPublicAddr(natMapping) { + return result + } + + // Router gave us a private IP; maybe we're behind a CGNAT. + // See if we have a public IP from observed addresses. + _, extMaddrNoIP := ma.SplitFirst(natMapping) + if extMaddrNoIP == nil { + return result + } + + allAddrs := getAllPossibleLocalAddrs(listenAddr, ifaceAddrs) + for _, addr := range allAddrs { + for _, obsMaddr := range obsAddrsFunc(addr) { + // Extract a public observed addr. + ip, _ := ma.SplitFirst(obsMaddr) + if ip == nil || !manet.IsPublicAddr(ip) { + continue + } + result = append(result, ma.Join(ip, extMaddrNoIP)) + } + } + return result +} diff --git a/p2p/host/basic/address_service_test.go b/p2p/host/basic/address_service_test.go new file mode 100644 index 0000000000..d86cb162e7 --- /dev/null +++ b/p2p/host/basic/address_service_test.go @@ -0,0 +1,272 @@ +package basichost + +import ( + "testing" + "time" + + "github.com/libp2p/go-libp2p/core/network" + swarmt "github.com/libp2p/go-libp2p/p2p/net/swarm/testing" + ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" + "github.com/stretchr/testify/require" +) + +func TestAppendNATAddrs(t *testing.T) { + if1, if2 := ma.StringCast("/ip4/192.168.0.100"), ma.StringCast("/ip4/1.1.1.1") + ifaceAddrs := []ma.Multiaddr{if1, if2} + tcpListenAddr, udpListenAddr := ma.StringCast("/ip4/0.0.0.0/tcp/1"), ma.StringCast("/ip4/0.0.0.0/udp/2/quic-v1") + cases := []struct { + Name string + Listen ma.Multiaddr + Nat ma.Multiaddr + ObsAddrFunc func(ma.Multiaddr) []ma.Multiaddr + Expected []ma.Multiaddr + }{ + { + Name: "nat map success", + // nat mapping success, obsaddress ignored + Listen: ma.StringCast("/ip4/0.0.0.0/udp/1/quic-v1"), + Nat: ma.StringCast("/ip4/1.1.1.1/udp/10/quic-v1"), + ObsAddrFunc: func(m ma.Multiaddr) []ma.Multiaddr { + return []ma.Multiaddr{ma.StringCast("/ip4/2.2.2.2/udp/100/quic-v1")} + }, + Expected: []ma.Multiaddr{ma.StringCast("/ip4/1.1.1.1/udp/10/quic-v1")}, + }, + { + Name: "nat map failure", + //nat mapping fails, obs addresses added + Listen: ma.StringCast("/ip4/0.0.0.0/tcp/1"), + Nat: nil, + ObsAddrFunc: func(a ma.Multiaddr) []ma.Multiaddr { + ip, _ := ma.SplitFirst(a) + if ip == nil { + return nil + } + if ip.Equal(if1) { + return []ma.Multiaddr{ma.StringCast("/ip4/2.2.2.2/tcp/100")} + } else { + return []ma.Multiaddr{ma.StringCast("/ip4/3.3.3.3/tcp/100")} + } + }, + Expected: []ma.Multiaddr{ma.StringCast("/ip4/2.2.2.2/tcp/100"), ma.StringCast("/ip4/3.3.3.3/tcp/100")}, + }, + { + Name: "nat map success but CGNAT", + //nat addr added, obs address added with nat provided port + Listen: tcpListenAddr, + Nat: ma.StringCast("/ip4/100.100.1.1/tcp/100"), + ObsAddrFunc: func(a ma.Multiaddr) []ma.Multiaddr { + ip, _ := ma.SplitFirst(a) + if ip == nil { + return nil + } + if ip.Equal(if1) { + return []ma.Multiaddr{ma.StringCast("/ip4/2.2.2.2/tcp/20")} + } else { + return []ma.Multiaddr{ma.StringCast("/ip4/3.3.3.3/tcp/30")} + } + }, + Expected: []ma.Multiaddr{ + ma.StringCast("/ip4/100.100.1.1/tcp/100"), + ma.StringCast("/ip4/2.2.2.2/tcp/100"), + ma.StringCast("/ip4/3.3.3.3/tcp/100"), + }, + }, + { + Name: "uses unspecified address for obs address", + // observed address manager should be queries with both specified and unspecified addresses + // udp observed addresses are mapped to unspecified addresses + Listen: udpListenAddr, + Nat: nil, + ObsAddrFunc: func(a ma.Multiaddr) []ma.Multiaddr { + if manet.IsIPUnspecified(a) { + return []ma.Multiaddr{ma.StringCast("/ip4/3.3.3.3/udp/20/quic-v1")} + } + return []ma.Multiaddr{ma.StringCast("/ip4/2.2.2.2/udp/20/quic-v1")} + }, + Expected: []ma.Multiaddr{ + ma.StringCast("/ip4/2.2.2.2/udp/20/quic-v1"), + ma.StringCast("/ip4/3.3.3.3/udp/20/quic-v1"), + }, + }, + } + for _, tc := range cases { + t.Run(tc.Name, func(t *testing.T) { + res := appendNATAddrsForListenAddrs(nil, + tc.Listen, tc.Nat, tc.ObsAddrFunc, ifaceAddrs) + res = ma.Unique(res) + require.ElementsMatch(t, tc.Expected, res, "%s\n%s", tc.Expected, res) + }) + } +} + +type mockNatManager struct { + GetMappingFunc func(addr ma.Multiaddr) ma.Multiaddr + HasDiscoveredNATFunc func() bool +} + +func (m *mockNatManager) Close() error { + return nil +} + +func (m *mockNatManager) GetMapping(addr ma.Multiaddr) ma.Multiaddr { + return m.GetMappingFunc(addr) +} + +func (m *mockNatManager) HasDiscoveredNAT() bool { + return m.HasDiscoveredNATFunc() +} + +var _ NATManager = &mockNatManager{} + +type mockObservedAddrs struct { + OwnObservedAddrsFunc func() []ma.Multiaddr + ObservedAddrsForFunc func(ma.Multiaddr) []ma.Multiaddr +} + +func (m *mockObservedAddrs) OwnObservedAddrs() []ma.Multiaddr { + return m.OwnObservedAddrsFunc() +} + +func (m *mockObservedAddrs) ObservedAddrsFor(local ma.Multiaddr) []ma.Multiaddr { + return m.ObservedAddrsForFunc(local) +} + +func TestAddressService(t *testing.T) { + getAddrService := func() *addressService { + h, err := NewHost(swarmt.GenSwarm(t), &HostOpts{DisableIdentifyAddressDiscovery: true}) + require.NoError(t, err) + t.Cleanup(func() { h.Close() }) + + as := h.addressService + return as + } + + t.Run("NAT Address", func(t *testing.T) { + as := getAddrService() + as.natmgr = &mockNatManager{ + HasDiscoveredNATFunc: func() bool { return true }, + GetMappingFunc: func(addr ma.Multiaddr) ma.Multiaddr { + if _, err := addr.ValueForProtocol(ma.P_UDP); err == nil { + return ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1") + } + return nil + }, + } + require.Contains(t, as.Addrs(), ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1")) + }) + + t.Run("NAT And Observed Address", func(t *testing.T) { + as := getAddrService() + as.natmgr = &mockNatManager{ + HasDiscoveredNATFunc: func() bool { return true }, + GetMappingFunc: func(addr ma.Multiaddr) ma.Multiaddr { + if _, err := addr.ValueForProtocol(ma.P_UDP); err == nil { + return ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1") + } + return nil + }, + } + as.observedAddrsService = &mockObservedAddrs{ + ObservedAddrsForFunc: func(addr ma.Multiaddr) []ma.Multiaddr { + if _, err := addr.ValueForProtocol(ma.P_TCP); err == nil { + return []ma.Multiaddr{ma.StringCast("/ip4/2.2.2.2/tcp/1")} + } + return nil + }, + } + require.Contains(t, as.Addrs(), ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1")) + require.Contains(t, as.Addrs(), ma.StringCast("/ip4/2.2.2.2/tcp/1")) + }) + t.Run("Only Observed Address", func(t *testing.T) { + as := getAddrService() + as.natmgr = nil + as.observedAddrsService = &mockObservedAddrs{ + ObservedAddrsForFunc: func(addr ma.Multiaddr) []ma.Multiaddr { + if _, err := addr.ValueForProtocol(ma.P_TCP); err == nil { + return []ma.Multiaddr{ma.StringCast("/ip4/2.2.2.2/tcp/1")} + } + return nil + }, + OwnObservedAddrsFunc: func() []ma.Multiaddr { + return []ma.Multiaddr{ma.StringCast("/ip4/3.3.3.3/udp/1/quic-v1")} + }, + } + require.NotContains(t, as.Addrs(), ma.StringCast("/ip4/2.2.2.2/tcp/1")) + require.Contains(t, as.Addrs(), ma.StringCast("/ip4/3.3.3.3/udp/1/quic-v1")) + }) + t.Run("Public Addrs Removed When Private", func(t *testing.T) { + as := getAddrService() + as.natmgr = nil + as.observedAddrsService = &mockObservedAddrs{ + OwnObservedAddrsFunc: func() []ma.Multiaddr { + return []ma.Multiaddr{ma.StringCast("/ip4/3.3.3.3/udp/1/quic-v1")} + }, + } + as.reachability = func() network.Reachability { + return network.ReachabilityPrivate + } + relayAddr := ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1/p2p/QmdXGaeGiVA745XorV1jr11RHxB9z4fqykm6xCUPX1aTJo/p2p-circuit") + as.autoRelayAddrs = func() []ma.Multiaddr { + return []ma.Multiaddr{relayAddr} + } + require.NotContains(t, as.Addrs(), ma.StringCast("/ip4/3.3.3.3/udp/1/quic-v1")) + require.Contains(t, as.Addrs(), relayAddr) + require.Contains(t, as.AllAddrs(), ma.StringCast("/ip4/3.3.3.3/udp/1/quic-v1")) + }) + + t.Run("AddressFactory gets relay addresses", func(t *testing.T) { + as := getAddrService() + as.natmgr = nil + as.observedAddrsService = &mockObservedAddrs{ + OwnObservedAddrsFunc: func() []ma.Multiaddr { + return []ma.Multiaddr{ma.StringCast("/ip4/3.3.3.3/udp/1/quic-v1")} + }, + } + as.reachability = func() network.Reachability { + return network.ReachabilityPrivate + } + relayAddr := ma.StringCast("/ip4/1.2.3.4/udp/1/quic-v1/p2p/QmdXGaeGiVA745XorV1jr11RHxB9z4fqykm6xCUPX1aTJo/p2p-circuit") + as.autoRelayAddrs = func() []ma.Multiaddr { + return []ma.Multiaddr{relayAddr} + } + as.addrsFactory = func(addrs []ma.Multiaddr) []ma.Multiaddr { + for _, a := range addrs { + if a.Equal(relayAddr) { + return []ma.Multiaddr{ma.StringCast("/ip4/3.3.3.3/udp/1/quic-v1")} + } + } + return nil + } + require.Contains(t, as.Addrs(), ma.StringCast("/ip4/3.3.3.3/udp/1/quic-v1")) + require.NotContains(t, as.Addrs(), relayAddr) + }) + + t.Run("updates addresses on signaling", func(t *testing.T) { + as := getAddrService() + as.natmgr = nil + updateChan := make(chan struct{}) + a1 := ma.StringCast("/ip4/1.1.1.1/udp/1/quic-v1") + a2 := ma.StringCast("/ip4/1.1.1.1/tcp/1") + as.addrsFactory = func(addrs []ma.Multiaddr) []ma.Multiaddr { + select { + case <-updateChan: + return []ma.Multiaddr{a2} + default: + return []ma.Multiaddr{a1} + } + } + as.Start() + require.Contains(t, as.Addrs(), a1) + require.NotContains(t, as.Addrs(), a2) + close(updateChan) + as.SignalAddressChange() + select { + case <-as.AddrsUpdated(): + require.Contains(t, as.Addrs(), a2) + require.NotContains(t, as.Addrs(), a1) + case <-time.After(2 * time.Second): + t.Fatal("expected addrs to be updated") + } + }) +} diff --git a/p2p/host/basic/basic_host.go b/p2p/host/basic/basic_host.go index 5da306d54a..404949a287 100644 --- a/p2p/host/basic/basic_host.go +++ b/p2p/host/basic/basic_host.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "io" - "net" "slices" "sync" "time" @@ -19,10 +18,8 @@ import ( "github.com/libp2p/go-libp2p/core/peerstore" "github.com/libp2p/go-libp2p/core/protocol" "github.com/libp2p/go-libp2p/core/record" - "github.com/libp2p/go-libp2p/core/transport" "github.com/libp2p/go-libp2p/p2p/host/autonat" "github.com/libp2p/go-libp2p/p2p/host/autorelay" - "github.com/libp2p/go-libp2p/p2p/host/basic/internal/backoff" "github.com/libp2p/go-libp2p/p2p/host/eventbus" "github.com/libp2p/go-libp2p/p2p/host/pstoremanager" "github.com/libp2p/go-libp2p/p2p/host/relaysvc" @@ -35,8 +32,6 @@ import ( libp2pwebtransport "github.com/libp2p/go-libp2p/p2p/transport/webtransport" "github.com/prometheus/client_golang/prometheus" - "github.com/libp2p/go-netroute" - logging "github.com/ipfs/go-log/v2" ma "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr/net" @@ -86,8 +81,6 @@ type BasicHost struct { eventbus event.Bus relayManager *relaysvc.RelayManager - AddrsFactory AddrsFactory - negtimeout time.Duration emitters struct { @@ -95,23 +88,16 @@ type BasicHost struct { evtLocalAddrsUpdated event.Emitter } - addrChangeChan chan struct{} - - addrMu sync.RWMutex - updateLocalIPv4Backoff backoff.ExpBackoff - updateLocalIPv6Backoff backoff.ExpBackoff - filteredInterfaceAddrs []ma.Multiaddr - allInterfaceAddrs []ma.Multiaddr - disableSignedPeerRecord bool signKey crypto.PrivKey caBook peerstore.CertifiedAddrBook - autoNat autonat.AutoNAT + autoNATMx sync.RWMutex + autoNat autonat.AutoNAT - autonatv2 *autonatv2.AutoNAT - addrSub event.Subscription - autorelay *autorelay.AutoRelay + autonatv2 *autonatv2.AutoNAT + autorelay *autorelay.AutoRelay + addressService *addressService } var _ host.Host = (*BasicHost)(nil) @@ -192,27 +178,18 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) { return nil, err } - addrSub, err := opts.EventBus.Subscribe(new(event.EvtAutoRelayAddrs)) - if err != nil { - return nil, err - } hostCtx, cancel := context.WithCancel(context.Background()) h := &BasicHost{ network: n, psManager: psManager, mux: msmux.NewMultistreamMuxer[protocol.ID](), negtimeout: DefaultNegotiationTimeout, - AddrsFactory: DefaultAddrsFactory, eventbus: opts.EventBus, - addrChangeChan: make(chan struct{}, 1), ctx: hostCtx, ctxCancel: cancel, disableSignedPeerRecord: opts.DisableSignedPeerRecord, - addrSub: addrSub, } - h.updateLocalIpAddr() - if h.emitters.evtLocalProtocolsUpdated, err = h.eventbus.Emitter(&event.EvtLocalProtocolsUpdated{}, eventbus.Stateful); err != nil { return nil, err } @@ -220,30 +197,8 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) { return nil, err } - if !h.disableSignedPeerRecord { - cab, ok := peerstore.GetCertifiedAddrBook(n.Peerstore()) - if !ok { - return nil, errors.New("peerstore should also be a certified address book") - } - h.caBook = cab - + if !opts.DisableSignedPeerRecord { h.signKey = h.Peerstore().PrivKey(h.ID()) - if h.signKey == nil { - return nil, errors.New("unable to access host key") - } - - // persist a signed peer record for self to the peerstore. - rec := peer.PeerRecordFromAddrInfo(peer.AddrInfo{ - ID: h.ID(), - Addrs: h.Addrs(), - }) - ev, err := record.Seal(rec, h.signKey) - if err != nil { - return nil, fmt.Errorf("failed to create signed record for self: %w", err) - } - if _, err := cab.ConsumePeerRecord(ev, peerstore.PermanentAddrTTL); err != nil { - return nil, fmt.Errorf("failed to persist signed record to peerstore: %w", err) - } } if opts.MultistreamMuxer != nil { @@ -273,6 +228,31 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) { return nil, fmt.Errorf("failed to create Identify service: %s", err) } + if opts.EnableAutoRelay { + if opts.EnableMetrics { + mt := autorelay.WithMetricsTracer( + autorelay.NewMetricsTracer(autorelay.WithRegisterer(opts.PrometheusRegisterer))) + mtOpts := []autorelay.Option{mt} + opts.AutoRelayOpts = append(mtOpts, opts.AutoRelayOpts...) + } + + ar, err := autorelay.NewAutoRelay(h, opts.AutoRelayOpts...) + if err != nil { + return nil, fmt.Errorf("failed to create autorelay: %w", err) + } + h.autorelay = ar + } + + addrFactory := DefaultAddrsFactory + if opts.AddrsFactory != nil { + addrFactory = opts.AddrsFactory + } + + h.addressService, err = NewAddressService(h, opts.NATManager, addrFactory) + if err != nil { + return nil, fmt.Errorf("failed to create address service: %w", err) + } + if opts.EnableHolePunching { if opts.EnableMetrics { hpOpts := []holepunch.Option{ @@ -280,16 +260,7 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) { opts.HolePunchingOptions = append(hpOpts, opts.HolePunchingOptions...) } - h.hps, err = holepunch.NewService(h, h.ids, func() []ma.Multiaddr { - addrs := h.AllAddrs() - if opts.AddrsFactory != nil { - addrs = slices.Clone(opts.AddrsFactory(addrs)) - } - // AllAddrs may ignore observed addresses in favour of NAT mappings. Use both for hole punching. - addrs = append(addrs, h.ids.OwnObservedAddrs()...) - addrs = ma.Unique(addrs) - return slices.DeleteFunc(addrs, func(a ma.Multiaddr) bool { return !manet.IsPublicAddr(a) }) - }, opts.HolePunchingOptions...) + h.hps, err = holepunch.NewService(h, h.ids, h.addressService.GetHolePunchAddrs, opts.HolePunchingOptions...) if err != nil { return nil, fmt.Errorf("failed to create hole punch service: %w", err) } @@ -299,14 +270,6 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) { h.negtimeout = opts.NegotiationTimeout } - if opts.AddrsFactory != nil { - h.AddrsFactory = opts.AddrsFactory - } - - if opts.NATManager != nil { - h.natmgr = opts.NATManager(n) - } - if opts.ConnManager == nil { h.cmgr = &connmgr.NullConnMgr{} } else { @@ -340,27 +303,27 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) { } } - if opts.EnableAutoRelay { - if opts.EnableMetrics { - mt := autorelay.WithMetricsTracer( - autorelay.NewMetricsTracer(autorelay.WithRegisterer(opts.PrometheusRegisterer))) - mtOpts := []autorelay.Option{mt} - opts.AutoRelayOpts = append(mtOpts, opts.AutoRelayOpts...) - } + n.SetStreamHandler(h.newStreamHandler) - ar, err := autorelay.NewAutoRelay(h, opts.AutoRelayOpts...) + if !h.disableSignedPeerRecord { + cab, ok := peerstore.GetCertifiedAddrBook(h.Peerstore()) + if !ok { + return nil, errors.New("peerstore should also be a certified address book") + } + h.caBook = cab + rec, err := h.makeSignedPeerRecord(h.addressService.Addrs()) if err != nil { - return nil, fmt.Errorf("failed to create autorelay: %w", err) + return nil, fmt.Errorf("failed to create signed record for self: %w", err) + } + if _, err := h.caBook.ConsumePeerRecord(rec, peerstore.PermanentAddrTTL); err != nil { + return nil, fmt.Errorf("failed to persist signed record to peerstore: %w", err) } - h.autorelay = ar } - n.SetStreamHandler(h.newStreamHandler) - // register to be notified when the network's listen addrs change, // so we can update our address set and push events if needed listenHandler := func(network.Network, ma.Multiaddr) { - h.SignalAddressChange() + h.addressService.SignalAddressChange() } n.Notify(&network.NotifyBundle{ ListenF: listenHandler, @@ -370,92 +333,9 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) { return h, nil } -func (h *BasicHost) updateLocalIpAddr() { - h.addrMu.Lock() - defer h.addrMu.Unlock() - - h.filteredInterfaceAddrs = nil - h.allInterfaceAddrs = nil - - // Try to use the default ipv4/6 addresses. - - if r, err := netroute.New(); err != nil { - log.Debugw("failed to build Router for kernel's routing table", "error", err) - } else { - - var localIPv4 net.IP - var ran bool - err, ran = h.updateLocalIPv4Backoff.Run(func() error { - _, _, localIPv4, err = r.Route(net.IPv4zero) - return err - }) - - if ran && err != nil { - log.Debugw("failed to fetch local IPv4 address", "error", err) - } else if ran && localIPv4.IsGlobalUnicast() { - maddr, err := manet.FromIP(localIPv4) - if err == nil { - h.filteredInterfaceAddrs = append(h.filteredInterfaceAddrs, maddr) - } - } - - var localIPv6 net.IP - err, ran = h.updateLocalIPv6Backoff.Run(func() error { - _, _, localIPv6, err = r.Route(net.IPv6unspecified) - return err - }) - - if ran && err != nil { - log.Debugw("failed to fetch local IPv6 address", "error", err) - } else if ran && localIPv6.IsGlobalUnicast() { - maddr, err := manet.FromIP(localIPv6) - if err == nil { - h.filteredInterfaceAddrs = append(h.filteredInterfaceAddrs, maddr) - } - } - } - - // Resolve the interface addresses - ifaceAddrs, err := manet.InterfaceMultiaddrs() - if err != nil { - // This usually shouldn't happen, but we could be in some kind - // of funky restricted environment. - log.Errorw("failed to resolve local interface addresses", "error", err) - - // Add the loopback addresses to the filtered addrs and use them as the non-filtered addrs. - // Then bail. There's nothing else we can do here. - h.filteredInterfaceAddrs = append(h.filteredInterfaceAddrs, manet.IP4Loopback, manet.IP6Loopback) - h.allInterfaceAddrs = h.filteredInterfaceAddrs - return - } - - for _, addr := range ifaceAddrs { - // Skip link-local addrs, they're mostly useless. - if !manet.IsIP6LinkLocal(addr) { - h.allInterfaceAddrs = append(h.allInterfaceAddrs, addr) - } - } - - // If netroute failed to get us any interface addresses, use all of - // them. - if len(h.filteredInterfaceAddrs) == 0 { - // Add all addresses. - h.filteredInterfaceAddrs = h.allInterfaceAddrs - } else { - // Only add loopback addresses. Filter these because we might - // not _have_ an IPv6 loopback address. - for _, addr := range h.allInterfaceAddrs { - if manet.IsIPLoopback(addr) { - h.filteredInterfaceAddrs = append(h.filteredInterfaceAddrs, addr) - } - } - } -} - // Start starts background tasks in the host func (h *BasicHost) Start() { h.psManager.Start() - h.refCount.Add(1) h.ids.Start() if h.autorelay != nil { h.autorelay.Start() @@ -466,7 +346,10 @@ func (h *BasicHost) Start() { log.Errorf("autonat v2 failed to start: %s", err) } } + h.refCount.Add(1) go h.background() + h.addressService.Start() + } // newStreamHandler is the remote-opened stream handler for network.Network @@ -517,16 +400,6 @@ func (h *BasicHost) newStreamHandler(s network.Stream) { handle(protoID, s) } -// SignalAddressChange signals to the host that it needs to determine whether our listen addresses have recently -// changed. -// Warning: this interface is unstable and may disappear in the future. -func (h *BasicHost) SignalAddressChange() { - select { - case h.addrChangeChan <- struct{}{}: - default: - } -} - func (h *BasicHost) makeUpdatedAddrEvent(prev, current []ma.Multiaddr) *event.EvtLocalAddressesUpdated { if prev == nil && current == nil { return nil @@ -598,6 +471,7 @@ func (h *BasicHost) background() { defer h.refCount.Done() var lastAddrs []ma.Multiaddr + // TODO: Deprecate this event and logic once we have a new event for address with reachability emitAddrChange := func(currentAddrs []ma.Multiaddr, lastAddrs []ma.Multiaddr) { changeEvt := h.makeUpdatedAddrEvent(lastAddrs, currentAddrs) if changeEvt == nil { @@ -625,24 +499,13 @@ func (h *BasicHost) background() { } } - // periodically schedules an IdentifyPush to update our peers for changes - // in our address set (if needed) - ticker := time.NewTicker(addrChangeTickrInterval) - defer ticker.Stop() - for { - // Update our local IP addresses before checking our current addresses. - if len(h.network.ListenAddresses()) > 0 { - h.updateLocalIpAddr() - } curr := h.Addrs() emitAddrChange(curr, lastAddrs) lastAddrs = curr select { - case <-ticker.C: - case <-h.addrChangeChan: - case <-h.addrSub.Out(): + case <-h.addressService.AddrsUpdated(): case <-h.ctx.Done(): return } @@ -835,7 +698,6 @@ func (h *BasicHost) Connect(ctx context.Context, pi peer.AddrInfo) error { return nil } } - return h.dialPeer(ctx, pi.ID) } @@ -872,15 +734,7 @@ func (h *BasicHost) ConnManager() connmgr.ConnManager { // When used with AutoRelay, and if the host is not publicly reachable, // this will only have host's private, relay, and no public addresses. func (h *BasicHost) Addrs() []ma.Multiaddr { - addrs := h.AllAddrs() - // Make a copy. Consumers can modify the slice elements - if h.autoNat != nil && h.autorelay != nil && h.autoNat.Status() == network.ReachabilityPrivate { - addrs = slices.DeleteFunc(addrs, func(a ma.Multiaddr) bool { return manet.IsPublicAddr(a) }) - addrs = append(addrs, h.autorelay.RelayAddrs()...) - } - addrs = slices.Clone(h.AddrsFactory(addrs)) - // Add certhashes for the addresses provided by the user via address factory. - return h.addCertHashes(ma.Unique(addrs)) + return h.addressService.Addrs() } // NormalizeMultiaddr returns a multiaddr suitable for equality checks. @@ -900,153 +754,78 @@ func (h *BasicHost) NormalizeMultiaddr(addr ma.Multiaddr) ma.Multiaddr { return addr } -var p2pCircuitAddr = ma.StringCast("/p2p-circuit") - // AllAddrs returns all the addresses the host is listening on except circuit addresses. func (h *BasicHost) AllAddrs() []ma.Multiaddr { - listenAddrs := h.Network().ListenAddresses() - if len(listenAddrs) == 0 { - return nil - } + return h.addressService.AllAddrs() +} - h.addrMu.RLock() - filteredIfaceAddrs := h.filteredInterfaceAddrs - allIfaceAddrs := h.allInterfaceAddrs - h.addrMu.RUnlock() - - // Iterate over all _unresolved_ listen addresses, resolving our primary - // interface only to avoid advertising too many addresses. - finalAddrs := make([]ma.Multiaddr, 0, 8) - if resolved, err := manet.ResolveUnspecifiedAddresses(listenAddrs, filteredIfaceAddrs); err != nil { - // This can happen if we're listening on no addrs, or listening - // on IPv6 addrs, but only have IPv4 interface addrs. - log.Debugw("failed to resolve listen addrs", "error", err) - } else { - finalAddrs = append(finalAddrs, resolved...) +// SetAutoNat sets the autonat service for the host. +func (h *BasicHost) SetAutoNat(a autonat.AutoNAT) { + h.autoNATMx.Lock() + defer h.autoNATMx.Unlock() + if h.autoNat == nil { + h.autoNat = a } +} - finalAddrs = ma.Unique(finalAddrs) - - // use nat mappings if we have them - if h.natmgr != nil && h.natmgr.HasDiscoveredNAT() { - // We have successfully mapped ports on our NAT. Use those - // instead of observed addresses (mostly). - // Next, apply this mapping to our addresses. - for _, listen := range listenAddrs { - extMaddr := h.natmgr.GetMapping(listen) - if extMaddr == nil { - // not mapped - continue - } +// GetAutoNat returns the host's AutoNAT service, if AutoNAT is enabled. +func (h *BasicHost) GetAutoNat() autonat.AutoNAT { + h.autoNATMx.Lock() + defer h.autoNATMx.Unlock() + return h.autoNat +} - // if the router reported a sane address - if !manet.IsIPUnspecified(extMaddr) { - // Add in the mapped addr. - finalAddrs = append(finalAddrs, extMaddr) - } else { - log.Warn("NAT device reported an unspecified IP as it's external address") - } +// Close shuts down the Host's services (network, etc). +func (h *BasicHost) Close() error { + h.closeSync.Do(func() { + h.ctxCancel() + if h.natmgr != nil { + h.natmgr.Close() + } + if h.cmgr != nil { + h.cmgr.Close() + } - // Did the router give us a routable public addr? - if manet.IsPublicAddr(extMaddr) { - // well done - continue - } + h.addressService.Close() - // No. - // in case the router gives us a wrong address or we're behind a double-NAT. - // also add observed addresses - resolved, err := manet.ResolveUnspecifiedAddress(listen, allIfaceAddrs) - if err != nil { - // This can happen if we try to resolve /ip6/::/... - // without any IPv6 interface addresses. - continue - } + if h.ids != nil { + h.ids.Close() + } + if h.autoNat != nil { + h.autoNat.Close() + } + if h.relayManager != nil { + h.relayManager.Close() + } + if h.hps != nil { + h.hps.Close() + } + if h.autonatv2 != nil { + h.autonatv2.Close() + } + if h.autorelay != nil { + h.autorelay.Close() + } - for _, addr := range resolved { - // Now, check if we have any observed addresses that - // differ from the one reported by the router. Routers - // don't always give the most accurate information. - observed := h.ids.ObservedAddrsFor(addr) + _ = h.emitters.evtLocalProtocolsUpdated.Close() - if len(observed) == 0 { - continue - } + if err := h.network.Close(); err != nil { + log.Errorf("swarm close failed: %v", err) + } - // Drop the IP from the external maddr - _, extMaddrNoIP := ma.SplitFirst(extMaddr) + h.psManager.Close() + if h.Peerstore() != nil { + h.Peerstore().Close() + } - for _, obsMaddr := range observed { - // Extract a public observed addr. - ip, _ := ma.SplitFirst(obsMaddr) - if ip == nil || !manet.IsPublicAddr(ip) { - continue - } + h.refCount.Wait() - finalAddrs = append(finalAddrs, ma.Join(ip, extMaddrNoIP)) - } - } - } - } else { - var observedAddrs []ma.Multiaddr - if h.ids != nil { - observedAddrs = h.ids.OwnObservedAddrs() + if h.Network().ResourceManager() != nil { + h.Network().ResourceManager().Close() } - finalAddrs = append(finalAddrs, observedAddrs...) - } - finalAddrs = ma.Unique(finalAddrs) - // Remove /p2p-circuit addresses from the list. - // The p2p-circuit tranport listener reports its address as just /p2p-circuit - // This is useless for dialing. Users need to manage their circuit addresses themselves, - // or use AutoRelay. - finalAddrs = slices.DeleteFunc(finalAddrs, func(a ma.Multiaddr) bool { - return a.Equal(p2pCircuitAddr) }) - // Add certhashes for /webrtc-direct, /webtransport, etc addresses discovered - // using identify. - finalAddrs = h.addCertHashes(finalAddrs) - return finalAddrs -} -func (h *BasicHost) addCertHashes(addrs []ma.Multiaddr) []ma.Multiaddr { - // This is a temporary workaround/hack that fixes #2233. Once we have a - // proper address pipeline, rework this. See the issue for more context. - type transportForListeninger interface { - TransportForListening(a ma.Multiaddr) transport.Transport - } - - type addCertHasher interface { - AddCertHashes(m ma.Multiaddr) (ma.Multiaddr, bool) - } - - s, ok := h.Network().(transportForListeninger) - if !ok { - return addrs - } - - // Copy addrs slice since we'll be modifying it. - addrsOld := addrs - addrs = make([]ma.Multiaddr, len(addrsOld)) - copy(addrs, addrsOld) - - for i, addr := range addrs { - wtOK, wtN := libp2pwebtransport.IsWebtransportMultiaddr(addr) - webrtcOK, webrtcN := libp2pwebrtc.IsWebRTCDirectMultiaddr(addr) - if (wtOK && wtN == 0) || (webrtcOK && webrtcN == 0) { - t := s.TransportForListening(addr) - tpt, ok := t.(addCertHasher) - if !ok { - continue - } - addrWithCerthash, added := tpt.AddCertHashes(addr) - if !added { - log.Debugf("Couldn't add certhashes to multiaddr: %s", addr) - continue - } - addrs[i] = addrWithCerthash - } - } - return addrs + return nil } func trimHostAddrList(addrs []ma.Multiaddr, maxSize int) []ma.Multiaddr { @@ -1101,75 +880,6 @@ func trimHostAddrList(addrs []ma.Multiaddr, maxSize int) []ma.Multiaddr { return addrs } -// SetAutoNat sets the autonat service for the host. -func (h *BasicHost) SetAutoNat(a autonat.AutoNAT) { - h.addrMu.Lock() - defer h.addrMu.Unlock() - if h.autoNat == nil { - h.autoNat = a - } -} - -// GetAutoNat returns the host's AutoNAT service, if AutoNAT is enabled. -func (h *BasicHost) GetAutoNat() autonat.AutoNAT { - h.addrMu.Lock() - defer h.addrMu.Unlock() - return h.autoNat -} - -// Close shuts down the Host's services (network, etc). -func (h *BasicHost) Close() error { - h.closeSync.Do(func() { - h.ctxCancel() - if h.natmgr != nil { - h.natmgr.Close() - } - if h.cmgr != nil { - h.cmgr.Close() - } - if h.ids != nil { - h.ids.Close() - } - if h.autoNat != nil { - h.autoNat.Close() - } - if h.relayManager != nil { - h.relayManager.Close() - } - if h.hps != nil { - h.hps.Close() - } - if h.autonatv2 != nil { - h.autonatv2.Close() - } - if h.autorelay != nil { - h.autorelay.Close() - } - - _ = h.emitters.evtLocalProtocolsUpdated.Close() - _ = h.emitters.evtLocalAddrsUpdated.Close() - - if err := h.network.Close(); err != nil { - log.Errorf("swarm close failed: %v", err) - } - - h.psManager.Close() - if h.Peerstore() != nil { - h.Peerstore().Close() - } - - h.refCount.Wait() - - _ = h.addrSub.Close() - - if h.Network().ResourceManager() != nil { - h.Network().ResourceManager().Close() - } - }) - - return nil -} - type streamWrapper struct { network.Stream rw io.ReadWriteCloser diff --git a/p2p/host/basic/basic_host_test.go b/p2p/host/basic/basic_host_test.go index 2a7a772976..dcf174410b 100644 --- a/p2p/host/basic/basic_host_test.go +++ b/p2p/host/basic/basic_host_test.go @@ -8,6 +8,7 @@ import ( "reflect" "strings" "sync" + "sync/atomic" "testing" "time" @@ -205,29 +206,6 @@ func TestHostAddrsFactory(t *testing.T) { } } -func TestLocalIPChangesWhenListenAddrChanges(t *testing.T) { - // no listen addrs - h, err := NewHost(swarmt.GenSwarm(t, swarmt.OptDialOnly), nil) - require.NoError(t, err) - h.Start() - defer h.Close() - - h.addrMu.Lock() - h.filteredInterfaceAddrs = nil - h.allInterfaceAddrs = nil - h.addrMu.Unlock() - - // change listen addrs and verify local IP addr is not nil again - require.NoError(t, h.Network().Listen(ma.StringCast("/ip4/0.0.0.0/tcp/0"))) - h.SignalAddressChange() - time.Sleep(1 * time.Second) - - h.addrMu.RLock() - defer h.addrMu.RUnlock() - require.NotEmpty(t, h.filteredInterfaceAddrs) - require.NotEmpty(t, h.allInterfaceAddrs) -} - func TestAllAddrs(t *testing.T) { // no listen addrs h, err := NewHost(swarmt.GenSwarm(t, swarmt.OptDialOnly), nil) @@ -619,8 +597,13 @@ func TestAddrChangeImmediatelyIfAddressNonEmpty(t *testing.T) { ctx := context.Background() taddrs := []ma.Multiaddr{ma.StringCast("/ip4/1.2.3.4/tcp/1234")} - starting := make(chan struct{}) + starting := make(chan struct{}, 1) + var count atomic.Int32 h, err := NewHost(swarmt.GenSwarm(t), &HostOpts{AddrsFactory: func(addrs []ma.Multiaddr) []ma.Multiaddr { + // The first call here is made from the constructor. Don't block. + if count.Add(1) == 1 { + return addrs + } <-starting return taddrs }}) @@ -628,11 +611,11 @@ func TestAddrChangeImmediatelyIfAddressNonEmpty(t *testing.T) { defer h.Close() sub, err := h.EventBus().Subscribe(&event.EvtLocalAddressesUpdated{}) - close(starting) if err != nil { t.Error(err) } defer sub.Close() + close(starting) h.Start() expected := event.EvtLocalAddressesUpdated{ @@ -751,7 +734,7 @@ func TestHostAddrChangeDetection(t *testing.T) { lk.Lock() currentAddrSet = i lk.Unlock() - h.SignalAddressChange() + h.addressService.SignalAddressChange() evt := waitForAddrChangeEvent(ctx, sub, t) if !updatedAddrEventsEqual(expectedEvents[i-1], evt) { t.Errorf("change events not equal: \n\texpected: %v \n\tactual: %v", expectedEvents[i-1], evt)