diff --git a/dht.go b/dht.go index c6bd51e39..68f16a8dc 100644 --- a/dht.go +++ b/dht.go @@ -163,6 +163,8 @@ type IpfsDHT struct { // addrFilter is used to filter the addresses we put into the peer store. // Mostly used to filter out localhost and local addresses. addrFilter func([]ma.Multiaddr) []ma.Multiaddr + + onRequestHook func(ctx context.Context, s network.Stream, req pb.Message) } // Assert that IPFS assumptions about interfaces aren't broken. These aren't a @@ -306,6 +308,7 @@ func makeDHT(h host.Host, cfg dhtcfg.Config) (*IpfsDHT, error) { routingTablePeerFilter: cfg.RoutingTable.PeerFilter, rtPeerDiversityFilter: cfg.RoutingTable.DiversityFilter, addrFilter: cfg.AddressFilter, + onRequestHook: cfg.OnRequestHook, fixLowPeersChan: make(chan struct{}, 1), diff --git a/dht_net.go b/dht_net.go index 640c7ea83..af7aee785 100644 --- a/dht_net.go +++ b/dht_net.go @@ -101,6 +101,10 @@ func (dht *IpfsDHT) handleNewMessage(s network.Stream) bool { metrics.ReceivedBytes.M(int64(msgLen)), ) + if dht.onRequestHook != nil { + dht.onRequestHook(ctx, s, req) + } + handler := dht.handlerForMsgType(req.GetType()) if handler == nil { stats.Record(ctx, metrics.ReceivedMessageErrors.M(1)) diff --git a/dht_options.go b/dht_options.go index 1c0b3d13b..5939d2653 100644 --- a/dht_options.go +++ b/dht_options.go @@ -1,6 +1,7 @@ package dht import ( + "context" "fmt" "testing" "time" @@ -12,6 +13,7 @@ import ( "github.com/libp2p/go-libp2p-kbucket/peerdiversity" record "github.com/libp2p/go-libp2p-record" "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/protocol" @@ -368,3 +370,14 @@ func WithCustomMessageSender(messageSenderBuilder func(h host.Host, protos []pro return nil } } + +// OnRequestHook registers a callback function that will be invoked for every +// incoming DHT protocol message. +// Note: Ensure that the callback executes efficiently, as it will block the +// entire message handler. +func OnRequestHook(f func(ctx context.Context, s network.Stream, req pb.Message)) Option { + return func(c *dhtcfg.Config) error { + c.OnRequestHook = f + return nil + } +} diff --git a/go.mod b/go.mod index ea11c1236..0de641a9e 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( github.com/ipfs/go-datastore v0.6.0 github.com/ipfs/go-detect-race v0.0.1 github.com/ipfs/go-log/v2 v2.5.1 + github.com/ipfs/go-test v0.0.4 github.com/libp2p/go-libp2p v0.38.2 github.com/libp2p/go-libp2p-kbucket v0.6.4 github.com/libp2p/go-libp2p-record v0.3.1 @@ -63,6 +64,8 @@ require ( github.com/gorilla/websocket v1.5.3 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/huin/goupnp v1.3.0 // indirect + github.com/ipfs/go-block-format v0.2.0 // indirect + github.com/ipfs/go-ipfs-util v0.0.3 // indirect github.com/ipfs/go-log v1.0.5 // indirect github.com/ipld/go-ipld-prime v0.21.0 // indirect github.com/jackpal/go-nat-pmp v1.0.2 // indirect diff --git a/internal/config/config.go b/internal/config/config.go index 3248b8171..919c2d0dd 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,6 +1,7 @@ package config import ( + "context" "fmt" "time" @@ -14,6 +15,7 @@ import ( "github.com/libp2p/go-libp2p-kbucket/peerdiversity" record "github.com/libp2p/go-libp2p-record" "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/protocol" ma "github.com/multiformats/go-multiaddr" @@ -63,6 +65,7 @@ type Config struct { BootstrapPeers func() []peer.AddrInfo AddressFilter func([]ma.Multiaddr) []ma.Multiaddr + OnRequestHook func(ctx context.Context, s network.Stream, req pb.Message) // test specific Config options DisableFixLowPeers bool