From 3f82cfdbf0fe1f6210906c3fa872a21ff98006ef Mon Sep 17 00:00:00 2001 From: Atsushi Watanabe Date: Tue, 23 Mar 2021 20:17:23 +0900 Subject: [PATCH] Fix packet handling before init It caused panic when some kind of packets are arrived before completing initialization. Add error check to CookieEcho and Data handlers. --- association.go | 5 ++ association_test.go | 137 ++++++++++++++++++++++++++++++++++++++++++ chunk_payload_data.go | 6 ++ 3 files changed, 148 insertions(+) diff --git a/association.go b/association.go index f811f966..298f564b 100644 --- a/association.go +++ b/association.go @@ -1146,6 +1146,11 @@ func (a *Association) handleHeartbeat(c *chunkHeartbeat) []*packet { func (a *Association) handleCookieEcho(c *chunkCookieEcho) []*packet { state := a.getState() a.log.Debugf("[%s] COOKIE-ECHO received in state '%s'", a.name, getAssociationStateString(state)) + + if a.myCookie == nil { + a.log.Debugf("[%s] COOKIE-ECHO received before initialization", a.name) + return nil + } switch state { default: return nil diff --git a/association_test.go b/association_test.go index d810c5cb..53ed54af 100644 --- a/association_test.go +++ b/association_test.go @@ -2650,3 +2650,140 @@ func TestAssociation_ShutdownDuringWrite(t *testing.T) { assert.Fail(t, "timed out waiting for a2 read loop to close") } } + +func TestAssociation_HandlePacketBeforeInit(t *testing.T) { + loggerFactory := logging.NewDefaultLoggerFactory() + + testCases := map[string]struct { + inputPacket *packet + }{ + "InitAck": { + inputPacket: &packet{ + sourcePort: 1, + destinationPort: 1, + chunks: []chunk{ + &chunkInitAck{ + chunkInitCommon: chunkInitCommon{ + initiateTag: 1, + numInboundStreams: 1, + numOutboundStreams: 1, + advertisedReceiverWindowCredit: 1500, + }, + }, + }, + }, + }, + "Abort": { + inputPacket: &packet{ + sourcePort: 1, + destinationPort: 1, + chunks: []chunk{&chunkAbort{}}, + }, + }, + "CoockeEcho": { + inputPacket: &packet{ + sourcePort: 1, + destinationPort: 1, + chunks: []chunk{&chunkCookieEcho{}}, + }, + }, + "HeartBeat": { + inputPacket: &packet{ + sourcePort: 1, + destinationPort: 1, + chunks: []chunk{&chunkHeartbeat{}}, + }, + }, + "PayloadData": { + inputPacket: &packet{ + sourcePort: 1, + destinationPort: 1, + chunks: []chunk{&chunkPayloadData{}}, + }, + }, + "Sack": { + inputPacket: &packet{ + sourcePort: 1, + destinationPort: 1, + chunks: []chunk{&chunkSelectiveAck{ + cumulativeTSNAck: 1000, + advertisedReceiverWindowCredit: 1500, + gapAckBlocks: []gapAckBlock{ + {start: 100, end: 200}, + }, + }}, + }, + }, + "Reconfig": { + inputPacket: &packet{ + sourcePort: 1, + destinationPort: 1, + chunks: []chunk{&chunkReconfig{ + paramA: ¶mOutgoingResetRequest{}, + paramB: ¶mReconfigResponse{}, + }}, + }, + }, + "ForwardTSN": { + inputPacket: &packet{ + sourcePort: 1, + destinationPort: 1, + chunks: []chunk{&chunkForwardTSN{ + newCumulativeTSN: 100, + }}, + }, + }, + "Error": { + inputPacket: &packet{ + sourcePort: 1, + destinationPort: 1, + chunks: []chunk{&chunkError{}}, + }, + }, + "Shutdown": { + inputPacket: &packet{ + sourcePort: 1, + destinationPort: 1, + chunks: []chunk{&chunkShutdown{}}, + }, + }, + "ShutdownAck": { + inputPacket: &packet{ + sourcePort: 1, + destinationPort: 1, + chunks: []chunk{&chunkShutdownAck{}}, + }, + }, + "ShutdownComplete": { + inputPacket: &packet{ + sourcePort: 1, + destinationPort: 1, + chunks: []chunk{&chunkShutdownComplete{}}, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + aConn, charlieConn := pipeDump() + a := createAssociation(Config{ + NetConn: aConn, + MaxReceiveBufferSize: 0, + LoggerFactory: loggerFactory, + }) + a.init(true) + defer func() { + assert.NoError(t, a.close()) + }() + + packet, err := testCase.inputPacket.marshal() + assert.NoError(t, err) + _, err = charlieConn.Write(packet) + assert.NoError(t, err) + + // Should not panic. + time.Sleep(100 * time.Millisecond) + }) + } +} diff --git a/chunk_payload_data.go b/chunk_payload_data.go index 3e5e3468..5d3057ed 100644 --- a/chunk_payload_data.go +++ b/chunk_payload_data.go @@ -2,6 +2,7 @@ package sctp import ( "encoding/binary" + "errors" "fmt" "time" ) @@ -95,6 +96,8 @@ const ( PayloadTypeWebRTCBinaryEmpty PayloadProtocolIdentifier = 57 ) +var errChunkPayloadSmall = errors.New("packet is smaller than the header size") + func (p PayloadProtocolIdentifier) String() string { switch p { case PayloadTypeWebRTCDCEP: @@ -122,6 +125,9 @@ func (p *chunkPayloadData) unmarshal(raw []byte) error { p.beginningFragment = p.flags&payloadDataBeginingFragmentBitmask != 0 p.endingFragment = p.flags&payloadDataEndingFragmentBitmask != 0 + if len(raw) < payloadDataHeaderSize { + return errChunkPayloadSmall + } p.tsn = binary.BigEndian.Uint32(p.raw[0:]) p.streamIdentifier = binary.BigEndian.Uint16(p.raw[4:]) p.streamSequenceNumber = binary.BigEndian.Uint16(p.raw[6:])