diff --git a/config/config.go b/config/config.go index 52b3562e68..21cbbb66e8 100644 --- a/config/config.go +++ b/config/config.go @@ -345,13 +345,13 @@ func (cfg *Config) NewNode() (host.Host, error) { rcmgr.MustRegisterWith(cfg.PrometheusRegisterer) } - var autonatv2Dialer network.Network + var autonatv2Dialer *blankhost.BlankHost if !cfg.DisableAutoNATv2 { ah, err := cfg.makeAutoNATHost() if err != nil { return nil, err } - autonatv2Dialer = ah.Network() + autonatv2Dialer = ah } h, err := bhost.NewHost(swrm, &bhost.HostOpts{ diff --git a/p2p/host/basic/basic_host.go b/p2p/host/basic/basic_host.go index 8873fedfa3..3e507a5184 100644 --- a/p2p/host/basic/basic_host.go +++ b/p2p/host/basic/basic_host.go @@ -165,7 +165,7 @@ type HostOpts struct { PrometheusRegisterer prometheus.Registerer EnableAutoNATv2 bool - AutoNATv2Dialer network.Network + AutoNATv2Dialer host.Host } // NewHost constructs a new *BasicHost and activates it by attaching its stream and connection handlers to the given inet.Network. diff --git a/p2p/protocol/autonatv2/server_test.go b/p2p/protocol/autonatv2/server_test.go index e6c542a9fd..8853dc35df 100644 --- a/p2p/protocol/autonatv2/server_test.go +++ b/p2p/protocol/autonatv2/server_test.go @@ -98,3 +98,27 @@ func TestServerDataRequest(t *testing.T) { require.Equal(t, res[0].Rch, network.ReachabilityPublic) } + +func TestServerDial(t *testing.T) { + h := bhost.NewBlankHost(swarmt.GenSwarm(t)) + dialer := bhost.NewBlankHost(swarmt.GenSwarm(t)) + as := NewServer(h, dialer, nil, true) + defer as.host.Close() + as.Start() + + c := newAutoNAT(t) + c.allowAllAddrs = true + defer c.Close() + defer c.host.Close() + + c.host.Peerstore().AddAddrs(as.host.ID(), as.host.Addrs(), peerstore.PermanentAddrTTL) + c.host.Peerstore().AddProtocols(as.host.ID(), DialProtocol) + randAddr := ma.StringCast("/ip4/1.2.3.4/tcp/2") + res, err := c.CheckReachability(context.Background(), []ma.Multiaddr{randAddr}, c.host.Addrs()) + require.NoError(t, err) + require.Equal(t, res[0].Rch, network.ReachabilityPrivate) + + res, err = c.CheckReachability(context.Background(), nil, c.host.Addrs()) + require.NoError(t, err) + require.Equal(t, res[0].Rch, network.ReachabilityPublic) +}