Skip to content

Commit

Permalink
Fix packet handling before init
Browse files Browse the repository at this point in the history
It caused panic when some kind of packets are arrived
before completing initialization.
Add error check to CookieEcho and Data handlers.
  • Loading branch information
at-wat committed Mar 23, 2021
1 parent 4d8cfd4 commit 156e8a7
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 0 deletions.
8 changes: 8 additions & 0 deletions association.go
Original file line number Diff line number Diff line change
Expand Up @@ -1150,10 +1150,18 @@ func (a *Association) handleCookieEcho(c *chunkCookieEcho) []*packet {
default:
return nil
case established:
if a.myCookie == nil {
a.log.Debugf("[%s] Association is not initialized", a.name)
return nil
}
if !bytes.Equal(a.myCookie.cookie, c.cookie) {
return nil
}
case closed, cookieWait, cookieEchoed:
if a.myCookie == nil {
a.log.Debugf("[%s] Association is not initialized", a.name)
return nil
}
if !bytes.Equal(a.myCookie.cookie, c.cookie) {
return nil
}
Expand Down
135 changes: 135 additions & 0 deletions association_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2650,3 +2650,138 @@ func TestAssociation_ShutdownDuringWrite(t *testing.T) {
assert.Fail(t, "timed out waiting for a2 read loop to close")
}
}

func TestAssociation_HandlePacketBeforeInit(t *testing.T) {
loggerFactory := logging.NewDefaultLoggerFactory()

testCases := map[string]struct {
inputPacket *packet
}{
"InitAck": {
inputPacket: &packet{
sourcePort: 1,
destinationPort: 1,
chunks: []chunk{
&chunkInitAck{
chunkInitCommon: chunkInitCommon{
initiateTag: 1,
numInboundStreams: 1,
numOutboundStreams: 1,
advertisedReceiverWindowCredit: 1500,
},
},
},
},
},
"Abort": {
inputPacket: &packet{
sourcePort: 1,
destinationPort: 1,
chunks: []chunk{&chunkAbort{}},
},
},
"CoockeEcho": {
inputPacket: &packet{
sourcePort: 1,
destinationPort: 1,
chunks: []chunk{&chunkCookieEcho{}},
},
},
"HeartBeat": {
inputPacket: &packet{
sourcePort: 1,
destinationPort: 1,
chunks: []chunk{&chunkHeartbeat{}},
},
},
"PayloadData": {
inputPacket: &packet{
sourcePort: 1,
destinationPort: 1,
chunks: []chunk{&chunkPayloadData{}},
},
},
"Sack": {
inputPacket: &packet{
sourcePort: 1,
destinationPort: 1,
chunks: []chunk{&chunkSelectiveAck{
cumulativeTSNAck: 1000,
advertisedReceiverWindowCredit: 1500,
gapAckBlocks: []gapAckBlock{
{start: 100, end: 200},
},
}},
},
},
"Reconfig": {
inputPacket: &packet{
sourcePort: 1,
destinationPort: 1,
chunks: []chunk{&chunkReconfig{
paramA: &paramOutgoingResetRequest{},
paramB: &paramReconfigResponse{},
}},
},
},
"ForwardTSN": {
inputPacket: &packet{
sourcePort: 1,
destinationPort: 1,
chunks: []chunk{&chunkForwardTSN{
newCumulativeTSN: 100,
}},
},
},
"Error": {
inputPacket: &packet{
sourcePort: 1,
destinationPort: 1,
chunks: []chunk{&chunkError{}},
},
},
"Shutdown": {
inputPacket: &packet{
sourcePort: 1,
destinationPort: 1,
chunks: []chunk{&chunkShutdown{}},
},
},
"ShutdownAck": {
inputPacket: &packet{
sourcePort: 1,
destinationPort: 1,
chunks: []chunk{&chunkShutdownAck{}},
},
},
"ShutdownComplete": {
inputPacket: &packet{
sourcePort: 1,
destinationPort: 1,
chunks: []chunk{&chunkShutdownComplete{}},
},
},
}

for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
aConn, charlieConn := pipeDump()
a := createAssociation(Config{
NetConn: aConn,
MaxReceiveBufferSize: 0,
LoggerFactory: loggerFactory,
})
a.init(true)
defer a.close()

packet, err := testCase.inputPacket.marshal()
assert.NoError(t, err)
_, err = charlieConn.Write(packet)
assert.NoError(t, err)

// Should not panic.
time.Sleep(100 * time.Millisecond)
})
}
}
6 changes: 6 additions & 0 deletions chunk_payload_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package sctp

import (
"encoding/binary"
"errors"
"fmt"
"time"
)
Expand Down Expand Up @@ -95,6 +96,8 @@ const (
PayloadTypeWebRTCBinaryEmpty PayloadProtocolIdentifier = 57
)

var errChunkPayloadSmall = errors.New("packet is smaller than the header size")

func (p PayloadProtocolIdentifier) String() string {
switch p {
case PayloadTypeWebRTCDCEP:
Expand Down Expand Up @@ -122,6 +125,9 @@ func (p *chunkPayloadData) unmarshal(raw []byte) error {
p.beginningFragment = p.flags&payloadDataBeginingFragmentBitmask != 0
p.endingFragment = p.flags&payloadDataEndingFragmentBitmask != 0

if len(raw) < payloadDataHeaderSize {
return errChunkPayloadSmall
}
p.tsn = binary.BigEndian.Uint32(p.raw[0:])
p.streamIdentifier = binary.BigEndian.Uint16(p.raw[4:])
p.streamSequenceNumber = binary.BigEndian.Uint16(p.raw[6:])
Expand Down

0 comments on commit 156e8a7

Please sign in to comment.