diff --git a/p2p/transport/webrtc/connection.go b/p2p/transport/webrtc/connection.go index 5ac6cb8d44..09e25d5a8a 100644 --- a/p2p/transport/webrtc/connection.go +++ b/p2p/transport/webrtc/connection.go @@ -53,9 +53,7 @@ type connection struct { remoteKey ic.PubKey remoteMultiaddr ma.Multiaddr - m sync.Mutex - streams map[uint16]*stream - + m sync.Mutex acceptQueue chan dataChannel ctx context.Context @@ -89,7 +87,6 @@ func newConnection( remoteMultiaddr: remoteMultiaddr, ctx: ctx, cancel: cancel, - streams: make(map[uint16]*stream), acceptQueue: make(chan dataChannel, maxAcceptQueueLen), } @@ -126,41 +123,26 @@ func (c *connection) ConnState() network.ConnectionState { // Close closes the underlying peerconnection. func (c *connection) Close() error { - c.closeOnce.Do(func() { - c.closeErr = errors.New("connection closed") - // cancel must be called after closeErr is set. This ensures interested goroutines waiting on - // ctx.Done can read closeErr without holding the conn lock. - c.cancel() - c.m.Lock() - streams := c.streams - c.streams = nil - c.m.Unlock() - for _, str := range streams { - str.Reset() - } - c.pc.Close() - c.scope.Done() - }) + c.closeOnce.Do(func() { c.closeWithError(errors.New("connection closed")) }) return nil } -func (c *connection) closeTimedOut() error { - c.closeOnce.Do(func() { - c.closeErr = errConnectionTimeout{} - // cancel must be called after closeErr is set. This ensures interested goroutines waiting on - // ctx.Done can read closeErr without holding the conn lock. - c.cancel() - c.m.Lock() - streams := c.streams - c.streams = nil - c.m.Unlock() - for _, str := range streams { - str.closeWithError(errConnectionTimeout{}) +// closeWithError is used to Close the connection when the underlying DTLS connection fails +func (c *connection) closeWithError(err error) { + c.closeErr = err + // cancel must be called after closeErr is set. This ensures interested goroutines waiting on + // ctx.Done can read closeErr without holding the conn lock. + c.cancel() + c.pc.Close() +loop: + for { + select { + case <-c.acceptQueue: + default: + break loop } - c.pc.Close() - c.scope.Done() - }) - return nil + } + c.scope.Done() } func (c *connection) IsClosed() bool { @@ -173,10 +155,6 @@ func (c *connection) IsClosed() bool { } func (c *connection) OpenStream(ctx context.Context) (network.MuxedStream, error) { - if c.IsClosed() { - return nil, c.closeErr - } - dc, err := c.pc.CreateDataChannel("", nil) if err != nil { return nil, err @@ -185,12 +163,11 @@ func (c *connection) OpenStream(ctx context.Context) (network.MuxedStream, error if err != nil { return nil, fmt.Errorf("open stream: %w", err) } - str := newStream(dc, rwc, func() { c.removeStream(*dc.ID()) }) - if err := c.addStream(str); err != nil { - str.Reset() - return nil, err + if c.IsClosed() { + dc.Close() + return nil, c.closeErr } - return str, nil + return newStream(dc, rwc, nil), nil } func (c *connection) AcceptStream() (network.MuxedStream, error) { @@ -198,12 +175,11 @@ func (c *connection) AcceptStream() (network.MuxedStream, error) { case <-c.ctx.Done(): return nil, c.closeErr case dc := <-c.acceptQueue: - str := newStream(dc.channel, dc.stream, func() { c.removeStream(*dc.channel.ID()) }) - if err := c.addStream(str); err != nil { - str.Reset() - return nil, err + if c.IsClosed() { + dc.channel.Close() + return nil, c.closeErr } - return str, nil + return newStream(dc.channel, dc.stream, nil), nil } } @@ -215,28 +191,9 @@ func (c *connection) RemoteMultiaddr() ma.Multiaddr { return c.remoteMultiaddr } func (c *connection) Scope() network.ConnScope { return c.scope } func (c *connection) Transport() tpt.Transport { return c.transport } -func (c *connection) addStream(str *stream) error { - c.m.Lock() - defer c.m.Unlock() - if c.IsClosed() { - return fmt.Errorf("connection closed: %w", c.closeErr) - } - if _, ok := c.streams[str.id]; ok { - return errors.New("stream ID already exists") - } - c.streams[str.id] = str - return nil -} - -func (c *connection) removeStream(id uint16) { - c.m.Lock() - defer c.m.Unlock() - delete(c.streams, id) -} - func (c *connection) onConnectionStateChange(state webrtc.PeerConnectionState) { if state == webrtc.PeerConnectionStateFailed || state == webrtc.PeerConnectionStateClosed { - c.closeTimedOut() + c.closeWithError(errConnectionTimeout{}) } } diff --git a/p2p/transport/webrtc/listener.go b/p2p/transport/webrtc/listener.go index c494da7072..5809c918e6 100644 --- a/p2p/transport/webrtc/listener.go +++ b/p2p/transport/webrtc/listener.go @@ -271,7 +271,8 @@ func (l *listener) setupConnection( localMultiaddrWithoutCerthash, _ := ma.SplitFunc(l.localMultiaddr, func(c ma.Component) bool { return c.Protocol().Code == ma.P_CERTHASH }) - handshakeChannel := newStream(rawDatachannel, rwc, func() {}) + s := newStream(rawDatachannel, rwc, func() {}) + defer s.Close() // The connection is instantiated before performing the Noise handshake. This is // to handle the case where the remote is faster and attempts to initiate a stream // before the ondatachannel callback can be set. @@ -291,7 +292,7 @@ func (l *listener) setupConnection( } // we do not yet know A's peer ID so accept any inbound - remotePubKey, err := l.transport.noiseHandshake(ctx, conn, handshakeChannel, "", crypto.SHA256, true) + remotePubKey, err := l.transport.noiseHandshake(ctx, conn, s, "", crypto.SHA256, true) if err != nil { conn.Close() return nil, err diff --git a/p2p/transport/webrtc/transport.go b/p2p/transport/webrtc/transport.go index aad23e82e4..3f52adaed1 100644 --- a/p2p/transport/webrtc/transport.go +++ b/p2p/transport/webrtc/transport.go @@ -385,6 +385,7 @@ func (t *WebRTCTransport) dial(ctx context.Context, scope network.ConnManagement } s := newStream(handshakeChannel, detached, func() {}) + defer s.Close() // the local address of the selected candidate pair should be the // local address for the connection, since different datachannels // are multiplexed over the same SCTP connection