diff --git a/dht.go b/dht.go index 0a6f2ecb..3aa962b9 100644 --- a/dht.go +++ b/dht.go @@ -163,6 +163,8 @@ type IpfsDHT struct { // addrFilter is used to filter the addresses we put into the peer store. // Mostly used to filter out localhost and local addresses. addrFilter func([]ma.Multiaddr) []ma.Multiaddr + + onRequestHook func(ctx context.Context, s network.Stream, req pb.Message) } // Assert that IPFS assumptions about interfaces aren't broken. These aren't a @@ -306,6 +308,7 @@ func makeDHT(h host.Host, cfg dhtcfg.Config) (*IpfsDHT, error) { routingTablePeerFilter: cfg.RoutingTable.PeerFilter, rtPeerDiversityFilter: cfg.RoutingTable.DiversityFilter, addrFilter: cfg.AddressFilter, + onRequestHook: cfg.OnRequestHook, fixLowPeersChan: make(chan struct{}, 1), diff --git a/dht_net.go b/dht_net.go index 3e135df1..3fd113fc 100644 --- a/dht_net.go +++ b/dht_net.go @@ -100,6 +100,10 @@ func (dht *IpfsDHT) handleNewMessage(s network.Stream) bool { metrics.ReceivedBytes.M(int64(msgLen)), ) + if dht.onRequestHook != nil { + dht.onRequestHook(ctx, s, req) + } + handler := dht.handlerForMsgType(req.GetType()) if handler == nil { stats.Record(ctx, metrics.ReceivedMessageErrors.M(1)) diff --git a/dht_options.go b/dht_options.go index 1c0b3d13..5939d265 100644 --- a/dht_options.go +++ b/dht_options.go @@ -1,6 +1,7 @@ package dht import ( + "context" "fmt" "testing" "time" @@ -12,6 +13,7 @@ import ( "github.com/libp2p/go-libp2p-kbucket/peerdiversity" record "github.com/libp2p/go-libp2p-record" "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/protocol" @@ -368,3 +370,14 @@ func WithCustomMessageSender(messageSenderBuilder func(h host.Host, protos []pro return nil } } + +// OnRequestHook registers a callback function that will be invoked for every +// incoming DHT protocol message. +// Note: Ensure that the callback executes efficiently, as it will block the +// entire message handler. +func OnRequestHook(f func(ctx context.Context, s network.Stream, req pb.Message)) Option { + return func(c *dhtcfg.Config) error { + c.OnRequestHook = f + return nil + } +} diff --git a/internal/config/config.go b/internal/config/config.go index 3248b817..919c2d0d 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,6 +1,7 @@ package config import ( + "context" "fmt" "time" @@ -14,6 +15,7 @@ import ( "github.com/libp2p/go-libp2p-kbucket/peerdiversity" record "github.com/libp2p/go-libp2p-record" "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/protocol" ma "github.com/multiformats/go-multiaddr" @@ -63,6 +65,7 @@ type Config struct { BootstrapPeers func() []peer.AddrInfo AddressFilter func([]ma.Multiaddr) []ma.Multiaddr + OnRequestHook func(ctx context.Context, s network.Stream, req pb.Message) // test specific Config options DisableFixLowPeers bool