diff --git a/peer/brontide_test.go b/peer/brontide_test.go index 37803ecbc1..d127176d54 100644 --- a/peer/brontide_test.go +++ b/peer/brontide_test.go @@ -1344,3 +1344,109 @@ func TestHandleRemovePendingChannel(t *testing.T) { }) } } + +// TestStartupWriteMessageRace checks that no data race occurs when starting up +// a peer with an existing channel, while an outgoing message is queuing. Such +// a race occurred in https://github.com/lightningnetwork/lnd/issues/8184, where +// a channel reestablish message raced with another outgoing message. +// +// Note that races will only be detected with the Go race detector enabled. +func TestStartupWriteMessageRace(t *testing.T) { + t.Parallel() + + // Set up parameters for createTestPeer. + notifier := &mock.ChainNotifier{ + SpendChan: make(chan *chainntnfs.SpendDetail), + EpochChan: make(chan *chainntnfs.BlockEpoch), + ConfChan: make(chan *chainntnfs.TxConfirmation), + } + broadcastTxChan := make(chan *wire.MsgTx) + mockSwitch := &mockMessageSwitch{} + + // Use a callback to extract the channel created by createTestPeer, so + // we can mark it borked below. We can't mark it borked within the + // callback, since the channel hasn't been saved to the DB yet when the + // callback executes. + var channel *channeldb.OpenChannel + getChannels := func(a, b *channeldb.OpenChannel) { + channel = a + } + + // createTestPeer creates a peer and a channel with that peer. + peer, _, err := createTestPeer( + t, notifier, broadcastTxChan, getChannels, mockSwitch, + ) + require.NoError(t, err, "unable to create test channel") + + // Avoid the need to mock the channel graph by marking the channel + // borked. Borked channels still get a reestablish message sent on + // reconnect, while skipping channel graph checks and link creation. + require.NoError(t, channel.MarkBorked()) + + // Use a mock conn to detect read/write races on the conn. + mockConn := newMockConn(t, 2) + peer.cfg.Conn = mockConn + + // Set up other configuration necessary to successfully execute + // peer.Start(). + peer.cfg.LegacyFeatures = lnwire.EmptyFeatureVector() + writeBufferPool := pool.NewWriteBuffer( + pool.DefaultWriteBufferGCInterval, + pool.DefaultWriteBufferExpiryInterval, + ) + writePool := pool.NewWrite( + writeBufferPool, 1, timeout, + ) + require.NoError(t, writePool.Start()) + peer.cfg.WritePool = writePool + readBufferPool := pool.NewReadBuffer( + pool.DefaultReadBufferGCInterval, + pool.DefaultReadBufferExpiryInterval, + ) + readPool := pool.NewRead( + readBufferPool, 1, timeout, + ) + require.NoError(t, readPool.Start()) + peer.cfg.ReadPool = readPool + + // Send a message while starting the peer. As the peer starts up, it + // should not trigger a data race between the sending of this message + // and the sending of the channel reestablish message. + sendPingDone := make(chan struct{}) + go func() { + require.NoError(t, peer.SendMessage(true, lnwire.NewPing(0))) + close(sendPingDone) + }() + + // Handle init messages. + go func() { + // Read init message. + <-mockConn.writtenMessages + + // Write the init reply message. + initReplyMsg := lnwire.NewInitMessage( + lnwire.NewRawFeatureVector( + lnwire.DataLossProtectRequired, + ), + lnwire.NewRawFeatureVector(), + ) + var b bytes.Buffer + _, err = lnwire.WriteMessage(&b, initReplyMsg, 0) + require.NoError(t, err) + + mockConn.readMessages <- b.Bytes() + }() + + // Start the peer. No data race should occur. + require.NoError(t, peer.Start()) + + // Ensure messages were sent during startup. + <-sendPingDone + for i := 0; i < 2; i++ { + select { + case <-mockConn.writtenMessages: + default: + t.Fatalf("Failed to send all messages during startup") + } + } +} diff --git a/peer/test_utils.go b/peer/test_utils.go index add15cf19d..c7509c0c82 100644 --- a/peer/test_utils.go +++ b/peer/test_utils.go @@ -497,6 +497,16 @@ type mockMessageConn struct { readMessages chan []byte curReadMessage []byte + + // writeRaceDetectingCounter is incremented on any function call + // associated with writing to the connection. The race detector will + // trigger on this counter if a data race exists. + writeRaceDetectingCounter int + + // readRaceDetectingCounter is incremented on any function call + // associated with reading from the connection. The race detector will + // trigger on this counter if a data race exists. + readRaceDetectingCounter int } func newMockConn(t *testing.T, expectedMessages int) *mockMessageConn { @@ -509,17 +519,20 @@ func newMockConn(t *testing.T, expectedMessages int) *mockMessageConn { // SetWriteDeadline mocks setting write deadline for our conn. func (m *mockMessageConn) SetWriteDeadline(time.Time) error { + m.writeRaceDetectingCounter++ return nil } // Flush mocks a message conn flush. func (m *mockMessageConn) Flush() (int, error) { + m.writeRaceDetectingCounter++ return 0, nil } // WriteMessage mocks sending of a message on our connection. It will push // the bytes sent into the mock's writtenMessages channel. func (m *mockMessageConn) WriteMessage(msg []byte) error { + m.writeRaceDetectingCounter++ select { case m.writtenMessages <- msg: case <-time.After(timeout): @@ -542,15 +555,18 @@ func (m *mockMessageConn) assertWrite(expected []byte) { } func (m *mockMessageConn) SetReadDeadline(t time.Time) error { + m.readRaceDetectingCounter++ return nil } func (m *mockMessageConn) ReadNextHeader() (uint32, error) { + m.readRaceDetectingCounter++ m.curReadMessage = <-m.readMessages return uint32(len(m.curReadMessage)), nil } func (m *mockMessageConn) ReadNextBody(buf []byte) ([]byte, error) { + m.readRaceDetectingCounter++ return m.curReadMessage, nil } @@ -561,3 +577,7 @@ func (m *mockMessageConn) RemoteAddr() net.Addr { func (m *mockMessageConn) LocalAddr() net.Addr { return nil } + +func (m *mockMessageConn) Close() error { + return nil +}