From fd761008b19dd7f47bc94e9d042457d94334442f Mon Sep 17 00:00:00 2001 From: sukun Date: Tue, 3 Dec 2024 17:34:09 +0530 Subject: [PATCH] make everything mockable --- p2p/host/basic/address_service.go | 206 ++++++++----------------- p2p/host/basic/address_service_test.go | 2 +- p2p/host/basic/basic_host.go | 142 +++++++++++++++-- p2p/host/basic/basic_host_test.go | 2 +- 4 files changed, 189 insertions(+), 163 deletions(-) diff --git a/p2p/host/basic/address_service.go b/p2p/host/basic/address_service.go index 72949f79d1..73d359c857 100644 --- a/p2p/host/basic/address_service.go +++ b/p2p/host/basic/address_service.go @@ -2,8 +2,6 @@ package basichost import ( "context" - "errors" - "fmt" "net" "slices" "sync" @@ -11,14 +9,9 @@ import ( "github.com/libp2p/go-libp2p/core/event" "github.com/libp2p/go-libp2p/core/network" - "github.com/libp2p/go-libp2p/core/peer" - "github.com/libp2p/go-libp2p/core/peerstore" "github.com/libp2p/go-libp2p/core/record" "github.com/libp2p/go-libp2p/core/transport" - "github.com/libp2p/go-libp2p/p2p/host/autonat" - "github.com/libp2p/go-libp2p/p2p/host/autorelay" "github.com/libp2p/go-libp2p/p2p/host/basic/internal/backoff" - "github.com/libp2p/go-libp2p/p2p/host/eventbus" libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc" libp2pwebtransport "github.com/libp2p/go-libp2p/p2p/transport/webtransport" "github.com/libp2p/go-netroute" @@ -34,19 +27,15 @@ type observedAddrsService interface { } type addressService struct { - net network.Network - peerstore peerstore.Peerstore - id peer.ID - addrsFactory AddrsFactory - // peerRecord is used to create peer records when the addresses change - peerRecord peerRecordFunc - autonat autonat.AutoNAT - autorelay *autorelay.AutoRelay + net network.Network + addrsFactory AddrsFactory natmgr NATManager observedAddrsService observedAddrsService - addrChangeChan chan struct{} - relayAddrsSub event.Subscription - emitter event.Emitter + addrsChangeChan chan struct{} + addrsUpdated chan struct{} + autoRelayAddrsSub event.Subscription + autoRelayAddrs func() []ma.Multiaddr + reachability func() network.Reachability ifaceAddrs *interfaceAddrsCache wg sync.WaitGroup ctx context.Context @@ -55,10 +44,6 @@ type addressService struct { func NewAddressService(h *BasicHost, natmgr func(network.Network) NATManager, addrFactory AddrsFactory) (*addressService, error) { - emitter, err := h.eventbus.Emitter(&event.EvtLocalAddressesUpdated{}, eventbus.Stateful) - if err != nil { - return nil, err - } var nmgr NATManager if natmgr != nil { nmgr = natmgr(h.Network()) @@ -67,47 +52,35 @@ func NewAddressService(h *BasicHost, natmgr func(network.Network) NATManager, if err != nil { return nil, err } - peerRecord := h.makeSignedPeerRecord - if !h.disableSignedPeerRecord { - peerRecord = nil + + var autoRelayAddrs func() []ma.Multiaddr + if h.autorelay != nil { + autoRelayAddrs = h.autorelay.RelayAddrs } + ctx, cancel := context.WithCancel(context.Background()) as := &addressService{ net: h.Network(), - peerstore: h.Peerstore(), observedAddrsService: h.IDService(), - id: h.ID(), - peerRecord: peerRecord, natmgr: nmgr, - emitter: emitter, - autorelay: h.autorelay, addrsFactory: addrFactory, - addrChangeChan: make(chan struct{}, 1), - relayAddrsSub: addrSub, - ctx: ctx, - ctxCancel: cancel, + addrsChangeChan: make(chan struct{}, 1), + addrsUpdated: make(chan struct{}, 1), + autoRelayAddrsSub: addrSub, + autoRelayAddrs: autoRelayAddrs, ifaceAddrs: &interfaceAddrsCache{}, - } - if as.peerRecord != nil { - cab, ok := peerstore.GetCertifiedAddrBook(as.peerstore) - if !ok { - return nil, errors.New("peerstore should also be a certified address book") - } - rec, err := as.peerRecord(as.Addrs()) - if err != nil { - return nil, fmt.Errorf("failed to create signed record for self: %w", err) - } - if _, err := cab.ConsumePeerRecord(rec, peerstore.PermanentAddrTTL); err != nil { - return nil, fmt.Errorf("failed to persist signed record to peerstore: %w", err) - } + reachability: func() network.Reachability { + if h.GetAutoNat() != nil { + return h.GetAutoNat().Status() + } + return network.ReachabilityUnknown + }, + ctx: ctx, + ctxCancel: cancel, } return as, nil } -func (a *addressService) SetAutoNAT(an autonat.AutoNAT) { - a.autonat = an -} - func (a *addressService) Start() { a.wg.Add(1) go a.background() @@ -122,11 +95,7 @@ func (a *addressService) Close() { log.Warnf("error closing natmgr: %s", err) } } - err := a.emitter.Close() - if err != nil { - log.Warnf("error closing addrs update emitter: %s", err) - } - err = a.relayAddrsSub.Close() + err := a.autoRelayAddrsSub.Close() if err != nil { log.Warnf("error closing addrs update emitter: %s", err) } @@ -134,61 +103,36 @@ func (a *addressService) Close() { func (a *addressService) SignalAddressChange() { select { - case a.addrChangeChan <- struct{}{}: + case a.addrsChangeChan <- struct{}{}: default: } } +func (a *addressService) AddrsUpdated() chan struct{} { + return a.addrsUpdated +} + func (a *addressService) background() { defer a.wg.Done() - var lastAddrs []ma.Multiaddr - emitAddrChange := func(currentAddrs []ma.Multiaddr, lastAddrs []ma.Multiaddr) { - changeEvt := a.makeUpdatedAddrEvent(lastAddrs, currentAddrs) - if changeEvt == nil { - return - } - // Our addresses have changed. - // store the signed peer record in the peer store. - if a.peerRecord != nil { - cabook, ok := peerstore.GetCertifiedAddrBook(a.peerstore) - if !ok { - log.Errorf("peerstore doesn't implement certified address book") - return - } - if _, err := cabook.ConsumePeerRecord(changeEvt.SignedPeerRecord, peerstore.PermanentAddrTTL); err != nil { - log.Errorf("failed to persist signed peer record in peer store, err=%s", err) - return - } - } - // update host addresses in the peer store - removedAddrs := make([]ma.Multiaddr, 0, len(changeEvt.Removed)) - for _, ua := range changeEvt.Removed { - removedAddrs = append(removedAddrs, ua.Address) - } - a.peerstore.SetAddrs(a.id, currentAddrs, peerstore.PermanentAddrTTL) - a.peerstore.SetAddrs(a.id, removedAddrs, 0) - - // emit addr change event - if err := a.emitter.Emit(*changeEvt); err != nil { - log.Warnf("error emitting event for updated addrs: %s", err) - } - } + var prev []ma.Multiaddr - // periodically schedules an IdentifyPush to update our peers for changes - // in our address set (if needed) ticker := time.NewTicker(addrChangeTickrInterval) defer ticker.Stop() - for { curr := a.Addrs() - emitAddrChange(curr, lastAddrs) - lastAddrs = curr + if a.areAddrsDifferent(prev, curr) { + select { + case a.addrsUpdated <- struct{}{}: + default: + } + } + prev = curr select { case <-ticker.C: - case <-a.addrChangeChan: - case <-a.relayAddrsSub.Out(): + case <-a.addrsChangeChan: + case <-a.autoRelayAddrsSub.Out(): case <-a.ctx.Done(): return } @@ -201,9 +145,9 @@ func (a *addressService) background() { func (a *addressService) Addrs() []ma.Multiaddr { addrs := a.AllAddrs() // Delete public addresses if the node's reachability is private, and we have autorelay. - if a.autonat != nil && a.autorelay != nil && a.autonat.Status() == network.ReachabilityPrivate { + if a.reachability() == network.ReachabilityPrivate && a.autoRelayAddrs != nil { addrs = slices.DeleteFunc(addrs, func(a ma.Multiaddr) bool { return manet.IsPublicAddr(a) }) - addrs = append(addrs, a.autorelay.RelayAddrs()...) + addrs = append(addrs, a.autoRelayAddrs()...) } // Make a copy. Consumers can modify the slice elements addrs = slices.Clone(a.addrsFactory(addrs)) @@ -245,6 +189,7 @@ func (a *addressService) AllAddrs() []ma.Multiaddr { finalAddrs = slices.DeleteFunc(finalAddrs, func(a ma.Multiaddr) bool { return a.Equal(p2pCircuitAddr) }) + // Add certhashes for /webrtc-direct, /webtransport, etc addresses discovered // using identify. finalAddrs = a.addCertHashes(finalAddrs) @@ -270,7 +215,7 @@ func (a *addressService) appendNATAddrs(result []ma.Multiaddr, listenAddrs []ma. // we have a NAT device for _, listen := range listenAddrs { extMaddr := a.natmgr.GetMapping(listen) - result = appendValidNATAddrs(result, listen, extMaddr, a.observedAddrsService.ObservedAddrsFor, ifaceAddrs) + result = appendNATAddrsForListenAddrs(result, listen, extMaddr, a.observedAddrsService.ObservedAddrsFor, ifaceAddrs) } } else { if a.observedAddrsService != nil { @@ -321,55 +266,26 @@ func (a *addressService) addCertHashes(addrs []ma.Multiaddr) []ma.Multiaddr { return addrs } -func (a *addressService) makeUpdatedAddrEvent(prev, current []ma.Multiaddr) *event.EvtLocalAddressesUpdated { - if prev == nil && current == nil { - return nil +func (a *addressService) areAddrsDifferent(prev, current []ma.Multiaddr) bool { + prevmap := make(map[string]struct{}) + currmap := make(map[string]struct{}) + for _, p := range prev { + prevmap[string(p.Bytes())] = struct{}{} } - prevmap := make(map[string]ma.Multiaddr, len(prev)) - currmap := make(map[string]ma.Multiaddr, len(current)) - evt := &event.EvtLocalAddressesUpdated{Diffs: true} - addrsAdded := false - - for _, addr := range prev { - prevmap[string(addr.Bytes())] = addr - } - for _, addr := range current { - currmap[string(addr.Bytes())] = addr - } - for _, addr := range currmap { - _, ok := prevmap[string(addr.Bytes())] - updated := event.UpdatedAddress{Address: addr} - if ok { - updated.Action = event.Maintained - } else { - updated.Action = event.Added - addrsAdded = true - } - evt.Current = append(evt.Current, updated) - delete(prevmap, string(addr.Bytes())) + for _, c := range current { + currmap[string(c.Bytes())] = struct{}{} } - for _, addr := range prevmap { - updated := event.UpdatedAddress{Action: event.Removed, Address: addr} - evt.Removed = append(evt.Removed, updated) - } - - if !addrsAdded && len(evt.Removed) == 0 { - return nil + for p := range prevmap { + if _, ok := currmap[p]; !ok { + return true + } } - - // Our addresses have changed. Make a new signed peer record. - if a.peerRecord != nil { - // add signed peer record to the event - sr, err := a.peerRecord(current) - if err != nil { - log.Errorf("error creating a signed peer record from the set of current addresses, err=%s", err) - // drop this change - return nil + for c := range currmap { + if _, ok := prevmap[c]; !ok { + return true } - evt.SignedPeerRecord = sr } - - return evt + return false } const ifaceAddrsTTL = time.Minute @@ -505,9 +421,9 @@ func getAllPossibleLocalAddrs(listenAddr ma.Multiaddr, ifaceAddrs []ma.Multiaddr return append(resolved, listenAddr) } -// appendValidNATAddrs adds the NAT-ed addresses to the result. If the NAT device doesn't provide +// appendNATAddrsForListenAddrs adds the NAT-ed addresses to the result. If the NAT device doesn't provide // us with a public IP address, we use the observed addresses. -func appendValidNATAddrs(result []ma.Multiaddr, listenAddr ma.Multiaddr, natMapping ma.Multiaddr, +func appendNATAddrsForListenAddrs(result []ma.Multiaddr, listenAddr ma.Multiaddr, natMapping ma.Multiaddr, obsAddrsFunc func(ma.Multiaddr) []ma.Multiaddr, ifaceAddrs []ma.Multiaddr) []ma.Multiaddr { if natMapping == nil { diff --git a/p2p/host/basic/address_service_test.go b/p2p/host/basic/address_service_test.go index e099ba5737..6ee611c8a2 100644 --- a/p2p/host/basic/address_service_test.go +++ b/p2p/host/basic/address_service_test.go @@ -89,7 +89,7 @@ func TestAppendNATAddrs(t *testing.T) { } for _, tc := range cases { t.Run(tc.Name, func(t *testing.T) { - res := appendValidNATAddrs(nil, + res := appendNATAddrsForListenAddrs(nil, tc.Listen, tc.Nat, tc.ObsAddrFunc, ifaceAddrs) res = ma.Unique(res) require.ElementsMatch(t, tc.Expected, res, "%s\n%s", tc.Expected, res) diff --git a/p2p/host/basic/basic_host.go b/p2p/host/basic/basic_host.go index 61f4e1ebbd..ac7cb9bd0f 100644 --- a/p2p/host/basic/basic_host.go +++ b/p2p/host/basic/basic_host.go @@ -81,8 +81,6 @@ type BasicHost struct { eventbus event.Bus relayManager *relaysvc.RelayManager - AddrsFactory AddrsFactory - negtimeout time.Duration emitters struct { @@ -90,12 +88,11 @@ type BasicHost struct { evtLocalAddrsUpdated event.Emitter } - addrMu sync.RWMutex - disableSignedPeerRecord bool signKey crypto.PrivKey - autoNat autonat.AutoNAT + autoNATMx sync.RWMutex + autoNat autonat.AutoNAT autonatv2 *autonatv2.AutoNAT autorelay *autorelay.AutoRelay @@ -195,6 +192,9 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) { if h.emitters.evtLocalProtocolsUpdated, err = h.eventbus.Emitter(&event.EvtLocalProtocolsUpdated{}, eventbus.Stateful); err != nil { return nil, err } + if h.emitters.evtLocalAddrsUpdated, err = h.eventbus.Emitter(&event.EvtLocalAddressesUpdated{}, eventbus.Stateful); err != nil { + return nil, err + } if !opts.DisableSignedPeerRecord { h.signKey = h.Peerstore().PrivKey(h.ID()) @@ -304,10 +304,24 @@ func NewHost(n network.Network, opts *HostOpts) (*BasicHost, error) { n.SetStreamHandler(h.newStreamHandler) + if !h.disableSignedPeerRecord { + cab, ok := peerstore.GetCertifiedAddrBook(h.Peerstore()) + if !ok { + return nil, errors.New("peerstore should also be a certified address book") + } + rec, err := h.makeSignedPeerRecord(h.addressService.Addrs()) + if err != nil { + return nil, fmt.Errorf("failed to create signed record for self: %w", err) + } + if _, err := cab.ConsumePeerRecord(rec, peerstore.PermanentAddrTTL); err != nil { + return nil, fmt.Errorf("failed to persist signed record to peerstore: %w", err) + } + } + // register to be notified when the network's listen addrs change, // so we can update our address set and push events if needed listenHandler := func(network.Network, ma.Multiaddr) { - h.SignalAddressChange() + h.addressService.SignalAddressChange() } n.Notify(&network.NotifyBundle{ ListenF: listenHandler, @@ -330,7 +344,10 @@ func (h *BasicHost) Start() { log.Errorf("autonat v2 failed to start: %s", err) } } + h.refCount.Add(1) + go h.background() h.addressService.Start() + } // newStreamHandler is the remote-opened stream handler for network.Network @@ -381,11 +398,105 @@ func (h *BasicHost) newStreamHandler(s network.Stream) { handle(protoID, s) } -// SignalAddressChange signals to the host that it needs to determine whether our listen addresses have recently -// changed. -// Warning: this interface is unstable and may disappear in the future. -func (h *BasicHost) SignalAddressChange() { - h.addressService.SignalAddressChange() +func (h *BasicHost) background() { + defer h.refCount.Done() + var lastAddrs []ma.Multiaddr + + // TODO: Deprecate this event and logic + emitAddrChange := func(currentAddrs []ma.Multiaddr, lastAddrs []ma.Multiaddr) { + changeEvt := h.makeUpdatedAddrEvent(lastAddrs, currentAddrs) + if changeEvt == nil { + return + } + // Our addresses have changed. + // store the signed peer record in the peer store. + if !h.disableSignedPeerRecord { + cabook, ok := peerstore.GetCertifiedAddrBook(h.Peerstore()) + if !ok { + log.Errorf("peerstore doesn't implement certified address book") + return + } + if _, err := cabook.ConsumePeerRecord(changeEvt.SignedPeerRecord, peerstore.PermanentAddrTTL); err != nil { + log.Errorf("failed to persist signed peer record in peer store, err=%s", err) + return + } + } + // update host addresses in the peer store + removedAddrs := make([]ma.Multiaddr, 0, len(changeEvt.Removed)) + for _, ua := range changeEvt.Removed { + removedAddrs = append(removedAddrs, ua.Address) + } + h.Peerstore().SetAddrs(h.ID(), currentAddrs, peerstore.PermanentAddrTTL) + h.Peerstore().SetAddrs(h.ID(), removedAddrs, 0) + + // emit addr change event + if err := h.emitters.evtLocalAddrsUpdated.Emit(*changeEvt); err != nil { + log.Warnf("error emitting event for updated addrs: %s", err) + } + } + + for { + curr := h.Addrs() + emitAddrChange(curr, lastAddrs) + lastAddrs = curr + + select { + case <-h.addressService.AddrsUpdated(): + case <-h.ctx.Done(): + return + } + } +} + +func (h *BasicHost) makeUpdatedAddrEvent(prev, current []ma.Multiaddr) *event.EvtLocalAddressesUpdated { + if prev == nil && current == nil { + return nil + } + prevmap := make(map[string]ma.Multiaddr, len(prev)) + currmap := make(map[string]ma.Multiaddr, len(current)) + evt := &event.EvtLocalAddressesUpdated{Diffs: true} + addrsAdded := false + + for _, addr := range prev { + prevmap[string(addr.Bytes())] = addr + } + for _, addr := range current { + currmap[string(addr.Bytes())] = addr + } + for _, addr := range currmap { + _, ok := prevmap[string(addr.Bytes())] + updated := event.UpdatedAddress{Address: addr} + if ok { + updated.Action = event.Maintained + } else { + updated.Action = event.Added + addrsAdded = true + } + evt.Current = append(evt.Current, updated) + delete(prevmap, string(addr.Bytes())) + } + for _, addr := range prevmap { + updated := event.UpdatedAddress{Action: event.Removed, Address: addr} + evt.Removed = append(evt.Removed, updated) + } + + if !addrsAdded && len(evt.Removed) == 0 { + return nil + } + + // Our addresses have changed. Make a new signed peer record. + if !h.disableSignedPeerRecord { + // add signed peer record to the event + sr, err := h.makeSignedPeerRecord(current) + if err != nil { + log.Errorf("error creating a signed peer record from the set of current addresses, err=%s", err) + // drop this change + return nil + } + evt.SignedPeerRecord = sr + } + + return evt } func (h *BasicHost) makeSignedPeerRecord(addrs []ma.Multiaddr) (*record.Envelope, error) { @@ -653,18 +764,17 @@ func (h *BasicHost) AllAddrs() []ma.Multiaddr { // SetAutoNat sets the autonat service for the host. func (h *BasicHost) SetAutoNat(a autonat.AutoNAT) { - h.addrMu.Lock() - defer h.addrMu.Unlock() + h.autoNATMx.Lock() + defer h.autoNATMx.Unlock() if h.autoNat == nil { h.autoNat = a } - h.addressService.SetAutoNAT(h.autoNat) } // GetAutoNat returns the host's AutoNAT service, if AutoNAT is enabled. func (h *BasicHost) GetAutoNat() autonat.AutoNAT { - h.addrMu.Lock() - defer h.addrMu.Unlock() + h.autoNATMx.Lock() + defer h.autoNATMx.Unlock() return h.autoNat } diff --git a/p2p/host/basic/basic_host_test.go b/p2p/host/basic/basic_host_test.go index 2d254956d2..dcf174410b 100644 --- a/p2p/host/basic/basic_host_test.go +++ b/p2p/host/basic/basic_host_test.go @@ -734,7 +734,7 @@ func TestHostAddrChangeDetection(t *testing.T) { lk.Lock() currentAddrSet = i lk.Unlock() - h.SignalAddressChange() + h.addressService.SignalAddressChange() evt := waitForAddrChangeEvent(ctx, sub, t) if !updatedAddrEventsEqual(expectedEvents[i-1], evt) { t.Errorf("change events not equal: \n\texpected: %v \n\tactual: %v", expectedEvents[i-1], evt)