Skip to content

Commit

Permalink
Merge pull request lightningnetwork#8198 from morehouse/brontide_star…
Browse files Browse the repository at this point in the history
…tup_race_test

peer: test for startup writeHander data race
  • Loading branch information
Roasbeef authored Nov 18, 2023
2 parents e8bdf01 + f0ae5b2 commit d9b88fb
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 0 deletions.
106 changes: 106 additions & 0 deletions peer/brontide_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1344,3 +1344,109 @@ func TestHandleRemovePendingChannel(t *testing.T) {
})
}
}

// TestStartupWriteMessageRace checks that no data race occurs when starting up
// a peer with an existing channel, while an outgoing message is queuing. Such
// a race occurred in https://github.com/lightningnetwork/lnd/issues/8184, where
// a channel reestablish message raced with another outgoing message.
//
// Note that races will only be detected with the Go race detector enabled.
func TestStartupWriteMessageRace(t *testing.T) {
t.Parallel()

// Set up parameters for createTestPeer.
notifier := &mock.ChainNotifier{
SpendChan: make(chan *chainntnfs.SpendDetail),
EpochChan: make(chan *chainntnfs.BlockEpoch),
ConfChan: make(chan *chainntnfs.TxConfirmation),
}
broadcastTxChan := make(chan *wire.MsgTx)
mockSwitch := &mockMessageSwitch{}

// Use a callback to extract the channel created by createTestPeer, so
// we can mark it borked below. We can't mark it borked within the
// callback, since the channel hasn't been saved to the DB yet when the
// callback executes.
var channel *channeldb.OpenChannel
getChannels := func(a, b *channeldb.OpenChannel) {
channel = a
}

// createTestPeer creates a peer and a channel with that peer.
peer, _, err := createTestPeer(
t, notifier, broadcastTxChan, getChannels, mockSwitch,
)
require.NoError(t, err, "unable to create test channel")

// Avoid the need to mock the channel graph by marking the channel
// borked. Borked channels still get a reestablish message sent on
// reconnect, while skipping channel graph checks and link creation.
require.NoError(t, channel.MarkBorked())

// Use a mock conn to detect read/write races on the conn.
mockConn := newMockConn(t, 2)
peer.cfg.Conn = mockConn

// Set up other configuration necessary to successfully execute
// peer.Start().
peer.cfg.LegacyFeatures = lnwire.EmptyFeatureVector()
writeBufferPool := pool.NewWriteBuffer(
pool.DefaultWriteBufferGCInterval,
pool.DefaultWriteBufferExpiryInterval,
)
writePool := pool.NewWrite(
writeBufferPool, 1, timeout,
)
require.NoError(t, writePool.Start())
peer.cfg.WritePool = writePool
readBufferPool := pool.NewReadBuffer(
pool.DefaultReadBufferGCInterval,
pool.DefaultReadBufferExpiryInterval,
)
readPool := pool.NewRead(
readBufferPool, 1, timeout,
)
require.NoError(t, readPool.Start())
peer.cfg.ReadPool = readPool

// Send a message while starting the peer. As the peer starts up, it
// should not trigger a data race between the sending of this message
// and the sending of the channel reestablish message.
sendPingDone := make(chan struct{})
go func() {
require.NoError(t, peer.SendMessage(true, lnwire.NewPing(0)))
close(sendPingDone)
}()

// Handle init messages.
go func() {
// Read init message.
<-mockConn.writtenMessages

// Write the init reply message.
initReplyMsg := lnwire.NewInitMessage(
lnwire.NewRawFeatureVector(
lnwire.DataLossProtectRequired,
),
lnwire.NewRawFeatureVector(),
)
var b bytes.Buffer
_, err = lnwire.WriteMessage(&b, initReplyMsg, 0)
require.NoError(t, err)

mockConn.readMessages <- b.Bytes()
}()

// Start the peer. No data race should occur.
require.NoError(t, peer.Start())

// Ensure messages were sent during startup.
<-sendPingDone
for i := 0; i < 2; i++ {
select {
case <-mockConn.writtenMessages:
default:
t.Fatalf("Failed to send all messages during startup")
}
}
}
20 changes: 20 additions & 0 deletions peer/test_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,16 @@ type mockMessageConn struct {

readMessages chan []byte
curReadMessage []byte

// writeRaceDetectingCounter is incremented on any function call
// associated with writing to the connection. The race detector will
// trigger on this counter if a data race exists.
writeRaceDetectingCounter int

// readRaceDetectingCounter is incremented on any function call
// associated with reading from the connection. The race detector will
// trigger on this counter if a data race exists.
readRaceDetectingCounter int
}

func newMockConn(t *testing.T, expectedMessages int) *mockMessageConn {
Expand All @@ -509,17 +519,20 @@ func newMockConn(t *testing.T, expectedMessages int) *mockMessageConn {

// SetWriteDeadline mocks setting write deadline for our conn.
func (m *mockMessageConn) SetWriteDeadline(time.Time) error {
m.writeRaceDetectingCounter++
return nil
}

// Flush mocks a message conn flush.
func (m *mockMessageConn) Flush() (int, error) {
m.writeRaceDetectingCounter++
return 0, nil
}

// WriteMessage mocks sending of a message on our connection. It will push
// the bytes sent into the mock's writtenMessages channel.
func (m *mockMessageConn) WriteMessage(msg []byte) error {
m.writeRaceDetectingCounter++
select {
case m.writtenMessages <- msg:
case <-time.After(timeout):
Expand All @@ -542,15 +555,18 @@ func (m *mockMessageConn) assertWrite(expected []byte) {
}

func (m *mockMessageConn) SetReadDeadline(t time.Time) error {
m.readRaceDetectingCounter++
return nil
}

func (m *mockMessageConn) ReadNextHeader() (uint32, error) {
m.readRaceDetectingCounter++
m.curReadMessage = <-m.readMessages
return uint32(len(m.curReadMessage)), nil
}

func (m *mockMessageConn) ReadNextBody(buf []byte) ([]byte, error) {
m.readRaceDetectingCounter++
return m.curReadMessage, nil
}

Expand All @@ -561,3 +577,7 @@ func (m *mockMessageConn) RemoteAddr() net.Addr {
func (m *mockMessageConn) LocalAddr() net.Addr {
return nil
}

func (m *mockMessageConn) Close() error {
return nil
}

0 comments on commit d9b88fb

Please sign in to comment.