diff --git a/swarm_listen.go b/swarm_listen.go index ca54280c..38a46eca 100644 --- a/swarm_listen.go +++ b/swarm_listen.go @@ -9,6 +9,10 @@ import ( ma "github.com/multiformats/go-multiaddr" ) +var ( + ErrSwarmListenerAcceptError = fmt.Errorf("swarm listener accept error") +) + // Listen sets up listeners for all of the given addresses. // It returns as long as we successfully listen on at least *one* address. func (s *Swarm) Listen(addrs ...ma.Multiaddr) error { @@ -35,6 +39,20 @@ func (s *Swarm) Listen(addrs ...ma.Multiaddr) error { return nil } +// ListenClose stop and delete listeners for all of the given addresses. +func (s *Swarm) ListenClose(addrs ...ma.Multiaddr) { + s.listeners.Lock() + defer s.listeners.Unlock() + + for l := range s.listeners.m { + if !containsMultiaddr(addrs, l.Multiaddr()) { + continue + } + + l.Close() + } +} + // AddListenAddr tells the swarm to listen on a single address. Unlike Listen, // this method does not attempt to filter out bad addresses. func (s *Swarm) AddListenAddr(a ma.Multiaddr) error { @@ -79,24 +97,27 @@ func (s *Swarm) AddListenAddr(a ma.Multiaddr) error { go func() { defer func() { list.Close() + s.listeners.Lock() + defer s.listeners.Unlock() + delete(s.listeners.m, list) s.listeners.cacheEOL = time.Time{} - s.listeners.Unlock() // signal to our notifiees on listen close. s.notifyAll(func(n network.Notifiee) { n.ListenClose(s, maddr) }) s.refs.Done() + + // log if no more listener and the swarm is still running. + if len(s.listeners.m) == 0 && s.ctx.Err() == nil { + log.Error(ErrSwarmListenerAcceptError) + } }() for { c, err := list.Accept() if err != nil { - if s.ctx.Err() == nil { - // only log if the swarm is still running. - log.Errorf("swarm listener accept error: %s", err) - } return } @@ -119,3 +140,12 @@ func (s *Swarm) AddListenAddr(a ma.Multiaddr) error { }() return nil } + +func containsMultiaddr(addrs []ma.Multiaddr, addr ma.Multiaddr) bool { + for _, a := range addrs { + if addr == a { + return true + } + } + return false +} diff --git a/swarm_test.go b/swarm_test.go index 714941f5..f5f140ab 100644 --- a/swarm_test.go +++ b/swarm_test.go @@ -539,3 +539,57 @@ func TestResourceManagerAcceptStream(t *testing.T) { _, err = str.Read([]byte{0}) require.EqualError(t, err, "stream reset") } + +func TestListenCloseCount(t *testing.T) { + s := GenSwarm(t, OptDialOnly) + addrsToListen := []ma.Multiaddr{ + ma.StringCast("/ip4/0.0.0.0/tcp/0"), + ma.StringCast("/ip4/0.0.0.0/udp/0/quic"), + } + + if err := s.Listen(addrsToListen...); err != nil { + t.Fatal(err) + } + listenedAddrs := s.ListenAddresses() + require.Equal(t, 2, len(listenedAddrs)) + + s.ListenClose(listenedAddrs...) + time.Sleep(time.Millisecond) // wait for deferred listener deletion + + remainingAddrs := s.ListenAddresses() + require.Equal(t, 0, len(remainingAddrs)) +} + +func TestListenCloseLog(t *testing.T) { + var logs bytes.Buffer + + // pipe logs to buffer + lr := logging.NewPipeReader(logging.PipeFormat(logging.PlaintextOutput)) + defer lr.Close() + go io.Copy(&logs, lr) + + s := GenSwarm(t, OptDialOnly) + addrsToListen := []ma.Multiaddr{ + ma.StringCast("/ip4/0.0.0.0/tcp/0"), + ma.StringCast("/ip4/0.0.0.0/udp/0/quic"), + } + + if err := s.Listen(addrsToListen...); err != nil { + t.Fatal(err) + } + listenedAddrs := s.ListenAddresses() // closed 0, opened: 2 + + s.ListenClose(listenedAddrs[0]) // closed 1, opened: 1 + time.Sleep(time.Millisecond) // wait for deferred listener deletion + + require.False(t, containsError(logs, swarm.ErrSwarmListenerAcceptError)) + + s.ListenClose(listenedAddrs[1]) // closed 2, opened: 0 + time.Sleep(time.Millisecond) // wait for deferred listener deletion + + require.True(t, containsError(logs, swarm.ErrSwarmListenerAcceptError)) +} + +func containsError(logs bytes.Buffer, err error) bool { + return strings.Contains(logs.String(), err.Error()) +}