diff --git a/p2p/host/basic/basic_host.go b/p2p/host/basic/basic_host.go index 61e113850a..8c608afa58 100644 --- a/p2p/host/basic/basic_host.go +++ b/p2p/host/basic/basic_host.go @@ -717,7 +717,7 @@ func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.I } defer func() { if strErr != nil && s != nil { - s.Reset() + s.ResetWithError(network.StreamProtocolNegotiationFailed) } }() @@ -761,13 +761,13 @@ func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.I return nil, fmt.Errorf("failed to negotiate protocol: %w", err) } case <-ctx.Done(): - s.ResetWithError(network.StreamProtocolNegotiationFailed) // wait for `SelectOneOf` to error out because of resetting the stream. <-errCh return nil, fmt.Errorf("failed to negotiate protocol: %w", ctx.Err()) } if err := s.SetProtocol(selected); err != nil { + s.ResetWithError(network.StreamResourceLimitExceeded) return nil, err } _ = h.Peerstore().AddProtocols(p, selected) // adding the protocol to the peerstore isn't critical diff --git a/p2p/host/basic/basic_host_test.go b/p2p/host/basic/basic_host_test.go index dac280731b..2a7a772976 100644 --- a/p2p/host/basic/basic_host_test.go +++ b/p2p/host/basic/basic_host_test.go @@ -995,35 +995,3 @@ func TestHostTimeoutNewStream(t *testing.T) { require.Error(t, err) require.ErrorContains(t, err, "context deadline exceeded") } - -func TestMultistreamFailure(t *testing.T) { - h1, err := NewHost(swarmt.GenSwarm(t), nil) - require.NoError(t, err) - h1.Start() - defer h1.Close() - - h2, err := NewHost(swarmt.GenSwarm(t), nil) - require.NoError(t, err) - h2.Start() - defer h2.Close() - - h2.Peerstore().AddProtocols(h1.ID(), "/test") - - err = h2.Connect(context.Background(), h1.Peerstore().PeerInfo(h1.ID())) - require.NoError(t, err) - h2.Peerstore().AddProtocols(h1.ID(), "/test") - s, err := h2.NewStream(context.Background(), h1.ID(), "/test") - require.NoError(t, err) - // Special string to make the other side fail multistream and reset - buf := make([]byte, 1024) - for i := 0; i < len(buf); i++ { - buf[i] = 0xff - } - _, err = s.Write(buf) - require.NoError(t, err) - _, err = s.Read(buf) - var se *network.StreamError - require.ErrorAs(t, err, &se) - require.True(t, se.Remote) - require.Equal(t, network.StreamProtocolNegotiationFailed, se.ErrorCode) -} diff --git a/p2p/net/connmgr/connmgr_test.go b/p2p/net/connmgr/connmgr_test.go index 76e70db5f2..4772837fe1 100644 --- a/p2p/net/connmgr/connmgr_test.go +++ b/p2p/net/connmgr/connmgr_test.go @@ -1024,7 +1024,7 @@ func TestErrorCode(t *testing.T) { } c21 = conns[0] return true - }, 5*time.Second, 100*time.Millisecond) + }, 10*time.Second, 100*time.Millisecond) c13, err := sw1.DialPeer(context.Background(), sw3.LocalPeer()) require.NoError(t, err) @@ -1037,7 +1037,7 @@ func TestErrorCode(t *testing.T) { } c31 = conns[0] return true - }, 5*time.Second, 100*time.Millisecond) + }, 10*time.Second, 100*time.Millisecond) cm.TrimOpenConns(context.Background())