Skip to content

Commit

Permalink
Fix packet handling before init
Browse files Browse the repository at this point in the history
It caused panic when some kind of packets are arrived
before completing initialization.
Add error check to CookieEcho and Data handlers.
  • Loading branch information
at-wat committed Mar 23, 2021
1 parent 4d8cfd4 commit fa0e0eb
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 0 deletions.
5 changes: 5 additions & 0 deletions association.go
Original file line number Diff line number Diff line change
Expand Up @@ -1146,6 +1146,11 @@ func (a *Association) handleHeartbeat(c *chunkHeartbeat) []*packet {
func (a *Association) handleCookieEcho(c *chunkCookieEcho) []*packet {
state := a.getState()
a.log.Debugf("[%s] COOKIE-ECHO received in state '%s'", a.name, getAssociationStateString(state))

if a.myCookie == nil {
a.log.Debugf("[%s] COOKIE-ECHO received before initialization", a.name)
return nil
}
switch state {
default:
return nil
Expand Down
137 changes: 137 additions & 0 deletions association_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2650,3 +2650,140 @@ func TestAssociation_ShutdownDuringWrite(t *testing.T) {
assert.Fail(t, "timed out waiting for a2 read loop to close")
}
}

func TestAssociation_HandlePacketBeforeInit(t *testing.T) {
loggerFactory := logging.NewDefaultLoggerFactory()

testCases := map[string]struct {
inputPacket *packet
}{
"InitAck": {
inputPacket: &packet{
sourcePort: 1,
destinationPort: 1,
chunks: []chunk{
&chunkInitAck{
chunkInitCommon: chunkInitCommon{
initiateTag: 1,
numInboundStreams: 1,
numOutboundStreams: 1,
advertisedReceiverWindowCredit: 1500,
},
},
},
},
},
"Abort": {
inputPacket: &packet{
sourcePort: 1,
destinationPort: 1,
chunks: []chunk{&chunkAbort{}},
},
},
"CoockeEcho": {
inputPacket: &packet{
sourcePort: 1,
destinationPort: 1,
chunks: []chunk{&chunkCookieEcho{}},
},
},
"HeartBeat": {
inputPacket: &packet{
sourcePort: 1,
destinationPort: 1,
chunks: []chunk{&chunkHeartbeat{}},
},
},
"PayloadData": {
inputPacket: &packet{
sourcePort: 1,
destinationPort: 1,
chunks: []chunk{&chunkPayloadData{}},
},
},
"Sack": {
inputPacket: &packet{
sourcePort: 1,
destinationPort: 1,
chunks: []chunk{&chunkSelectiveAck{
cumulativeTSNAck: 1000,
advertisedReceiverWindowCredit: 1500,
gapAckBlocks: []gapAckBlock{
{start: 100, end: 200},
},
}},
},
},
"Reconfig": {
inputPacket: &packet{
sourcePort: 1,
destinationPort: 1,
chunks: []chunk{&chunkReconfig{
paramA: &paramOutgoingResetRequest{},
paramB: &paramReconfigResponse{},
}},
},
},
"ForwardTSN": {
inputPacket: &packet{
sourcePort: 1,
destinationPort: 1,
chunks: []chunk{&chunkForwardTSN{
newCumulativeTSN: 100,
}},
},
},
"Error": {
inputPacket: &packet{
sourcePort: 1,
destinationPort: 1,
chunks: []chunk{&chunkError{}},
},
},
"Shutdown": {
inputPacket: &packet{
sourcePort: 1,
destinationPort: 1,
chunks: []chunk{&chunkShutdown{}},
},
},
"ShutdownAck": {
inputPacket: &packet{
sourcePort: 1,
destinationPort: 1,
chunks: []chunk{&chunkShutdownAck{}},
},
},
"ShutdownComplete": {
inputPacket: &packet{
sourcePort: 1,
destinationPort: 1,
chunks: []chunk{&chunkShutdownComplete{}},
},
},
}

for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
aConn, charlieConn := pipeDump()
a := createAssociation(Config{
NetConn: aConn,
MaxReceiveBufferSize: 0,
LoggerFactory: loggerFactory,
})
a.init(true)
defer func() {
assert.NoError(t, a.close())
}()

packet, err := testCase.inputPacket.marshal()
assert.NoError(t, err)
_, err = charlieConn.Write(packet)
assert.NoError(t, err)

// Should not panic.
time.Sleep(100 * time.Millisecond)
})
}
}
6 changes: 6 additions & 0 deletions chunk_payload_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package sctp

import (
"encoding/binary"
"errors"
"fmt"
"time"
)
Expand Down Expand Up @@ -95,6 +96,8 @@ const (
PayloadTypeWebRTCBinaryEmpty PayloadProtocolIdentifier = 57
)

var errChunkPayloadSmall = errors.New("packet is smaller than the header size")

func (p PayloadProtocolIdentifier) String() string {
switch p {
case PayloadTypeWebRTCDCEP:
Expand Down Expand Up @@ -122,6 +125,9 @@ func (p *chunkPayloadData) unmarshal(raw []byte) error {
p.beginningFragment = p.flags&payloadDataBeginingFragmentBitmask != 0
p.endingFragment = p.flags&payloadDataEndingFragmentBitmask != 0

if len(raw) < payloadDataHeaderSize {
return errChunkPayloadSmall
}
p.tsn = binary.BigEndian.Uint32(p.raw[0:])
p.streamIdentifier = binary.BigEndian.Uint16(p.raw[4:])
p.streamSequenceNumber = binary.BigEndian.Uint16(p.raw[6:])
Expand Down

0 comments on commit fa0e0eb

Please sign in to comment.